misc.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # ---------------------------------------------------------------------------
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # ---------------------------------------------------------------------------
  4. import time
  5. import math
  6. import datetime
  7. import numpy as np
  8. from typing import List
  9. from thop import profile
  10. from copy import deepcopy
  11. from collections import defaultdict, deque
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import torch.distributed as dist
  16. from torch import Tensor
  17. from .distributed_utils import is_dist_avail_and_initialized
  18. # ---------------------------- Train tools ----------------------------
  19. class SmoothedValue(object):
  20. """Track a series of values and provide access to smoothed values over a
  21. window or the global series average.
  22. """
  23. def __init__(self, window_size=20, fmt=None):
  24. if fmt is None:
  25. fmt = "{median:.4f} ({global_avg:.4f})"
  26. self.deque = deque(maxlen=window_size)
  27. self.total = 0.0
  28. self.count = 0
  29. self.fmt = fmt
  30. def update(self, value, n=1):
  31. self.deque.append(value)
  32. self.count += n
  33. self.total += value * n
  34. def synchronize_between_processes(self):
  35. """
  36. Warning: does not synchronize the deque!
  37. """
  38. if not is_dist_avail_and_initialized():
  39. return
  40. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  41. dist.barrier()
  42. dist.all_reduce(t)
  43. t = t.tolist()
  44. self.count = int(t[0])
  45. self.total = t[1]
  46. @property
  47. def median(self):
  48. d = torch.tensor(list(self.deque))
  49. return d.median().item()
  50. @property
  51. def avg(self):
  52. d = torch.tensor(list(self.deque), dtype=torch.float32)
  53. return d.mean().item()
  54. @property
  55. def global_avg(self):
  56. return self.total / self.count
  57. @property
  58. def max(self):
  59. return max(self.deque)
  60. @property
  61. def value(self):
  62. return self.deque[-1]
  63. def __str__(self):
  64. return self.fmt.format(
  65. median=self.median,
  66. avg=self.avg,
  67. global_avg=self.global_avg,
  68. max=self.max,
  69. value=self.value)
  70. class MetricLogger(object):
  71. def __init__(self, delimiter="\t"):
  72. self.meters = defaultdict(SmoothedValue)
  73. self.delimiter = delimiter
  74. def update(self, **kwargs):
  75. for k, v in kwargs.items():
  76. if isinstance(v, torch.Tensor):
  77. v = v.item()
  78. assert isinstance(v, (float, int))
  79. self.meters[k].update(v)
  80. def __getattr__(self, attr):
  81. if attr in self.meters:
  82. return self.meters[attr]
  83. if attr in self.__dict__:
  84. return self.__dict__[attr]
  85. raise AttributeError("'{}' object has no attribute '{}'".format(
  86. type(self).__name__, attr))
  87. def __str__(self):
  88. loss_str = []
  89. for name, meter in self.meters.items():
  90. loss_str.append(
  91. "{}: {}".format(name, str(meter))
  92. )
  93. return self.delimiter.join(loss_str)
  94. def synchronize_between_processes(self):
  95. for meter in self.meters.values():
  96. meter.synchronize_between_processes()
  97. def add_meter(self, name, meter):
  98. self.meters[name] = meter
  99. def log_every(self, iterable, print_freq, header=None):
  100. i = 0
  101. if not header:
  102. header = ''
  103. start_time = time.time()
  104. end = time.time()
  105. iter_time = SmoothedValue(fmt='{avg:.4f}')
  106. data_time = SmoothedValue(fmt='{avg:.4f}')
  107. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  108. if torch.cuda.is_available():
  109. log_msg = self.delimiter.join([
  110. header,
  111. '[{0' + space_fmt + '}/{1}]',
  112. 'eta: {eta}',
  113. '{meters}',
  114. 'time: {time}',
  115. 'data: {data}',
  116. 'max mem: {memory:.0f}'
  117. ])
  118. else:
  119. log_msg = self.delimiter.join([
  120. header,
  121. '[{0' + space_fmt + '}/{1}]',
  122. 'eta: {eta}',
  123. '{meters}',
  124. 'time: {time}',
  125. 'data: {data}'
  126. ])
  127. MB = 1024.0 * 1024.0
  128. for obj in iterable:
  129. data_time.update(time.time() - end)
  130. yield obj
  131. iter_time.update(time.time() - end)
  132. if i % print_freq == 0 or i == len(iterable) - 1:
  133. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  134. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  135. if torch.cuda.is_available():
  136. print(log_msg.format(
  137. i, len(iterable), eta=eta_string,
  138. meters=str(self),
  139. time=str(iter_time), data=str(data_time),
  140. memory=torch.cuda.max_memory_allocated() / MB))
  141. else:
  142. print(log_msg.format(
  143. i, len(iterable), eta=eta_string,
  144. meters=str(self),
  145. time=str(iter_time), data=str(data_time)))
  146. i += 1
  147. end = time.time()
  148. total_time = time.time() - start_time
  149. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  150. print('{} Total time: {} ({:.4f} s / it)'.format(
  151. header, total_time_str, total_time / len(iterable)))
  152. class SinkhornDistance(torch.nn.Module):
  153. def __init__(self, eps=1e-3, max_iter=100, reduction='none'):
  154. super(SinkhornDistance, self).__init__()
  155. self.eps = eps
  156. self.max_iter = max_iter
  157. self.reduction = reduction
  158. def forward(self, mu, nu, C):
  159. u = torch.ones_like(mu)
  160. v = torch.ones_like(nu)
  161. # Sinkhorn iterations
  162. for i in range(self.max_iter):
  163. v = self.eps * \
  164. (torch.log(
  165. nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
  166. u = self.eps * \
  167. (torch.log(
  168. mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
  169. U, V = u, v
  170. # Transport plan pi = diag(a)*K*diag(b)
  171. pi = torch.exp(
  172. self.M(C, U, V)).detach()
  173. # Sinkhorn distance
  174. cost = torch.sum(
  175. pi * C, dim=(-2, -1))
  176. return cost, pi
  177. def M(self, C, u, v):
  178. '''
  179. "Modified cost for logarithmic updates"
  180. "$M_{ij} = (-c_{ij} + u_i + v_j) / epsilon$"
  181. '''
  182. return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
  183. # ---------------------------- Dataloader tools ----------------------------
  184. def _max_by_axis(the_list):
  185. # type: (List[List[int]]) -> List[int]
  186. maxes = the_list[0]
  187. for sublist in the_list[1:]:
  188. for index, item in enumerate(sublist):
  189. maxes[index] = max(maxes[index], item)
  190. return maxes
  191. def batch_tensor_from_tensor_list(tensor_list: List[Tensor]):
  192. # TODO make this more general
  193. if tensor_list[0].ndim == 3:
  194. # TODO make it support different-sized images
  195. max_size = _max_by_axis([list(img.shape) for img in tensor_list])
  196. # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
  197. batch_shape = [len(tensor_list)] + max_size
  198. b, c, h, w = batch_shape
  199. dtype = tensor_list[0].dtype
  200. device = tensor_list[0].device
  201. tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
  202. mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
  203. for img, pad_img, m in zip(tensor_list, tensor, mask):
  204. pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  205. m[: img.shape[1], :img.shape[2]] = False
  206. else:
  207. raise ValueError('not supported')
  208. return tensor, mask
  209. def collate_fn(batch):
  210. batch = list(zip(*batch))
  211. batch[0] = batch_tensor_from_tensor_list(batch[0])
  212. return tuple(batch)
  213. # ---------------------------- For Model ----------------------------
  214. def match_name_keywords(n, name_keywords):
  215. out = False
  216. for b in name_keywords:
  217. if b in n:
  218. out = True
  219. break
  220. return out
  221. ## fuse Conv & BN layer
  222. def fuse_conv_bn(module):
  223. """Recursively fuse conv and bn in a module.
  224. During inference, the functionary of batch norm layers is turned off
  225. but only the mean and var alone channels are used, which exposes the
  226. chance to fuse it with the preceding conv layers to save computations and
  227. simplify network structures.
  228. Args:
  229. module (nn.Module): Module to be fused.
  230. Returns:
  231. nn.Module: Fused module.
  232. """
  233. last_conv = None
  234. last_conv_name = None
  235. def _fuse_conv_bn(conv, bn):
  236. """Fuse conv and bn into one module.
  237. Args:
  238. conv (nn.Module): Conv to be fused.
  239. bn (nn.Module): BN to be fused.
  240. Returns:
  241. nn.Module: Fused module.
  242. """
  243. conv_w = conv.weight
  244. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  245. bn.running_mean)
  246. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  247. conv.weight = nn.Parameter(conv_w *
  248. factor.reshape([conv.out_channels, 1, 1, 1]))
  249. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  250. return conv
  251. for name, child in module.named_children():
  252. if isinstance(child,
  253. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  254. if last_conv is None: # only fuse BN that is after Conv
  255. continue
  256. fused_conv = _fuse_conv_bn(last_conv, child)
  257. module._modules[last_conv_name] = fused_conv
  258. # To reduce changes, set BN as Identity instead of deleting it.
  259. module._modules[name] = nn.Identity()
  260. last_conv = None
  261. elif isinstance(child, nn.Conv2d):
  262. last_conv = child
  263. last_conv_name = name
  264. else:
  265. fuse_conv_bn(child)
  266. return module
  267. ## compute FLOPs & Parameters
  268. def compute_flops(model, min_size, max_size, device):
  269. if isinstance(min_size[0], List):
  270. min_size, max_size = min_size[0]
  271. else:
  272. min_size = min_size[0]
  273. x = torch.randn(1, 3, min_size, max_size).to(device)
  274. print('==============================')
  275. flops, params = profile(model, inputs=(x, ), verbose=False)
  276. print('GFLOPs : {:.2f}'.format(flops / 1e9))
  277. print('Params : {:.2f} M'.format(params / 1e6))
  278. ## load trained weight
  279. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  280. # check ckpt file
  281. if path_to_ckpt is None:
  282. print('no weight file ...')
  283. else:
  284. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  285. print('--------------------------------------')
  286. print('Best model infor:')
  287. print('Epoch: {}'.format(checkpoint.pop("epoch")))
  288. print('mAP: {}'.format(checkpoint.pop("mAP")))
  289. print('--------------------------------------')
  290. checkpoint_state_dict = checkpoint.pop("model")
  291. model.load_state_dict(checkpoint_state_dict)
  292. print('Finished loading model!')
  293. # fuse conv & bn
  294. if fuse_cbn:
  295. print('Fusing Conv & BN ...')
  296. model = fuse_conv_bn(model)
  297. return model
  298. ## gradient clip
  299. def get_total_grad_norm(parameters, norm_type=2):
  300. parameters = list(filter(lambda p: p.grad is not None, parameters))
  301. norm_type = float(norm_type)
  302. device = parameters[0].grad.device
  303. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
  304. norm_type)
  305. return total_norm
  306. ## param Dict
  307. def get_param_dict(model, cfg, return_name=False):
  308. # sanity check: a variable could not match backbone_names and linear_proj_names at the same time
  309. cfg['lr_backbone'] = cfg['base_lr'] * cfg['backbone_lr_ratio']
  310. for n, p in model.named_parameters():
  311. if match_name_keywords(n, cfg['lr_backbone_names']) and match_name_keywords(n, cfg['lr_linear_proj_names']):
  312. raise ValueError
  313. param_dicts = [
  314. {
  315. "params": [
  316. p if not return_name else n
  317. for n, p in model.named_parameters()
  318. if not match_name_keywords(n, cfg['lr_backbone_names'])
  319. and not match_name_keywords(n, cfg['lr_linear_proj_names'])
  320. and not match_name_keywords(n, cfg['wd_norm_names'])
  321. and p.requires_grad
  322. ],
  323. "lr": cfg['base_lr'],
  324. "weight_decay": cfg['weight_decay'],
  325. },
  326. {
  327. "params": [
  328. p if not return_name else n
  329. for n, p in model.named_parameters()
  330. if match_name_keywords(n, cfg['lr_backbone_names'])
  331. and not match_name_keywords(n, cfg['lr_linear_proj_names'])
  332. and not match_name_keywords(n, cfg['wd_norm_names'])
  333. and p.requires_grad
  334. ],
  335. "lr": cfg['lr_backbone'],
  336. "weight_decay": cfg['weight_decay'],
  337. },
  338. {
  339. "params": [
  340. p if not return_name else n
  341. for n, p in model.named_parameters()
  342. if not match_name_keywords(n, cfg['lr_backbone_names'])
  343. and match_name_keywords(n, cfg['lr_linear_proj_names'])
  344. and not match_name_keywords(n, cfg['wd_norm_names'])
  345. and p.requires_grad
  346. ],
  347. "lr": cfg['base_lr'] * cfg['lr_linear_proj_mult'],
  348. "weight_decay": cfg['weight_decay'],
  349. },
  350. {
  351. "params": [
  352. p if not return_name else n
  353. for n, p in model.named_parameters()
  354. if not match_name_keywords(n, cfg['lr_backbone_names'])
  355. and not match_name_keywords(n, cfg['lr_linear_proj_names'])
  356. and match_name_keywords(n, cfg['wd_norm_names'])
  357. and p.requires_grad
  358. ],
  359. "lr": cfg['base_lr'],
  360. "weight_decay": cfg['weight_decay'] * cfg['wd_norm_mult'],
  361. },
  362. {
  363. "params": [
  364. p if not return_name else n
  365. for n, p in model.named_parameters()
  366. if match_name_keywords(n, cfg['lr_backbone_names'])
  367. and not match_name_keywords(n, cfg['lr_linear_proj_names'])
  368. and match_name_keywords(n, cfg['wd_norm_names'])
  369. and p.requires_grad
  370. ],
  371. "lr": cfg['lr_backbone'],
  372. "weight_decay": cfg['weight_decay'] * cfg['wd_norm_mult'],
  373. },
  374. {
  375. "params": [
  376. p if not return_name else n
  377. for n, p in model.named_parameters()
  378. if not match_name_keywords(n, cfg['lr_backbone_names'])
  379. and match_name_keywords(n, cfg['lr_linear_proj_names'])
  380. and match_name_keywords(n, cfg['wd_norm_names'])
  381. and p.requires_grad
  382. ],
  383. "lr": cfg['base_lr'] * cfg['lr_linear_proj_mult'],
  384. "weight_decay": cfg['weight_decay'] * cfg['wd_norm_mult'],
  385. },
  386. ]
  387. return param_dicts
  388. ## Model EMA
  389. class ModelEMA(object):
  390. def __init__(self, cfg, model, updates=0):
  391. # Create EMA
  392. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  393. self.updates = updates # number of EMA updates
  394. self.decay = lambda x: cfg['ema_decay'] * (1 - math.exp(-x / cfg['ema_tau'])) # decay exponential ramp (to help early epochs)
  395. for p in self.ema.parameters():
  396. p.requires_grad_(False)
  397. def is_parallel(self, model):
  398. # Returns True if model is of type DP or DDP
  399. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  400. def de_parallel(self, model):
  401. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  402. return model.module if self.is_parallel(model) else model
  403. def copy_attr(self, a, b, include=(), exclude=()):
  404. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  405. for k, v in b.__dict__.items():
  406. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  407. continue
  408. else:
  409. setattr(a, k, v)
  410. def update(self, model):
  411. # Update EMA parameters
  412. self.updates += 1
  413. d = self.decay(self.updates)
  414. msd = self.de_parallel(model).state_dict() # model state_dict
  415. for k, v in self.ema.state_dict().items():
  416. if v.dtype.is_floating_point: # true for FP16 and FP32
  417. v *= d
  418. v += (1 - d) * msd[k].detach()
  419. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  420. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  421. # Update EMA attributes
  422. self.copy_attr(self.ema, model, include, exclude)
  423. # ---------------------------- For Loss ----------------------------
  424. ## focal loss
  425. def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
  426. """
  427. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  428. Args:
  429. inputs: A float tensor of arbitrary shape.
  430. The predictions for each example.
  431. targets: A float tensor with the same shape as inputs. Stores the binary
  432. classification label for each element in inputs
  433. (0 for the negative class and 1 for the positive class).
  434. alpha: (optional) Weighting factor in range (0,1) to balance
  435. positive vs negative examples. Default = -1 (no weighting).
  436. gamma: Exponent of the modulating factor (1 - p_t) to
  437. balance easy vs hard examples.
  438. Returns:
  439. Loss tensor
  440. """
  441. prob = inputs.sigmoid()
  442. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  443. p_t = prob * targets + (1 - prob) * (1 - targets)
  444. loss = ce_loss * ((1 - p_t) ** gamma)
  445. if alpha >= 0:
  446. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  447. loss = alpha_t * loss
  448. return loss
  449. # ---------------------------- NMS ----------------------------
  450. def nms(bboxes, scores, nms_thresh):
  451. """"Pure Python NMS."""
  452. x1 = bboxes[:, 0] #xmin
  453. y1 = bboxes[:, 1] #ymin
  454. x2 = bboxes[:, 2] #xmax
  455. y2 = bboxes[:, 3] #ymax
  456. areas = (x2 - x1) * (y2 - y1)
  457. order = scores.argsort()[::-1]
  458. keep = []
  459. while order.size > 0:
  460. i = order[0]
  461. keep.append(i)
  462. # compute iou
  463. xx1 = np.maximum(x1[i], x1[order[1:]])
  464. yy1 = np.maximum(y1[i], y1[order[1:]])
  465. xx2 = np.minimum(x2[i], x2[order[1:]])
  466. yy2 = np.minimum(y2[i], y2[order[1:]])
  467. w = np.maximum(1e-10, xx2 - xx1)
  468. h = np.maximum(1e-10, yy2 - yy1)
  469. inter = w * h
  470. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  471. #reserve all the boundingbox whose ovr less than thresh
  472. inds = np.where(iou <= nms_thresh)[0]
  473. order = order[inds + 1]
  474. return keep
  475. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  476. # nms
  477. keep = nms(bboxes, scores, nms_thresh)
  478. scores = scores[keep]
  479. labels = labels[keep]
  480. bboxes = bboxes[keep]
  481. return scores, labels, bboxes
  482. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  483. # nms
  484. keep = np.zeros(len(bboxes), dtype=np.int32)
  485. for i in range(num_classes):
  486. inds = np.where(labels == i)[0]
  487. if len(inds) == 0:
  488. continue
  489. c_bboxes = bboxes[inds]
  490. c_scores = scores[inds]
  491. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  492. keep[inds[c_keep]] = 1
  493. keep = np.where(keep > 0)
  494. scores = scores[keep]
  495. labels = labels[keep]
  496. bboxes = bboxes[keep]
  497. return scores, labels, bboxes
  498. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  499. if class_agnostic:
  500. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  501. else:
  502. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)