misc.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.distributed as dist
  5. from torch.utils.data import DataLoader, DistributedSampler
  6. import cv2
  7. import math
  8. import time
  9. import datetime
  10. import numpy as np
  11. from copy import deepcopy
  12. from thop import profile
  13. from collections import defaultdict, deque
  14. from .distributed_utils import is_dist_avail_and_initialized
  15. # ---------------------------- Train tools ----------------------------
  16. class SmoothedValue(object):
  17. """Track a series of values and provide access to smoothed values over a
  18. window or the global series average.
  19. """
  20. def __init__(self, window_size=20, fmt=None):
  21. if fmt is None:
  22. fmt = "{median:.4f} ({global_avg:.4f})"
  23. self.deque = deque(maxlen=window_size)
  24. self.total = 0.0
  25. self.count = 0
  26. self.fmt = fmt
  27. def update(self, value, n=1):
  28. self.deque.append(value)
  29. self.count += n
  30. self.total += value * n
  31. def synchronize_between_processes(self):
  32. """
  33. Warning: does not synchronize the deque!
  34. """
  35. if not is_dist_avail_and_initialized():
  36. return
  37. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  38. dist.barrier()
  39. dist.all_reduce(t)
  40. t = t.tolist()
  41. self.count = int(t[0])
  42. self.total = t[1]
  43. @property
  44. def median(self):
  45. d = torch.tensor(list(self.deque))
  46. return d.median().item()
  47. @property
  48. def avg(self):
  49. d = torch.tensor(list(self.deque), dtype=torch.float32)
  50. return d.mean().item()
  51. @property
  52. def global_avg(self):
  53. return self.total / self.count
  54. @property
  55. def max(self):
  56. return max(self.deque)
  57. @property
  58. def value(self):
  59. return self.deque[-1]
  60. def __str__(self):
  61. return self.fmt.format(
  62. median=self.median,
  63. avg=self.avg,
  64. global_avg=self.global_avg,
  65. max=self.max,
  66. value=self.value)
  67. class MetricLogger(object):
  68. def __init__(self, delimiter="\t"):
  69. self.meters = defaultdict(SmoothedValue)
  70. self.delimiter = delimiter
  71. def update(self, **kwargs):
  72. for k, v in kwargs.items():
  73. if isinstance(v, torch.Tensor):
  74. v = v.item()
  75. assert isinstance(v, (float, int))
  76. self.meters[k].update(v)
  77. def __getattr__(self, attr):
  78. if attr in self.meters:
  79. return self.meters[attr]
  80. if attr in self.__dict__:
  81. return self.__dict__[attr]
  82. raise AttributeError("'{}' object has no attribute '{}'".format(
  83. type(self).__name__, attr))
  84. def __str__(self):
  85. loss_str = []
  86. for name, meter in self.meters.items():
  87. loss_str.append(
  88. "{}: {}".format(name, str(meter))
  89. )
  90. return self.delimiter.join(loss_str)
  91. def synchronize_between_processes(self):
  92. for meter in self.meters.values():
  93. meter.synchronize_between_processes()
  94. def add_meter(self, name, meter):
  95. self.meters[name] = meter
  96. def log_every(self, iterable, print_freq, header=None):
  97. i = 0
  98. if not header:
  99. header = ''
  100. start_time = time.time()
  101. end = time.time()
  102. iter_time = SmoothedValue(fmt='{avg:.4f}')
  103. data_time = SmoothedValue(fmt='{avg:.4f}')
  104. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  105. if torch.cuda.is_available():
  106. log_msg = self.delimiter.join([
  107. header,
  108. '[{0' + space_fmt + '}/{1}]',
  109. 'eta: {eta}',
  110. '{meters}',
  111. 'time: {time}',
  112. 'data: {data}',
  113. 'max mem: {memory:.0f}'
  114. ])
  115. else:
  116. log_msg = self.delimiter.join([
  117. header,
  118. '[{0' + space_fmt + '}/{1}]',
  119. 'eta: {eta}',
  120. '{meters}',
  121. 'time: {time}',
  122. 'data: {data}'
  123. ])
  124. MB = 1024.0 * 1024.0
  125. for obj in iterable:
  126. data_time.update(time.time() - end)
  127. yield obj
  128. iter_time.update(time.time() - end)
  129. if i % print_freq == 0 or i == len(iterable) - 1:
  130. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  131. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  132. if torch.cuda.is_available():
  133. print(log_msg.format(
  134. i, len(iterable), eta=eta_string,
  135. meters=str(self),
  136. time=str(iter_time), data=str(data_time),
  137. memory=torch.cuda.max_memory_allocated() / MB))
  138. else:
  139. print(log_msg.format(
  140. i, len(iterable), eta=eta_string,
  141. meters=str(self),
  142. time=str(iter_time), data=str(data_time)))
  143. i += 1
  144. end = time.time()
  145. total_time = time.time() - start_time
  146. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  147. print('{} Total time: {} ({:.4f} s / it)'.format(
  148. header, total_time_str, total_time / len(iterable)))
  149. # ---------------------------- For Dataset ----------------------------
  150. ## build dataloader
  151. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  152. # distributed
  153. if args.distributed:
  154. sampler = DistributedSampler(dataset)
  155. else:
  156. sampler = torch.utils.data.RandomSampler(dataset)
  157. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  158. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  159. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  160. return dataloader
  161. ## collate_fn for dataloader
  162. class CollateFunc(object):
  163. def __call__(self, batch):
  164. targets = []
  165. images = []
  166. for sample in batch:
  167. image = sample[0]
  168. target = sample[1]
  169. images.append(image)
  170. targets.append(target)
  171. images = torch.stack(images, 0) # [B, C, H, W]
  172. return images, targets
  173. # ---------------------------- For Loss ----------------------------
  174. ## FocalLoss
  175. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  176. """
  177. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  178. Args:
  179. inputs: A float tensor of arbitrary shape.
  180. The predictions for each example.
  181. targets: A float tensor with the same shape as inputs. Stores the binary
  182. classification label for each element in inputs
  183. (0 for the negative class and 1 for the positive class).
  184. alpha: (optional) Weighting factor in range (0,1) to balance
  185. positive vs negative examples. Default = -1 (no weighting).
  186. gamma: Exponent of the modulating factor (1 - p_t) to
  187. balance easy vs hard examples.
  188. Returns:
  189. Loss tensor
  190. """
  191. prob = inputs.sigmoid()
  192. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  193. p_t = prob * targets + (1 - prob) * (1 - targets)
  194. loss = ce_loss * ((1 - p_t) ** gamma)
  195. if alpha >= 0:
  196. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  197. loss = alpha_t * loss
  198. return loss.mean(1).sum() / num_boxes
  199. ## Variable FocalLoss
  200. def varifocal_loss_with_logits(pred_logits,
  201. gt_score,
  202. label,
  203. normalizer=1.0,
  204. alpha=0.75,
  205. gamma=2.0):
  206. pred_score = F.sigmoid(pred_logits)
  207. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  208. loss = F.binary_cross_entropy_with_logits(
  209. pred_logits, gt_score, weight=weight, reduction='none')
  210. return loss.mean(1).sum() / normalizer
  211. ## InverseSigmoid
  212. def inverse_sigmoid(x, eps=1e-5):
  213. x = x.clamp(min=0, max=1)
  214. x1 = x.clamp(min=eps)
  215. x2 = (1 - x).clamp(min=eps)
  216. return torch.log(x1/x2)
  217. # ---------------------------- For Model ----------------------------
  218. ## fuse Conv & BN layer
  219. def fuse_conv_bn(module):
  220. """Recursively fuse conv and bn in a module.
  221. During inference, the functionary of batch norm layers is turned off
  222. but only the mean and var alone channels are used, which exposes the
  223. chance to fuse it with the preceding conv layers to save computations and
  224. simplify network structures.
  225. Args:
  226. module (nn.Module): Module to be fused.
  227. Returns:
  228. nn.Module: Fused module.
  229. """
  230. last_conv = None
  231. last_conv_name = None
  232. def _fuse_conv_bn(conv, bn):
  233. """Fuse conv and bn into one module.
  234. Args:
  235. conv (nn.Module): Conv to be fused.
  236. bn (nn.Module): BN to be fused.
  237. Returns:
  238. nn.Module: Fused module.
  239. """
  240. conv_w = conv.weight
  241. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  242. bn.running_mean)
  243. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  244. conv.weight = nn.Parameter(conv_w *
  245. factor.reshape([conv.out_channels, 1, 1, 1]))
  246. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  247. return conv
  248. for name, child in module.named_children():
  249. if isinstance(child,
  250. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  251. if last_conv is None: # only fuse BN that is after Conv
  252. continue
  253. fused_conv = _fuse_conv_bn(last_conv, child)
  254. module._modules[last_conv_name] = fused_conv
  255. # To reduce changes, set BN as Identity instead of deleting it.
  256. module._modules[name] = nn.Identity()
  257. last_conv = None
  258. elif isinstance(child, nn.Conv2d):
  259. last_conv = child
  260. last_conv_name = name
  261. else:
  262. fuse_conv_bn(child)
  263. return module
  264. ## replace module
  265. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  266. """
  267. Replace given type in module to a new type. mostly used in deploy.
  268. Args:
  269. module (nn.Module): model to apply replace operation.
  270. replaced_module_type (Type): module type to be replaced.
  271. new_module_type (Type)
  272. replace_func (function): python function to describe replace logic. Defalut value None.
  273. Returns:
  274. model (nn.Module): module that already been replaced.
  275. """
  276. def default_replace_func(replaced_module_type, new_module_type):
  277. return new_module_type()
  278. if replace_func is None:
  279. replace_func = default_replace_func
  280. model = module
  281. if isinstance(module, replaced_module_type):
  282. model = replace_func(replaced_module_type, new_module_type)
  283. else: # recurrsively replace
  284. for name, child in module.named_children():
  285. new_child = replace_module(child, replaced_module_type, new_module_type)
  286. if new_child is not child: # child is already replaced
  287. model.add_module(name, new_child)
  288. return model
  289. ## compute FLOPs & Parameters
  290. def compute_flops(model, img_size, device):
  291. x = torch.randn(1, 3, img_size, img_size).to(device)
  292. print('==============================')
  293. flops, params = profile(model, inputs=(x, ), verbose=False)
  294. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  295. print('Params : {:.2f} M'.format(params / 1e6))
  296. ## load trained weight
  297. def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
  298. # check ckpt file
  299. if path_to_ckpt is None:
  300. print('no weight file ...')
  301. else:
  302. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  303. print('--------------------------------------')
  304. print('Best model infor:')
  305. print('Epoch: {}'.format(checkpoint["epoch"]))
  306. print('mAP: {}'.format(checkpoint["mAP"]))
  307. print('--------------------------------------')
  308. checkpoint_state_dict = checkpoint["model"]
  309. model.load_state_dict(checkpoint_state_dict)
  310. print('Finished loading model!')
  311. # fuse rep conv
  312. if fuse_rep_conv:
  313. print("Fusing RepConv ...")
  314. for m in model.modules():
  315. if hasattr(m, 'fuse_convs'):
  316. m.fuse_convs()
  317. # fuse conv & bn
  318. if fuse_cbn:
  319. print('Fusing Conv & BN ...')
  320. model = fuse_conv_bn(model)
  321. return model
  322. def get_total_grad_norm(parameters, norm_type=2):
  323. parameters = list(filter(lambda p: p.grad is not None, parameters))
  324. norm_type = float(norm_type)
  325. device = parameters[0].grad.device
  326. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
  327. norm_type)
  328. return total_norm
  329. ## Model EMA
  330. class ModelEMA(object):
  331. def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
  332. # Create EMA
  333. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  334. self.updates = 0 # number of EMA updates
  335. self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
  336. for p in self.ema.parameters():
  337. p.requires_grad_(False)
  338. if resume is not None and resume.lower() != "none":
  339. self.load_resume(resume)
  340. print("Initialize ModelEMA's updates: {}".format(self.updates))
  341. def load_resume(self, resume):
  342. checkpoint = torch.load(resume)
  343. if 'ema_updates' in checkpoint.keys():
  344. print('--Load ModelEMA updates from the checkpoint: ', resume)
  345. # checkpoint state dict
  346. self.updates = checkpoint.pop("ema_updates")
  347. def is_parallel(self, model):
  348. # Returns True if model is of type DP or DDP
  349. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  350. def de_parallel(self, model):
  351. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  352. return model.module if self.is_parallel(model) else model
  353. def copy_attr(self, a, b, include=(), exclude=()):
  354. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  355. for k, v in b.__dict__.items():
  356. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  357. continue
  358. else:
  359. setattr(a, k, v)
  360. def update(self, model):
  361. # Update EMA parameters
  362. self.updates += 1
  363. d = self.decay(self.updates)
  364. msd = self.de_parallel(model).state_dict() # model state_dict
  365. for k, v in self.ema.state_dict().items():
  366. if v.dtype.is_floating_point: # true for FP16 and FP32
  367. v *= d
  368. v += (1 - d) * msd[k].detach()
  369. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  370. # Update EMA attributes
  371. self.copy_attr(self.ema, model, include, exclude)
  372. ## SiLU
  373. class SiLU(nn.Module):
  374. """export-friendly version of nn.SiLU()"""
  375. @staticmethod
  376. def forward(x):
  377. return x * torch.sigmoid(x)
  378. # ---------------------------- NMS ----------------------------
  379. ## basic NMS
  380. def nms(bboxes, scores, nms_thresh):
  381. """"Pure Python NMS."""
  382. x1 = bboxes[:, 0] #xmin
  383. y1 = bboxes[:, 1] #ymin
  384. x2 = bboxes[:, 2] #xmax
  385. y2 = bboxes[:, 3] #ymax
  386. areas = (x2 - x1) * (y2 - y1)
  387. order = scores.argsort()[::-1]
  388. keep = []
  389. while order.size > 0:
  390. i = order[0]
  391. keep.append(i)
  392. # compute iou
  393. xx1 = np.maximum(x1[i], x1[order[1:]])
  394. yy1 = np.maximum(y1[i], y1[order[1:]])
  395. xx2 = np.minimum(x2[i], x2[order[1:]])
  396. yy2 = np.minimum(y2[i], y2[order[1:]])
  397. w = np.maximum(1e-10, xx2 - xx1)
  398. h = np.maximum(1e-10, yy2 - yy1)
  399. inter = w * h
  400. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  401. #reserve all the boundingbox whose ovr less than thresh
  402. inds = np.where(iou <= nms_thresh)[0]
  403. order = order[inds + 1]
  404. return keep
  405. ## class-agnostic NMS
  406. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  407. # nms
  408. keep = nms(bboxes, scores, nms_thresh)
  409. scores = scores[keep]
  410. labels = labels[keep]
  411. bboxes = bboxes[keep]
  412. return scores, labels, bboxes
  413. ## class-aware NMS
  414. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  415. # nms
  416. keep = np.zeros(len(bboxes), dtype=np.int32)
  417. for i in range(num_classes):
  418. inds = np.where(labels == i)[0]
  419. if len(inds) == 0:
  420. continue
  421. c_bboxes = bboxes[inds]
  422. c_scores = scores[inds]
  423. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  424. keep[inds[c_keep]] = 1
  425. keep = np.where(keep > 0)
  426. scores = scores[keep]
  427. labels = labels[keep]
  428. bboxes = bboxes[keep]
  429. return scores, labels, bboxes
  430. ## multi-class NMS
  431. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  432. if class_agnostic:
  433. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  434. else:
  435. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  436. # ---------------------------- Processor for Deployment ----------------------------
  437. ## Pre-processer
  438. class PreProcessor(object):
  439. def __init__(self, img_size, keep_ratio=True):
  440. self.img_size = img_size
  441. self.keep_ratio = keep_ratio
  442. self.input_size = [img_size, img_size]
  443. def __call__(self, image, swap=(2, 0, 1)):
  444. """
  445. Input:
  446. image: (ndarray) [H, W, 3] or [H, W]
  447. formar: color format
  448. """
  449. if len(image.shape) == 3:
  450. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  451. else:
  452. padded_img = np.ones(self.input_size, np.float32) * 114.
  453. # resize
  454. if self.keep_ratio:
  455. orig_h, orig_w = image.shape[:2]
  456. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  457. resize_size = (int(orig_w * r), int(orig_h * r))
  458. if r != 1:
  459. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  460. else:
  461. resized_img = image
  462. # padding
  463. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  464. # [H, W, C] -> [C, H, W]
  465. padded_img = padded_img.transpose(swap)
  466. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  467. return padded_img, r
  468. else:
  469. orig_h, orig_w = image.shape[:2]
  470. r = np.array([self.input_size[0] / orig_w, self.input_size[1] / orig_w])
  471. if [orig_h, orig_w] == self.input_size:
  472. resized_img = image
  473. else:
  474. resized_img = cv2.resize(image, self.input_size, interpolation=cv2.INTER_LINEAR)
  475. return resized_img, r
  476. ## Post-processer
  477. class PostProcessor(object):
  478. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  479. self.num_classes = num_classes
  480. self.conf_thresh = conf_thresh
  481. self.nms_thresh = nms_thresh
  482. def __call__(self, predictions):
  483. """
  484. Input:
  485. predictions: (ndarray) [n_anchors_all, 4+1+C]
  486. """
  487. bboxes = predictions[..., :4]
  488. scores = predictions[..., 4:]
  489. # scores & labels
  490. labels = np.argmax(scores, axis=1) # [M,]
  491. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  492. # thresh
  493. keep = np.where(scores > self.conf_thresh)
  494. scores = scores[keep]
  495. labels = labels[keep]
  496. bboxes = bboxes[keep]
  497. # nms
  498. scores, labels, bboxes = multiclass_nms(
  499. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  500. return bboxes, scores, labels