misc.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # ---------------------------------------------------------------------------
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # ---------------------------------------------------------------------------
  4. import time
  5. import datetime
  6. import numpy as np
  7. from typing import List
  8. from thop import profile
  9. from collections import defaultdict, deque
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torch.distributed as dist
  14. from torch import Tensor
  15. from .distributed_utils import is_dist_avail_and_initialized
  16. # ---------------------------- Train tools ----------------------------
  17. class SmoothedValue(object):
  18. """Track a series of values and provide access to smoothed values over a
  19. window or the global series average.
  20. """
  21. def __init__(self, window_size=20, fmt=None):
  22. if fmt is None:
  23. fmt = "{median:.4f} ({global_avg:.4f})"
  24. self.deque = deque(maxlen=window_size)
  25. self.total = 0.0
  26. self.count = 0
  27. self.fmt = fmt
  28. def update(self, value, n=1):
  29. self.deque.append(value)
  30. self.count += n
  31. self.total += value * n
  32. def synchronize_between_processes(self):
  33. """
  34. Warning: does not synchronize the deque!
  35. """
  36. if not is_dist_avail_and_initialized():
  37. return
  38. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  39. dist.barrier()
  40. dist.all_reduce(t)
  41. t = t.tolist()
  42. self.count = int(t[0])
  43. self.total = t[1]
  44. @property
  45. def median(self):
  46. d = torch.tensor(list(self.deque))
  47. return d.median().item()
  48. @property
  49. def avg(self):
  50. d = torch.tensor(list(self.deque), dtype=torch.float32)
  51. return d.mean().item()
  52. @property
  53. def global_avg(self):
  54. return self.total / self.count
  55. @property
  56. def max(self):
  57. return max(self.deque)
  58. @property
  59. def value(self):
  60. return self.deque[-1]
  61. def __str__(self):
  62. return self.fmt.format(
  63. median=self.median,
  64. avg=self.avg,
  65. global_avg=self.global_avg,
  66. max=self.max,
  67. value=self.value)
  68. class MetricLogger(object):
  69. def __init__(self, delimiter="\t"):
  70. self.meters = defaultdict(SmoothedValue)
  71. self.delimiter = delimiter
  72. def update(self, **kwargs):
  73. for k, v in kwargs.items():
  74. if isinstance(v, torch.Tensor):
  75. v = v.item()
  76. assert isinstance(v, (float, int))
  77. self.meters[k].update(v)
  78. def __getattr__(self, attr):
  79. if attr in self.meters:
  80. return self.meters[attr]
  81. if attr in self.__dict__:
  82. return self.__dict__[attr]
  83. raise AttributeError("'{}' object has no attribute '{}'".format(
  84. type(self).__name__, attr))
  85. def __str__(self):
  86. loss_str = []
  87. for name, meter in self.meters.items():
  88. loss_str.append(
  89. "{}: {}".format(name, str(meter))
  90. )
  91. return self.delimiter.join(loss_str)
  92. def synchronize_between_processes(self):
  93. for meter in self.meters.values():
  94. meter.synchronize_between_processes()
  95. def add_meter(self, name, meter):
  96. self.meters[name] = meter
  97. def log_every(self, iterable, print_freq, header=None):
  98. i = 0
  99. if not header:
  100. header = ''
  101. start_time = time.time()
  102. end = time.time()
  103. iter_time = SmoothedValue(fmt='{avg:.4f}')
  104. data_time = SmoothedValue(fmt='{avg:.4f}')
  105. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  106. if torch.cuda.is_available():
  107. log_msg = self.delimiter.join([
  108. header,
  109. '[{0' + space_fmt + '}/{1}]',
  110. 'eta: {eta}',
  111. '{meters}',
  112. 'time: {time}',
  113. 'data: {data}',
  114. 'max mem: {memory:.0f}'
  115. ])
  116. else:
  117. log_msg = self.delimiter.join([
  118. header,
  119. '[{0' + space_fmt + '}/{1}]',
  120. 'eta: {eta}',
  121. '{meters}',
  122. 'time: {time}',
  123. 'data: {data}'
  124. ])
  125. MB = 1024.0 * 1024.0
  126. for obj in iterable:
  127. data_time.update(time.time() - end)
  128. yield obj
  129. iter_time.update(time.time() - end)
  130. if i % print_freq == 0 or i == len(iterable) - 1:
  131. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  132. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  133. if torch.cuda.is_available():
  134. print(log_msg.format(
  135. i, len(iterable), eta=eta_string,
  136. meters=str(self),
  137. time=str(iter_time), data=str(data_time),
  138. memory=torch.cuda.max_memory_allocated() / MB))
  139. else:
  140. print(log_msg.format(
  141. i, len(iterable), eta=eta_string,
  142. meters=str(self),
  143. time=str(iter_time), data=str(data_time)))
  144. i += 1
  145. end = time.time()
  146. total_time = time.time() - start_time
  147. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  148. print('{} Total time: {} ({:.4f} s / it)'.format(
  149. header, total_time_str, total_time / len(iterable)))
  150. class SinkhornDistance(torch.nn.Module):
  151. def __init__(self, eps=1e-3, max_iter=100, reduction='none'):
  152. super(SinkhornDistance, self).__init__()
  153. self.eps = eps
  154. self.max_iter = max_iter
  155. self.reduction = reduction
  156. def forward(self, mu, nu, C):
  157. u = torch.ones_like(mu)
  158. v = torch.ones_like(nu)
  159. # Sinkhorn iterations
  160. for i in range(self.max_iter):
  161. v = self.eps * \
  162. (torch.log(
  163. nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
  164. u = self.eps * \
  165. (torch.log(
  166. mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
  167. U, V = u, v
  168. # Transport plan pi = diag(a)*K*diag(b)
  169. pi = torch.exp(
  170. self.M(C, U, V)).detach()
  171. # Sinkhorn distance
  172. cost = torch.sum(
  173. pi * C, dim=(-2, -1))
  174. return cost, pi
  175. def M(self, C, u, v):
  176. '''
  177. "Modified cost for logarithmic updates"
  178. "$M_{ij} = (-c_{ij} + u_i + v_j) / epsilon$"
  179. '''
  180. return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
  181. # ---------------------------- Dataloader tools ----------------------------
  182. def _max_by_axis(the_list):
  183. # type: (List[List[int]]) -> List[int]
  184. maxes = the_list[0]
  185. for sublist in the_list[1:]:
  186. for index, item in enumerate(sublist):
  187. maxes[index] = max(maxes[index], item)
  188. return maxes
  189. def batch_tensor_from_tensor_list(tensor_list: List[Tensor]):
  190. # TODO make this more general
  191. if tensor_list[0].ndim == 3:
  192. # TODO make it support different-sized images
  193. max_size = _max_by_axis([list(img.shape) for img in tensor_list])
  194. # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
  195. batch_shape = [len(tensor_list)] + max_size
  196. b, c, h, w = batch_shape
  197. dtype = tensor_list[0].dtype
  198. device = tensor_list[0].device
  199. tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
  200. mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
  201. for img, pad_img, m in zip(tensor_list, tensor, mask):
  202. pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  203. m[: img.shape[1], :img.shape[2]] = False
  204. else:
  205. raise ValueError('not supported')
  206. return tensor, mask
  207. def collate_fn(batch):
  208. batch = list(zip(*batch))
  209. batch[0] = batch_tensor_from_tensor_list(batch[0])
  210. return tuple(batch)
  211. # ---------------------------- For Model ----------------------------
  212. ## fuse Conv & BN layer
  213. def fuse_conv_bn(module):
  214. """Recursively fuse conv and bn in a module.
  215. During inference, the functionary of batch norm layers is turned off
  216. but only the mean and var alone channels are used, which exposes the
  217. chance to fuse it with the preceding conv layers to save computations and
  218. simplify network structures.
  219. Args:
  220. module (nn.Module): Module to be fused.
  221. Returns:
  222. nn.Module: Fused module.
  223. """
  224. last_conv = None
  225. last_conv_name = None
  226. def _fuse_conv_bn(conv, bn):
  227. """Fuse conv and bn into one module.
  228. Args:
  229. conv (nn.Module): Conv to be fused.
  230. bn (nn.Module): BN to be fused.
  231. Returns:
  232. nn.Module: Fused module.
  233. """
  234. conv_w = conv.weight
  235. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  236. bn.running_mean)
  237. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  238. conv.weight = nn.Parameter(conv_w *
  239. factor.reshape([conv.out_channels, 1, 1, 1]))
  240. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  241. return conv
  242. for name, child in module.named_children():
  243. if isinstance(child,
  244. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  245. if last_conv is None: # only fuse BN that is after Conv
  246. continue
  247. fused_conv = _fuse_conv_bn(last_conv, child)
  248. module._modules[last_conv_name] = fused_conv
  249. # To reduce changes, set BN as Identity instead of deleting it.
  250. module._modules[name] = nn.Identity()
  251. last_conv = None
  252. elif isinstance(child, nn.Conv2d):
  253. last_conv = child
  254. last_conv_name = name
  255. else:
  256. fuse_conv_bn(child)
  257. return module
  258. ## compute FLOPs & Parameters
  259. def compute_flops(model, min_size, max_size, device):
  260. if isinstance(min_size[0], List):
  261. min_size, max_size = min_size[0]
  262. else:
  263. min_size = min_size[0]
  264. x = torch.randn(1, 3, min_size, max_size).to(device)
  265. print('==============================')
  266. flops, params = profile(model, inputs=(x, ), verbose=False)
  267. print('GFLOPs : {:.2f}'.format(flops / 1e9))
  268. print('Params : {:.2f} M'.format(params / 1e6))
  269. ## load trained weight
  270. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  271. # check ckpt file
  272. if path_to_ckpt is None:
  273. print('no weight file ...')
  274. else:
  275. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  276. if "epoch" in checkpoint and "mAP" in checkpoint:
  277. print('--------------------------------------')
  278. print('Best model infor:')
  279. print('Epoch: {}'.format(checkpoint.pop("epoch")))
  280. print('mAP: {}'.format(checkpoint.pop("mAP")))
  281. print('--------------------------------------')
  282. checkpoint_state_dict = checkpoint.pop("model")
  283. model.load_state_dict(checkpoint_state_dict)
  284. print('Finished loading model!')
  285. # fuse conv & bn
  286. if fuse_cbn:
  287. print('Fusing Conv & BN ...')
  288. model = fuse_conv_bn(model)
  289. return model
  290. ## gradient clip
  291. def get_total_grad_norm(parameters, norm_type=2):
  292. parameters = list(filter(lambda p: p.grad is not None, parameters))
  293. norm_type = float(norm_type)
  294. device = parameters[0].grad.device
  295. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
  296. norm_type)
  297. return total_norm
  298. # ---------------------------- For Loss ----------------------------
  299. ## focal loss
  300. def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
  301. """
  302. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  303. Args:
  304. inputs: A float tensor of arbitrary shape.
  305. The predictions for each example.
  306. targets: A float tensor with the same shape as inputs. Stores the binary
  307. classification label for each element in inputs
  308. (0 for the negative class and 1 for the positive class).
  309. alpha: (optional) Weighting factor in range (0,1) to balance
  310. positive vs negative examples. Default = -1 (no weighting).
  311. gamma: Exponent of the modulating factor (1 - p_t) to
  312. balance easy vs hard examples.
  313. Returns:
  314. Loss tensor
  315. """
  316. prob = inputs.sigmoid()
  317. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  318. p_t = prob * targets + (1 - prob) * (1 - targets)
  319. loss = ce_loss * ((1 - p_t) ** gamma)
  320. if alpha >= 0:
  321. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  322. loss = alpha_t * loss
  323. return loss
  324. # ---------------------------- NMS ----------------------------
  325. def nms(bboxes, scores, nms_thresh):
  326. """"Pure Python NMS."""
  327. x1 = bboxes[:, 0] #xmin
  328. y1 = bboxes[:, 1] #ymin
  329. x2 = bboxes[:, 2] #xmax
  330. y2 = bboxes[:, 3] #ymax
  331. areas = (x2 - x1) * (y2 - y1)
  332. order = scores.argsort()[::-1]
  333. keep = []
  334. while order.size > 0:
  335. i = order[0]
  336. keep.append(i)
  337. # compute iou
  338. xx1 = np.maximum(x1[i], x1[order[1:]])
  339. yy1 = np.maximum(y1[i], y1[order[1:]])
  340. xx2 = np.minimum(x2[i], x2[order[1:]])
  341. yy2 = np.minimum(y2[i], y2[order[1:]])
  342. w = np.maximum(1e-10, xx2 - xx1)
  343. h = np.maximum(1e-10, yy2 - yy1)
  344. inter = w * h
  345. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  346. #reserve all the boundingbox whose ovr less than thresh
  347. inds = np.where(iou <= nms_thresh)[0]
  348. order = order[inds + 1]
  349. return keep
  350. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  351. # nms
  352. keep = nms(bboxes, scores, nms_thresh)
  353. scores = scores[keep]
  354. labels = labels[keep]
  355. bboxes = bboxes[keep]
  356. return scores, labels, bboxes
  357. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  358. # nms
  359. keep = np.zeros(len(bboxes), dtype=np.int32)
  360. for i in range(num_classes):
  361. inds = np.where(labels == i)[0]
  362. if len(inds) == 0:
  363. continue
  364. c_bboxes = bboxes[inds]
  365. c_scores = scores[inds]
  366. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  367. keep[inds[c_keep]] = 1
  368. keep = np.where(keep > 0)
  369. scores = scores[keep]
  370. labels = labels[keep]
  371. bboxes = bboxes[keep]
  372. return scores, labels, bboxes
  373. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  374. if class_agnostic:
  375. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  376. else:
  377. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)