misc.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.distributed as dist
  5. from torch.utils.data import DataLoader, DistributedSampler
  6. import cv2
  7. import math
  8. import time
  9. import datetime
  10. import numpy as np
  11. from copy import deepcopy
  12. from thop import profile
  13. from collections import defaultdict, deque
  14. from .distributed_utils import is_dist_avail_and_initialized
  15. # ---------------------------- Train tools ----------------------------
  16. class SmoothedValue(object):
  17. """Track a series of values and provide access to smoothed values over a
  18. window or the global series average.
  19. """
  20. def __init__(self, window_size=20, fmt=None):
  21. if fmt is None:
  22. fmt = "{median:.4f} ({global_avg:.4f})"
  23. self.deque = deque(maxlen=window_size)
  24. self.total = 0.0
  25. self.count = 0
  26. self.fmt = fmt
  27. def update(self, value, n=1):
  28. self.deque.append(value)
  29. self.count += n
  30. self.total += value * n
  31. def synchronize_between_processes(self):
  32. """
  33. Warning: does not synchronize the deque!
  34. """
  35. if not is_dist_avail_and_initialized():
  36. return
  37. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  38. dist.barrier()
  39. dist.all_reduce(t)
  40. t = t.tolist()
  41. self.count = int(t[0])
  42. self.total = t[1]
  43. @property
  44. def median(self):
  45. d = torch.tensor(list(self.deque))
  46. return d.median().item()
  47. @property
  48. def avg(self):
  49. d = torch.tensor(list(self.deque), dtype=torch.float32)
  50. return d.mean().item()
  51. @property
  52. def global_avg(self):
  53. return self.total / self.count
  54. @property
  55. def max(self):
  56. return max(self.deque)
  57. @property
  58. def value(self):
  59. return self.deque[-1]
  60. def __str__(self):
  61. return self.fmt.format(
  62. median=self.median,
  63. avg=self.avg,
  64. global_avg=self.global_avg,
  65. max=self.max,
  66. value=self.value)
  67. class MetricLogger(object):
  68. def __init__(self, delimiter="\t"):
  69. self.meters = defaultdict(SmoothedValue)
  70. self.delimiter = delimiter
  71. def update(self, **kwargs):
  72. for k, v in kwargs.items():
  73. if isinstance(v, torch.Tensor):
  74. v = v.item()
  75. assert isinstance(v, (float, int))
  76. self.meters[k].update(v)
  77. def __getattr__(self, attr):
  78. if attr in self.meters:
  79. return self.meters[attr]
  80. if attr in self.__dict__:
  81. return self.__dict__[attr]
  82. raise AttributeError("'{}' object has no attribute '{}'".format(
  83. type(self).__name__, attr))
  84. def __str__(self):
  85. loss_str = []
  86. for name, meter in self.meters.items():
  87. loss_str.append(
  88. "{}: {}".format(name, str(meter))
  89. )
  90. return self.delimiter.join(loss_str)
  91. def synchronize_between_processes(self):
  92. for meter in self.meters.values():
  93. meter.synchronize_between_processes()
  94. def add_meter(self, name, meter):
  95. self.meters[name] = meter
  96. def log_every(self, iterable, print_freq, header=None):
  97. i = 0
  98. if not header:
  99. header = ''
  100. start_time = time.time()
  101. end = time.time()
  102. iter_time = SmoothedValue(fmt='{avg:.4f}')
  103. data_time = SmoothedValue(fmt='{avg:.4f}')
  104. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  105. if torch.cuda.is_available():
  106. log_msg = self.delimiter.join([
  107. header,
  108. '[{0' + space_fmt + '}/{1}]',
  109. 'eta: {eta}',
  110. '{meters}',
  111. 'time: {time}',
  112. 'data: {data}',
  113. 'max mem: {memory:.0f}'
  114. ])
  115. else:
  116. log_msg = self.delimiter.join([
  117. header,
  118. '[{0' + space_fmt + '}/{1}]',
  119. 'eta: {eta}',
  120. '{meters}',
  121. 'time: {time}',
  122. 'data: {data}'
  123. ])
  124. MB = 1024.0 * 1024.0
  125. for obj in iterable:
  126. data_time.update(time.time() - end)
  127. yield obj
  128. iter_time.update(time.time() - end)
  129. if i % print_freq == 0 or i == len(iterable) - 1:
  130. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  131. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  132. if torch.cuda.is_available():
  133. print(log_msg.format(
  134. i, len(iterable), eta=eta_string,
  135. meters=str(self),
  136. time=str(iter_time), data=str(data_time),
  137. memory=torch.cuda.max_memory_allocated() / MB))
  138. else:
  139. print(log_msg.format(
  140. i, len(iterable), eta=eta_string,
  141. meters=str(self),
  142. time=str(iter_time), data=str(data_time)))
  143. i += 1
  144. end = time.time()
  145. total_time = time.time() - start_time
  146. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  147. print('{} Total time: {} ({:.4f} s / it)'.format(
  148. header, total_time_str, total_time / len(iterable)))
  149. # ---------------------------- For Dataset ----------------------------
  150. ## build dataloader
  151. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  152. # distributed
  153. if args.distributed:
  154. sampler = DistributedSampler(dataset)
  155. else:
  156. sampler = torch.utils.data.RandomSampler(dataset)
  157. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  158. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  159. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  160. return dataloader
  161. ## collate_fn for dataloader
  162. class CollateFunc(object):
  163. def __call__(self, batch):
  164. targets = []
  165. images = []
  166. for sample in batch:
  167. image = sample[0]
  168. target = sample[1]
  169. images.append(image)
  170. targets.append(target)
  171. images = torch.stack(images, 0) # [B, C, H, W]
  172. return images, targets
  173. # ---------------------------- For Loss ----------------------------
  174. ## FocalLoss
  175. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  176. """
  177. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  178. Args:
  179. inputs: A float tensor of arbitrary shape.
  180. The predictions for each example.
  181. targets: A float tensor with the same shape as inputs. Stores the binary
  182. classification label for each element in inputs
  183. (0 for the negative class and 1 for the positive class).
  184. alpha: (optional) Weighting factor in range (0,1) to balance
  185. positive vs negative examples. Default = -1 (no weighting).
  186. gamma: Exponent of the modulating factor (1 - p_t) to
  187. balance easy vs hard examples.
  188. Returns:
  189. Loss tensor
  190. """
  191. prob = inputs.sigmoid()
  192. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  193. p_t = prob * targets + (1 - prob) * (1 - targets)
  194. loss = ce_loss * ((1 - p_t) ** gamma)
  195. if alpha >= 0:
  196. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  197. loss = alpha_t * loss
  198. return loss.mean(1).sum() / num_boxes
  199. ## Variable FocalLoss
  200. def varifocal_loss_with_logits(pred_logits,
  201. gt_score,
  202. label,
  203. normalizer=1.0,
  204. alpha=0.75,
  205. gamma=2.0):
  206. pred_score = F.sigmoid(pred_logits)
  207. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  208. loss = F.binary_cross_entropy_with_logits(
  209. pred_logits, gt_score, weight=weight, reduction='none')
  210. return loss.mean(1).sum() / normalizer
  211. ## InverseSigmoid
  212. def inverse_sigmoid(x, eps=1e-5):
  213. x = x.clamp(min=0, max=1)
  214. x1 = x.clamp(min=eps)
  215. x2 = (1 - x).clamp(min=eps)
  216. return torch.log(x1/x2)
  217. # ---------------------------- For Model ----------------------------
  218. ## fuse Conv & BN layer
  219. def fuse_conv_bn(module):
  220. """Recursively fuse conv and bn in a module.
  221. During inference, the functionary of batch norm layers is turned off
  222. but only the mean and var alone channels are used, which exposes the
  223. chance to fuse it with the preceding conv layers to save computations and
  224. simplify network structures.
  225. Args:
  226. module (nn.Module): Module to be fused.
  227. Returns:
  228. nn.Module: Fused module.
  229. """
  230. last_conv = None
  231. last_conv_name = None
  232. def _fuse_conv_bn(conv, bn):
  233. """Fuse conv and bn into one module.
  234. Args:
  235. conv (nn.Module): Conv to be fused.
  236. bn (nn.Module): BN to be fused.
  237. Returns:
  238. nn.Module: Fused module.
  239. """
  240. conv_w = conv.weight
  241. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  242. bn.running_mean)
  243. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  244. conv.weight = nn.Parameter(conv_w *
  245. factor.reshape([conv.out_channels, 1, 1, 1]))
  246. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  247. return conv
  248. for name, child in module.named_children():
  249. if isinstance(child,
  250. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  251. if last_conv is None: # only fuse BN that is after Conv
  252. continue
  253. fused_conv = _fuse_conv_bn(last_conv, child)
  254. module._modules[last_conv_name] = fused_conv
  255. # To reduce changes, set BN as Identity instead of deleting it.
  256. module._modules[name] = nn.Identity()
  257. last_conv = None
  258. elif isinstance(child, nn.Conv2d):
  259. last_conv = child
  260. last_conv_name = name
  261. else:
  262. fuse_conv_bn(child)
  263. return module
  264. ## replace module
  265. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  266. """
  267. Replace given type in module to a new type. mostly used in deploy.
  268. Args:
  269. module (nn.Module): model to apply replace operation.
  270. replaced_module_type (Type): module type to be replaced.
  271. new_module_type (Type)
  272. replace_func (function): python function to describe replace logic. Defalut value None.
  273. Returns:
  274. model (nn.Module): module that already been replaced.
  275. """
  276. def default_replace_func(replaced_module_type, new_module_type):
  277. return new_module_type()
  278. if replace_func is None:
  279. replace_func = default_replace_func
  280. model = module
  281. if isinstance(module, replaced_module_type):
  282. model = replace_func(replaced_module_type, new_module_type)
  283. else: # recurrsively replace
  284. for name, child in module.named_children():
  285. new_child = replace_module(child, replaced_module_type, new_module_type)
  286. if new_child is not child: # child is already replaced
  287. model.add_module(name, new_child)
  288. return model
  289. ## compute FLOPs & Parameters
  290. def compute_flops(model, img_size, device):
  291. x = torch.randn(1, 3, img_size, img_size).to(device)
  292. print('==============================')
  293. flops, params = profile(model, inputs=(x, ), verbose=False)
  294. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  295. print('Params : {:.2f} M'.format(params / 1e6))
  296. ## load trained weight
  297. def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
  298. # check ckpt file
  299. if path_to_ckpt is None:
  300. print('no weight file ...')
  301. else:
  302. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  303. print('--------------------------------------')
  304. print('Best model infor:')
  305. print('Epoch: {}'.format(checkpoint["epoch"]))
  306. print('mAP: {}'.format(checkpoint["mAP"]))
  307. print('--------------------------------------')
  308. checkpoint_state_dict = checkpoint["model"]
  309. model.load_state_dict(checkpoint_state_dict)
  310. print('Finished loading model!')
  311. # fuse rep conv
  312. if fuse_rep_conv:
  313. print("Fusing RepConv ...")
  314. for m in model.modules():
  315. if hasattr(m, 'fuse_convs'):
  316. m.fuse_convs()
  317. # fuse conv & bn
  318. if fuse_cbn:
  319. print('Fusing Conv & BN ...')
  320. model = fuse_conv_bn(model)
  321. return model
  322. ## Model EMA
  323. class ModelEMA(object):
  324. def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
  325. # Create EMA
  326. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  327. self.updates = 0 # number of EMA updates
  328. self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau)) # decay exponential ramp (to help early epochs)
  329. for p in self.ema.parameters():
  330. p.requires_grad_(False)
  331. if resume is not None and resume.lower() != "none":
  332. self.load_resume(resume)
  333. print("Initialize ModelEMA's updates: {}".format(self.updates))
  334. def load_resume(self, resume):
  335. checkpoint = torch.load(resume)
  336. if 'ema_updates' in checkpoint.keys():
  337. print('--Load ModelEMA updates from the checkpoint: ', resume)
  338. # checkpoint state dict
  339. self.updates = checkpoint.pop("ema_updates")
  340. def is_parallel(self, model):
  341. # Returns True if model is of type DP or DDP
  342. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  343. def de_parallel(self, model):
  344. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  345. return model.module if self.is_parallel(model) else model
  346. def copy_attr(self, a, b, include=(), exclude=()):
  347. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  348. for k, v in b.__dict__.items():
  349. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  350. continue
  351. else:
  352. setattr(a, k, v)
  353. def update(self, model):
  354. # Update EMA parameters
  355. self.updates += 1
  356. d = self.decay(self.updates)
  357. msd = self.de_parallel(model).state_dict() # model state_dict
  358. for k, v in self.ema.state_dict().items():
  359. if v.dtype.is_floating_point: # true for FP16 and FP32
  360. v *= d
  361. v += (1 - d) * msd[k].detach()
  362. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  363. # Update EMA attributes
  364. self.copy_attr(self.ema, model, include, exclude)
  365. ## SiLU
  366. class SiLU(nn.Module):
  367. """export-friendly version of nn.SiLU()"""
  368. @staticmethod
  369. def forward(x):
  370. return x * torch.sigmoid(x)
  371. # ---------------------------- NMS ----------------------------
  372. ## basic NMS
  373. def nms(bboxes, scores, nms_thresh):
  374. """"Pure Python NMS."""
  375. x1 = bboxes[:, 0] #xmin
  376. y1 = bboxes[:, 1] #ymin
  377. x2 = bboxes[:, 2] #xmax
  378. y2 = bboxes[:, 3] #ymax
  379. areas = (x2 - x1) * (y2 - y1)
  380. order = scores.argsort()[::-1]
  381. keep = []
  382. while order.size > 0:
  383. i = order[0]
  384. keep.append(i)
  385. # compute iou
  386. xx1 = np.maximum(x1[i], x1[order[1:]])
  387. yy1 = np.maximum(y1[i], y1[order[1:]])
  388. xx2 = np.minimum(x2[i], x2[order[1:]])
  389. yy2 = np.minimum(y2[i], y2[order[1:]])
  390. w = np.maximum(1e-10, xx2 - xx1)
  391. h = np.maximum(1e-10, yy2 - yy1)
  392. inter = w * h
  393. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  394. #reserve all the boundingbox whose ovr less than thresh
  395. inds = np.where(iou <= nms_thresh)[0]
  396. order = order[inds + 1]
  397. return keep
  398. ## class-agnostic NMS
  399. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  400. # nms
  401. keep = nms(bboxes, scores, nms_thresh)
  402. scores = scores[keep]
  403. labels = labels[keep]
  404. bboxes = bboxes[keep]
  405. return scores, labels, bboxes
  406. ## class-aware NMS
  407. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  408. # nms
  409. keep = np.zeros(len(bboxes), dtype=np.int32)
  410. for i in range(num_classes):
  411. inds = np.where(labels == i)[0]
  412. if len(inds) == 0:
  413. continue
  414. c_bboxes = bboxes[inds]
  415. c_scores = scores[inds]
  416. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  417. keep[inds[c_keep]] = 1
  418. keep = np.where(keep > 0)
  419. scores = scores[keep]
  420. labels = labels[keep]
  421. bboxes = bboxes[keep]
  422. return scores, labels, bboxes
  423. ## multi-class NMS
  424. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  425. if class_agnostic:
  426. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  427. else:
  428. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  429. # ---------------------------- Processor for Deployment ----------------------------
  430. ## Pre-processer
  431. class PreProcessor(object):
  432. def __init__(self, img_size, keep_ratio=True):
  433. self.img_size = img_size
  434. self.keep_ratio = keep_ratio
  435. self.input_size = [img_size, img_size]
  436. def __call__(self, image, swap=(2, 0, 1)):
  437. """
  438. Input:
  439. image: (ndarray) [H, W, 3] or [H, W]
  440. formar: color format
  441. """
  442. if len(image.shape) == 3:
  443. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  444. else:
  445. padded_img = np.ones(self.input_size, np.float32) * 114.
  446. # resize
  447. if self.keep_ratio:
  448. orig_h, orig_w = image.shape[:2]
  449. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  450. resize_size = (int(orig_w * r), int(orig_h * r))
  451. if r != 1:
  452. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  453. else:
  454. resized_img = image
  455. # padding
  456. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  457. # [H, W, C] -> [C, H, W]
  458. padded_img = padded_img.transpose(swap)
  459. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  460. return padded_img, r
  461. else:
  462. orig_h, orig_w = image.shape[:2]
  463. r = np.array([self.input_size[0] / orig_w, self.input_size[1] / orig_w])
  464. if [orig_h, orig_w] == self.input_size:
  465. resized_img = image
  466. else:
  467. resized_img = cv2.resize(image, self.input_size, interpolation=cv2.INTER_LINEAR)
  468. return resized_img, r
  469. ## Post-processer
  470. class PostProcessor(object):
  471. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  472. self.num_classes = num_classes
  473. self.conf_thresh = conf_thresh
  474. self.nms_thresh = nms_thresh
  475. def __call__(self, predictions):
  476. """
  477. Input:
  478. predictions: (ndarray) [n_anchors_all, 4+1+C]
  479. """
  480. bboxes = predictions[..., :4]
  481. scores = predictions[..., 4:]
  482. # scores & labels
  483. labels = np.argmax(scores, axis=1) # [M,]
  484. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  485. # thresh
  486. keep = np.where(scores > self.conf_thresh)
  487. scores = scores[keep]
  488. labels = labels[keep]
  489. bboxes = bboxes[keep]
  490. # nms
  491. scores, labels, bboxes = multiclass_nms(
  492. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  493. return bboxes, scores, labels