misc.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.distributed as dist
  5. from torch.utils.data import DataLoader, DistributedSampler
  6. import cv2
  7. import math
  8. import time
  9. import datetime
  10. import numpy as np
  11. from copy import deepcopy
  12. from thop import profile
  13. from collections import defaultdict, deque
  14. from .distributed_utils import is_dist_avail_and_initialized
  15. # ---------------------------- Train tools ----------------------------
  16. class SmoothedValue(object):
  17. """Track a series of values and provide access to smoothed values over a
  18. window or the global series average.
  19. """
  20. def __init__(self, window_size=20, fmt=None):
  21. if fmt is None:
  22. fmt = "{median:.4f} ({global_avg:.4f})"
  23. self.deque = deque(maxlen=window_size)
  24. self.total = 0.0
  25. self.count = 0
  26. self.fmt = fmt
  27. def update(self, value, n=1):
  28. self.deque.append(value)
  29. self.count += n
  30. self.total += value * n
  31. def synchronize_between_processes(self):
  32. """
  33. Warning: does not synchronize the deque!
  34. """
  35. if not is_dist_avail_and_initialized():
  36. return
  37. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  38. dist.barrier()
  39. dist.all_reduce(t)
  40. t = t.tolist()
  41. self.count = int(t[0])
  42. self.total = t[1]
  43. @property
  44. def median(self):
  45. d = torch.tensor(list(self.deque))
  46. return d.median().item()
  47. @property
  48. def avg(self):
  49. d = torch.tensor(list(self.deque), dtype=torch.float32)
  50. return d.mean().item()
  51. @property
  52. def global_avg(self):
  53. return self.total / self.count
  54. @property
  55. def max(self):
  56. return max(self.deque)
  57. @property
  58. def value(self):
  59. return self.deque[-1]
  60. def __str__(self):
  61. return self.fmt.format(
  62. median=self.median,
  63. avg=self.avg,
  64. global_avg=self.global_avg,
  65. max=self.max,
  66. value=self.value)
  67. class MetricLogger(object):
  68. def __init__(self, delimiter="\t"):
  69. self.meters = defaultdict(SmoothedValue)
  70. self.delimiter = delimiter
  71. def update(self, **kwargs):
  72. for k, v in kwargs.items():
  73. if isinstance(v, torch.Tensor):
  74. v = v.item()
  75. assert isinstance(v, (float, int))
  76. self.meters[k].update(v)
  77. def __getattr__(self, attr):
  78. if attr in self.meters:
  79. return self.meters[attr]
  80. if attr in self.__dict__:
  81. return self.__dict__[attr]
  82. raise AttributeError("'{}' object has no attribute '{}'".format(
  83. type(self).__name__, attr))
  84. def __str__(self):
  85. loss_str = []
  86. for name, meter in self.meters.items():
  87. loss_str.append(
  88. "{}: {}".format(name, str(meter))
  89. )
  90. return self.delimiter.join(loss_str)
  91. def synchronize_between_processes(self):
  92. for meter in self.meters.values():
  93. meter.synchronize_between_processes()
  94. def add_meter(self, name, meter):
  95. self.meters[name] = meter
  96. def log_every(self, iterable, print_freq, header=None):
  97. i = 0
  98. if not header:
  99. header = ''
  100. start_time = time.time()
  101. end = time.time()
  102. iter_time = SmoothedValue(fmt='{avg:.4f}')
  103. data_time = SmoothedValue(fmt='{avg:.4f}')
  104. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  105. if torch.cuda.is_available():
  106. log_msg = self.delimiter.join([
  107. header,
  108. '[{0' + space_fmt + '}/{1}]',
  109. 'eta: {eta}',
  110. '{meters}',
  111. 'time: {time}',
  112. 'data: {data}',
  113. 'max mem: {memory:.0f}'
  114. ])
  115. else:
  116. log_msg = self.delimiter.join([
  117. header,
  118. '[{0' + space_fmt + '}/{1}]',
  119. 'eta: {eta}',
  120. '{meters}',
  121. 'time: {time}',
  122. 'data: {data}'
  123. ])
  124. MB = 1024.0 * 1024.0
  125. for obj in iterable:
  126. data_time.update(time.time() - end)
  127. yield obj
  128. iter_time.update(time.time() - end)
  129. if i % print_freq == 0 or i == len(iterable) - 1:
  130. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  131. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  132. if torch.cuda.is_available():
  133. print(log_msg.format(
  134. i, len(iterable), eta=eta_string,
  135. meters=str(self),
  136. time=str(iter_time), data=str(data_time),
  137. memory=torch.cuda.max_memory_allocated() / MB))
  138. else:
  139. print(log_msg.format(
  140. i, len(iterable), eta=eta_string,
  141. meters=str(self),
  142. time=str(iter_time), data=str(data_time)))
  143. i += 1
  144. end = time.time()
  145. total_time = time.time() - start_time
  146. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  147. print('{} Total time: {} ({:.4f} s / it)'.format(
  148. header, total_time_str, total_time / len(iterable)))
  149. # ---------------------------- For Dataset ----------------------------
  150. ## build dataloader
  151. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  152. # distributed
  153. if args.distributed:
  154. sampler = DistributedSampler(dataset)
  155. else:
  156. sampler = torch.utils.data.RandomSampler(dataset)
  157. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  158. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  159. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  160. return dataloader
  161. ## collate_fn for dataloader
  162. class CollateFunc(object):
  163. def __call__(self, batch):
  164. targets = []
  165. images = []
  166. for sample in batch:
  167. image = sample[0]
  168. target = sample[1]
  169. images.append(image)
  170. targets.append(target)
  171. images = torch.stack(images, 0) # [B, C, H, W]
  172. return images, targets
  173. # ---------------------------- For Loss ----------------------------
  174. ## FocalLoss
  175. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  176. """
  177. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  178. Args:
  179. inputs: A float tensor of arbitrary shape.
  180. The predictions for each example.
  181. targets: A float tensor with the same shape as inputs. Stores the binary
  182. classification label for each element in inputs
  183. (0 for the negative class and 1 for the positive class).
  184. alpha: (optional) Weighting factor in range (0,1) to balance
  185. positive vs negative examples. Default = -1 (no weighting).
  186. gamma: Exponent of the modulating factor (1 - p_t) to
  187. balance easy vs hard examples.
  188. Returns:
  189. Loss tensor
  190. """
  191. prob = inputs.sigmoid()
  192. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  193. p_t = prob * targets + (1 - prob) * (1 - targets)
  194. loss = ce_loss * ((1 - p_t) ** gamma)
  195. if alpha >= 0:
  196. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  197. loss = alpha_t * loss
  198. return loss.mean(1).sum() / num_boxes
  199. ## Variable FocalLoss
  200. def varifocal_loss_with_logits(pred_logits,
  201. gt_score,
  202. label,
  203. normalizer=1.0,
  204. alpha=0.75,
  205. gamma=2.0):
  206. pred_score = F.sigmoid(pred_logits)
  207. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  208. loss = F.binary_cross_entropy_with_logits(
  209. pred_logits, gt_score, weight=weight, reduction='none')
  210. return loss.mean(1).sum() / normalizer
  211. ## InverseSigmoid
  212. def inverse_sigmoid(x, eps=1e-5):
  213. x = x.clamp(min=0, max=1)
  214. x1 = x.clamp(min=eps)
  215. x2 = (1 - x).clamp(min=eps)
  216. return torch.log(x1/x2)
  217. # ---------------------------- For Model ----------------------------
  218. ## fuse Conv & BN layer
  219. def fuse_conv_bn(module):
  220. """Recursively fuse conv and bn in a module.
  221. During inference, the functionary of batch norm layers is turned off
  222. but only the mean and var alone channels are used, which exposes the
  223. chance to fuse it with the preceding conv layers to save computations and
  224. simplify network structures.
  225. Args:
  226. module (nn.Module): Module to be fused.
  227. Returns:
  228. nn.Module: Fused module.
  229. """
  230. last_conv = None
  231. last_conv_name = None
  232. def _fuse_conv_bn(conv, bn):
  233. """Fuse conv and bn into one module.
  234. Args:
  235. conv (nn.Module): Conv to be fused.
  236. bn (nn.Module): BN to be fused.
  237. Returns:
  238. nn.Module: Fused module.
  239. """
  240. conv_w = conv.weight
  241. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  242. bn.running_mean)
  243. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  244. conv.weight = nn.Parameter(conv_w *
  245. factor.reshape([conv.out_channels, 1, 1, 1]))
  246. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  247. return conv
  248. for name, child in module.named_children():
  249. if isinstance(child,
  250. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  251. if last_conv is None: # only fuse BN that is after Conv
  252. continue
  253. fused_conv = _fuse_conv_bn(last_conv, child)
  254. module._modules[last_conv_name] = fused_conv
  255. # To reduce changes, set BN as Identity instead of deleting it.
  256. module._modules[name] = nn.Identity()
  257. last_conv = None
  258. elif isinstance(child, nn.Conv2d):
  259. last_conv = child
  260. last_conv_name = name
  261. else:
  262. fuse_conv_bn(child)
  263. return module
  264. ## replace module
  265. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  266. """
  267. Replace given type in module to a new type. mostly used in deploy.
  268. Args:
  269. module (nn.Module): model to apply replace operation.
  270. replaced_module_type (Type): module type to be replaced.
  271. new_module_type (Type)
  272. replace_func (function): python function to describe replace logic. Defalut value None.
  273. Returns:
  274. model (nn.Module): module that already been replaced.
  275. """
  276. def default_replace_func(replaced_module_type, new_module_type):
  277. return new_module_type()
  278. if replace_func is None:
  279. replace_func = default_replace_func
  280. model = module
  281. if isinstance(module, replaced_module_type):
  282. model = replace_func(replaced_module_type, new_module_type)
  283. else: # recurrsively replace
  284. for name, child in module.named_children():
  285. new_child = replace_module(child, replaced_module_type, new_module_type)
  286. if new_child is not child: # child is already replaced
  287. model.add_module(name, new_child)
  288. return model
  289. ## compute FLOPs & Parameters
  290. def compute_flops(model, img_size, device):
  291. x = torch.randn(1, 3, img_size, img_size).to(device)
  292. print('==============================')
  293. flops, params = profile(model, inputs=(x, ), verbose=False)
  294. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  295. print('Params : {:.2f} M'.format(params / 1e6))
  296. ## load trained weight
  297. def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
  298. # Check ckpt file
  299. if path_to_ckpt is None:
  300. print('no weight file ...')
  301. else:
  302. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  303. print('--------------------------------------')
  304. print('Best model infor:')
  305. print('Epoch: {}'.format(checkpoint["epoch"]))
  306. print('mAP: {}'.format(checkpoint["mAP"]))
  307. print('--------------------------------------')
  308. checkpoint_state_dict = checkpoint["model"]
  309. model.load_state_dict(checkpoint_state_dict)
  310. print('Finished loading model!')
  311. # Fuse conv & bn
  312. if fuse_cbn:
  313. print('Fusing Conv & BN ...')
  314. model = fuse_conv_bn(model)
  315. # Fuse RepConv
  316. if hasattr(model, "switch_deploy"):
  317. print("Reparam ...")
  318. model.switch_deploy()
  319. return model
  320. ## Model EMA
  321. class ModelEMA(object):
  322. def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
  323. # Create EMA
  324. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  325. self.updates = 0 # number of EMA updates
  326. self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
  327. for p in self.ema.parameters():
  328. p.requires_grad_(False)
  329. if resume is not None and resume.lower() != "none":
  330. self.load_resume(resume)
  331. print("Initialize ModelEMA's updates: {}".format(self.updates))
  332. def load_resume(self, resume):
  333. checkpoint = torch.load(resume)
  334. if 'ema_updates' in checkpoint.keys():
  335. print('--Load ModelEMA updates from the checkpoint: ', resume)
  336. # checkpoint state dict
  337. self.updates = checkpoint.pop("ema_updates")
  338. def is_parallel(self, model):
  339. # Returns True if model is of type DP or DDP
  340. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  341. def de_parallel(self, model):
  342. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  343. return model.module if self.is_parallel(model) else model
  344. def copy_attr(self, a, b, include=(), exclude=()):
  345. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  346. for k, v in b.__dict__.items():
  347. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  348. continue
  349. else:
  350. setattr(a, k, v)
  351. def update(self, model):
  352. # Update EMA parameters
  353. self.updates += 1
  354. d = self.decay(self.updates)
  355. msd = self.de_parallel(model).state_dict() # model state_dict
  356. for k, v in self.ema.state_dict().items():
  357. if v.dtype.is_floating_point: # true for FP16 and FP32
  358. v *= d
  359. v += (1 - d) * msd[k].detach()
  360. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  361. # Update EMA attributes
  362. self.copy_attr(self.ema, model, include, exclude)
  363. ## SiLU
  364. class SiLU(nn.Module):
  365. """export-friendly version of nn.SiLU()"""
  366. @staticmethod
  367. def forward(x):
  368. return x * torch.sigmoid(x)
  369. # ---------------------------- NMS ----------------------------
  370. ## basic NMS
  371. def nms(bboxes, scores, nms_thresh):
  372. """"Pure Python NMS."""
  373. x1 = bboxes[:, 0] #xmin
  374. y1 = bboxes[:, 1] #ymin
  375. x2 = bboxes[:, 2] #xmax
  376. y2 = bboxes[:, 3] #ymax
  377. areas = (x2 - x1) * (y2 - y1)
  378. order = scores.argsort()[::-1]
  379. keep = []
  380. while order.size > 0:
  381. i = order[0]
  382. keep.append(i)
  383. # compute iou
  384. xx1 = np.maximum(x1[i], x1[order[1:]])
  385. yy1 = np.maximum(y1[i], y1[order[1:]])
  386. xx2 = np.minimum(x2[i], x2[order[1:]])
  387. yy2 = np.minimum(y2[i], y2[order[1:]])
  388. w = np.maximum(1e-10, xx2 - xx1)
  389. h = np.maximum(1e-10, yy2 - yy1)
  390. inter = w * h
  391. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  392. #reserve all the boundingbox whose ovr less than thresh
  393. inds = np.where(iou <= nms_thresh)[0]
  394. order = order[inds + 1]
  395. return keep
  396. ## class-agnostic NMS
  397. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  398. # nms
  399. keep = nms(bboxes, scores, nms_thresh)
  400. scores = scores[keep]
  401. labels = labels[keep]
  402. bboxes = bboxes[keep]
  403. return scores, labels, bboxes
  404. ## class-aware NMS
  405. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  406. # nms
  407. keep = np.zeros(len(bboxes), dtype=np.int32)
  408. for i in range(num_classes):
  409. inds = np.where(labels == i)[0]
  410. if len(inds) == 0:
  411. continue
  412. c_bboxes = bboxes[inds]
  413. c_scores = scores[inds]
  414. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  415. keep[inds[c_keep]] = 1
  416. keep = np.where(keep > 0)
  417. scores = scores[keep]
  418. labels = labels[keep]
  419. bboxes = bboxes[keep]
  420. return scores, labels, bboxes
  421. ## multi-class NMS
  422. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  423. if class_agnostic:
  424. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  425. else:
  426. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  427. # ---------------------------- Processor for Deployment ----------------------------
  428. ## Pre-processer
  429. class PreProcessor(object):
  430. def __init__(self, img_size, keep_ratio=True):
  431. self.img_size = img_size
  432. self.keep_ratio = keep_ratio
  433. self.input_size = [img_size, img_size]
  434. def __call__(self, image, swap=(2, 0, 1)):
  435. """
  436. Input:
  437. image: (ndarray) [H, W, 3] or [H, W]
  438. formar: color format
  439. """
  440. if len(image.shape) == 3:
  441. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  442. else:
  443. padded_img = np.ones(self.input_size, np.float32) * 114.
  444. # resize
  445. if self.keep_ratio:
  446. orig_h, orig_w = image.shape[:2]
  447. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  448. resize_size = (int(orig_w * r), int(orig_h * r))
  449. if r != 1:
  450. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  451. else:
  452. resized_img = image
  453. # padding
  454. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  455. # [H, W, C] -> [C, H, W]
  456. padded_img = padded_img.transpose(swap)
  457. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  458. return padded_img, r
  459. else:
  460. orig_h, orig_w = image.shape[:2]
  461. r = np.array([self.input_size[0] / orig_w, self.input_size[1] / orig_w])
  462. if [orig_h, orig_w] == self.input_size:
  463. resized_img = image
  464. else:
  465. resized_img = cv2.resize(image, self.input_size, interpolation=cv2.INTER_LINEAR)
  466. return resized_img, r
  467. ## Post-processer
  468. class PostProcessor(object):
  469. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  470. self.num_classes = num_classes
  471. self.conf_thresh = conf_thresh
  472. self.nms_thresh = nms_thresh
  473. def __call__(self, predictions):
  474. """
  475. Input:
  476. predictions: (ndarray) [n_anchors_all, 4+1+C]
  477. """
  478. bboxes = predictions[..., :4]
  479. scores = predictions[..., 4:]
  480. # scores & labels
  481. labels = np.argmax(scores, axis=1) # [M,]
  482. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  483. # thresh
  484. keep = np.where(scores > self.conf_thresh)
  485. scores = scores[keep]
  486. labels = labels[keep]
  487. bboxes = bboxes[keep]
  488. # nms
  489. scores, labels, bboxes = multiclass_nms(
  490. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  491. return bboxes, scores, labels