misc.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # ---------------------------------------------------------------------------
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # ---------------------------------------------------------------------------
  4. import time
  5. import datetime
  6. import numpy as np
  7. from typing import List
  8. from thop import profile
  9. from collections import defaultdict, deque
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torch.distributed as dist
  14. from torch import Tensor
  15. from .distributed_utils import is_dist_avail_and_initialized
  16. # ---------------------------- Train tools ----------------------------
  17. class SmoothedValue(object):
  18. """Track a series of values and provide access to smoothed values over a
  19. window or the global series average.
  20. """
  21. def __init__(self, window_size=20, fmt=None):
  22. if fmt is None:
  23. fmt = "{median:.4f} ({global_avg:.4f})"
  24. self.deque = deque(maxlen=window_size)
  25. self.total = 0.0
  26. self.count = 0
  27. self.fmt = fmt
  28. def update(self, value, n=1):
  29. self.deque.append(value)
  30. self.count += n
  31. self.total += value * n
  32. def synchronize_between_processes(self):
  33. """
  34. Warning: does not synchronize the deque!
  35. """
  36. if not is_dist_avail_and_initialized():
  37. return
  38. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  39. dist.barrier()
  40. dist.all_reduce(t)
  41. t = t.tolist()
  42. self.count = int(t[0])
  43. self.total = t[1]
  44. @property
  45. def median(self):
  46. d = torch.tensor(list(self.deque))
  47. return d.median().item()
  48. @property
  49. def avg(self):
  50. d = torch.tensor(list(self.deque), dtype=torch.float32)
  51. return d.mean().item()
  52. @property
  53. def global_avg(self):
  54. return self.total / self.count
  55. @property
  56. def max(self):
  57. return max(self.deque)
  58. @property
  59. def value(self):
  60. return self.deque[-1]
  61. def __str__(self):
  62. return self.fmt.format(
  63. median=self.median,
  64. avg=self.avg,
  65. global_avg=self.global_avg,
  66. max=self.max,
  67. value=self.value)
  68. class MetricLogger(object):
  69. def __init__(self, delimiter="\t"):
  70. self.meters = defaultdict(SmoothedValue)
  71. self.delimiter = delimiter
  72. def update(self, **kwargs):
  73. for k, v in kwargs.items():
  74. if isinstance(v, torch.Tensor):
  75. v = v.item()
  76. assert isinstance(v, (float, int))
  77. self.meters[k].update(v)
  78. def __getattr__(self, attr):
  79. if attr in self.meters:
  80. return self.meters[attr]
  81. if attr in self.__dict__:
  82. return self.__dict__[attr]
  83. raise AttributeError("'{}' object has no attribute '{}'".format(
  84. type(self).__name__, attr))
  85. def __str__(self):
  86. loss_str = []
  87. for name, meter in self.meters.items():
  88. loss_str.append(
  89. "{}: {}".format(name, str(meter))
  90. )
  91. return self.delimiter.join(loss_str)
  92. def synchronize_between_processes(self):
  93. for meter in self.meters.values():
  94. meter.synchronize_between_processes()
  95. def add_meter(self, name, meter):
  96. self.meters[name] = meter
  97. def log_every(self, iterable, print_freq, header=None):
  98. i = 0
  99. if not header:
  100. header = ''
  101. start_time = time.time()
  102. end = time.time()
  103. iter_time = SmoothedValue(fmt='{avg:.4f}')
  104. data_time = SmoothedValue(fmt='{avg:.4f}')
  105. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  106. if torch.cuda.is_available():
  107. log_msg = self.delimiter.join([
  108. header,
  109. '[{0' + space_fmt + '}/{1}]',
  110. 'eta: {eta}',
  111. '{meters}',
  112. 'time: {time}',
  113. 'data: {data}',
  114. 'max mem: {memory:.0f}'
  115. ])
  116. else:
  117. log_msg = self.delimiter.join([
  118. header,
  119. '[{0' + space_fmt + '}/{1}]',
  120. 'eta: {eta}',
  121. '{meters}',
  122. 'time: {time}',
  123. 'data: {data}'
  124. ])
  125. MB = 1024.0 * 1024.0
  126. for obj in iterable:
  127. data_time.update(time.time() - end)
  128. yield obj
  129. iter_time.update(time.time() - end)
  130. if i % print_freq == 0 or i == len(iterable) - 1:
  131. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  132. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  133. if torch.cuda.is_available():
  134. print(log_msg.format(
  135. i, len(iterable), eta=eta_string,
  136. meters=str(self),
  137. time=str(iter_time), data=str(data_time),
  138. memory=torch.cuda.max_memory_allocated() / MB))
  139. else:
  140. print(log_msg.format(
  141. i, len(iterable), eta=eta_string,
  142. meters=str(self),
  143. time=str(iter_time), data=str(data_time)))
  144. i += 1
  145. end = time.time()
  146. total_time = time.time() - start_time
  147. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  148. print('{} Total time: {} ({:.4f} s / it)'.format(
  149. header, total_time_str, total_time / len(iterable)))
  150. class SinkhornDistance(torch.nn.Module):
  151. def __init__(self, eps=1e-3, max_iter=100, reduction='none'):
  152. super(SinkhornDistance, self).__init__()
  153. self.eps = eps
  154. self.max_iter = max_iter
  155. self.reduction = reduction
  156. def forward(self, mu, nu, C):
  157. u = torch.ones_like(mu)
  158. v = torch.ones_like(nu)
  159. # Sinkhorn iterations
  160. for i in range(self.max_iter):
  161. v = self.eps * \
  162. (torch.log(
  163. nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
  164. u = self.eps * \
  165. (torch.log(
  166. mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
  167. U, V = u, v
  168. # Transport plan pi = diag(a)*K*diag(b)
  169. pi = torch.exp(
  170. self.M(C, U, V)).detach()
  171. # Sinkhorn distance
  172. cost = torch.sum(
  173. pi * C, dim=(-2, -1))
  174. return cost, pi
  175. def M(self, C, u, v):
  176. '''
  177. "Modified cost for logarithmic updates"
  178. "$M_{ij} = (-c_{ij} + u_i + v_j) / epsilon$"
  179. '''
  180. return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
  181. # ---------------------------- Dataloader tools ----------------------------
  182. def _max_by_axis(the_list):
  183. # type: (List[List[int]]) -> List[int]
  184. maxes = the_list[0]
  185. for sublist in the_list[1:]:
  186. for index, item in enumerate(sublist):
  187. maxes[index] = max(maxes[index], item)
  188. return maxes
  189. def batch_tensor_from_tensor_list(tensor_list: List[Tensor]):
  190. # TODO make this more general
  191. if tensor_list[0].ndim == 3:
  192. # TODO make it support different-sized images
  193. max_size = _max_by_axis([list(img.shape) for img in tensor_list])
  194. # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
  195. batch_shape = [len(tensor_list)] + max_size
  196. b, c, h, w = batch_shape
  197. dtype = tensor_list[0].dtype
  198. device = tensor_list[0].device
  199. tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
  200. mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
  201. for img, pad_img, m in zip(tensor_list, tensor, mask):
  202. pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  203. m[: img.shape[1], :img.shape[2]] = False
  204. else:
  205. raise ValueError('not supported')
  206. return tensor, mask
  207. def collate_fn(batch):
  208. batch = list(zip(*batch))
  209. batch[0] = batch_tensor_from_tensor_list(batch[0])
  210. return tuple(batch)
  211. # ---------------------------- For Model ----------------------------
  212. ## fuse Conv & BN layer
  213. def fuse_conv_bn(module):
  214. """Recursively fuse conv and bn in a module.
  215. During inference, the functionary of batch norm layers is turned off
  216. but only the mean and var alone channels are used, which exposes the
  217. chance to fuse it with the preceding conv layers to save computations and
  218. simplify network structures.
  219. Args:
  220. module (nn.Module): Module to be fused.
  221. Returns:
  222. nn.Module: Fused module.
  223. """
  224. last_conv = None
  225. last_conv_name = None
  226. def _fuse_conv_bn(conv, bn):
  227. """Fuse conv and bn into one module.
  228. Args:
  229. conv (nn.Module): Conv to be fused.
  230. bn (nn.Module): BN to be fused.
  231. Returns:
  232. nn.Module: Fused module.
  233. """
  234. conv_w = conv.weight
  235. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  236. bn.running_mean)
  237. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  238. conv.weight = nn.Parameter(conv_w *
  239. factor.reshape([conv.out_channels, 1, 1, 1]))
  240. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  241. return conv
  242. for name, child in module.named_children():
  243. if isinstance(child,
  244. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  245. if last_conv is None: # only fuse BN that is after Conv
  246. continue
  247. fused_conv = _fuse_conv_bn(last_conv, child)
  248. module._modules[last_conv_name] = fused_conv
  249. # To reduce changes, set BN as Identity instead of deleting it.
  250. module._modules[name] = nn.Identity()
  251. last_conv = None
  252. elif isinstance(child, nn.Conv2d):
  253. last_conv = child
  254. last_conv_name = name
  255. else:
  256. fuse_conv_bn(child)
  257. return module
  258. ## compute FLOPs & Parameters
  259. def compute_flops(model, min_size, max_size, device):
  260. if isinstance(min_size[0], List):
  261. min_size, max_size = min_size[0]
  262. else:
  263. min_size = min_size[0]
  264. x = torch.randn(1, 3, min_size, max_size).to(device)
  265. print('==============================')
  266. flops, params = profile(model, inputs=(x, ), verbose=False)
  267. print('GFLOPs : {:.2f}'.format(flops / 1e9))
  268. print('Params : {:.2f} M'.format(params / 1e6))
  269. ## load trained weight
  270. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  271. # check ckpt file
  272. if path_to_ckpt is None:
  273. print('no weight file ...')
  274. else:
  275. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  276. print('--------------------------------------')
  277. print('Best model infor:')
  278. print('Epoch: {}'.format(checkpoint.pop("epoch")))
  279. print('mAP: {}'.format(checkpoint.pop("mAP")))
  280. print('--------------------------------------')
  281. checkpoint_state_dict = checkpoint.pop("model")
  282. model.load_state_dict(checkpoint_state_dict)
  283. print('Finished loading model!')
  284. # fuse conv & bn
  285. if fuse_cbn:
  286. print('Fusing Conv & BN ...')
  287. model = fuse_conv_bn(model)
  288. return model
  289. ## gradient clip
  290. def get_total_grad_norm(parameters, norm_type=2):
  291. parameters = list(filter(lambda p: p.grad is not None, parameters))
  292. norm_type = float(norm_type)
  293. device = parameters[0].grad.device
  294. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
  295. norm_type)
  296. return total_norm
  297. # ---------------------------- For Loss ----------------------------
  298. ## focal loss
  299. def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
  300. """
  301. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  302. Args:
  303. inputs: A float tensor of arbitrary shape.
  304. The predictions for each example.
  305. targets: A float tensor with the same shape as inputs. Stores the binary
  306. classification label for each element in inputs
  307. (0 for the negative class and 1 for the positive class).
  308. alpha: (optional) Weighting factor in range (0,1) to balance
  309. positive vs negative examples. Default = -1 (no weighting).
  310. gamma: Exponent of the modulating factor (1 - p_t) to
  311. balance easy vs hard examples.
  312. Returns:
  313. Loss tensor
  314. """
  315. prob = inputs.sigmoid()
  316. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  317. p_t = prob * targets + (1 - prob) * (1 - targets)
  318. loss = ce_loss * ((1 - p_t) ** gamma)
  319. if alpha >= 0:
  320. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  321. loss = alpha_t * loss
  322. return loss
  323. # ---------------------------- NMS ----------------------------
  324. def nms(bboxes, scores, nms_thresh):
  325. """"Pure Python NMS."""
  326. x1 = bboxes[:, 0] #xmin
  327. y1 = bboxes[:, 1] #ymin
  328. x2 = bboxes[:, 2] #xmax
  329. y2 = bboxes[:, 3] #ymax
  330. areas = (x2 - x1) * (y2 - y1)
  331. order = scores.argsort()[::-1]
  332. keep = []
  333. while order.size > 0:
  334. i = order[0]
  335. keep.append(i)
  336. # compute iou
  337. xx1 = np.maximum(x1[i], x1[order[1:]])
  338. yy1 = np.maximum(y1[i], y1[order[1:]])
  339. xx2 = np.minimum(x2[i], x2[order[1:]])
  340. yy2 = np.minimum(y2[i], y2[order[1:]])
  341. w = np.maximum(1e-10, xx2 - xx1)
  342. h = np.maximum(1e-10, yy2 - yy1)
  343. inter = w * h
  344. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  345. #reserve all the boundingbox whose ovr less than thresh
  346. inds = np.where(iou <= nms_thresh)[0]
  347. order = order[inds + 1]
  348. return keep
  349. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  350. # nms
  351. keep = nms(bboxes, scores, nms_thresh)
  352. scores = scores[keep]
  353. labels = labels[keep]
  354. bboxes = bboxes[keep]
  355. return scores, labels, bboxes
  356. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  357. # nms
  358. keep = np.zeros(len(bboxes), dtype=np.int32)
  359. for i in range(num_classes):
  360. inds = np.where(labels == i)[0]
  361. if len(inds) == 0:
  362. continue
  363. c_bboxes = bboxes[inds]
  364. c_scores = scores[inds]
  365. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  366. keep[inds[c_keep]] = 1
  367. keep = np.where(keep > 0)
  368. scores = scores[keep]
  369. labels = labels[keep]
  370. bboxes = bboxes[keep]
  371. return scores, labels, bboxes
  372. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  373. if class_agnostic:
  374. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  375. else:
  376. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)