misc.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.distributed as dist
  5. from torch.utils.data import DataLoader, DistributedSampler
  6. import cv2
  7. import math
  8. import time
  9. import datetime
  10. import numpy as np
  11. from copy import deepcopy
  12. from thop import profile
  13. from collections import defaultdict, deque
  14. from .distributed_utils import is_dist_avail_and_initialized
  15. # ---------------------------- Train tools ----------------------------
  16. class SmoothedValue(object):
  17. """Track a series of values and provide access to smoothed values over a
  18. window or the global series average.
  19. """
  20. def __init__(self, window_size=20, fmt=None):
  21. if fmt is None:
  22. fmt = "{median:.4f} ({global_avg:.4f})"
  23. self.deque = deque(maxlen=window_size)
  24. self.total = 0.0
  25. self.count = 0
  26. self.fmt = fmt
  27. def update(self, value, n=1):
  28. self.deque.append(value)
  29. self.count += n
  30. self.total += value * n
  31. def synchronize_between_processes(self):
  32. """
  33. Warning: does not synchronize the deque!
  34. """
  35. if not is_dist_avail_and_initialized():
  36. return
  37. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  38. dist.barrier()
  39. dist.all_reduce(t)
  40. t = t.tolist()
  41. self.count = int(t[0])
  42. self.total = t[1]
  43. @property
  44. def median(self):
  45. d = torch.tensor(list(self.deque))
  46. return d.median().item()
  47. @property
  48. def avg(self):
  49. d = torch.tensor(list(self.deque), dtype=torch.float32)
  50. return d.mean().item()
  51. @property
  52. def global_avg(self):
  53. return self.total / self.count
  54. @property
  55. def max(self):
  56. return max(self.deque)
  57. @property
  58. def value(self):
  59. return self.deque[-1]
  60. def __str__(self):
  61. return self.fmt.format(
  62. median=self.median,
  63. avg=self.avg,
  64. global_avg=self.global_avg,
  65. max=self.max,
  66. value=self.value)
  67. class MetricLogger(object):
  68. def __init__(self, delimiter="\t"):
  69. self.meters = defaultdict(SmoothedValue)
  70. self.delimiter = delimiter
  71. def update(self, **kwargs):
  72. for k, v in kwargs.items():
  73. if isinstance(v, torch.Tensor):
  74. v = v.item()
  75. assert isinstance(v, (float, int))
  76. self.meters[k].update(v)
  77. def __getattr__(self, attr):
  78. if attr in self.meters:
  79. return self.meters[attr]
  80. if attr in self.__dict__:
  81. return self.__dict__[attr]
  82. raise AttributeError("'{}' object has no attribute '{}'".format(
  83. type(self).__name__, attr))
  84. def __str__(self):
  85. loss_str = []
  86. for name, meter in self.meters.items():
  87. loss_str.append(
  88. "{}: {}".format(name, str(meter))
  89. )
  90. return self.delimiter.join(loss_str)
  91. def synchronize_between_processes(self):
  92. for meter in self.meters.values():
  93. meter.synchronize_between_processes()
  94. def add_meter(self, name, meter):
  95. self.meters[name] = meter
  96. def log_every(self, iterable, print_freq, header=None):
  97. i = 0
  98. if not header:
  99. header = ''
  100. start_time = time.time()
  101. end = time.time()
  102. iter_time = SmoothedValue(fmt='{avg:.4f}')
  103. data_time = SmoothedValue(fmt='{avg:.4f}')
  104. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  105. if torch.cuda.is_available():
  106. log_msg = self.delimiter.join([
  107. header,
  108. '[{0' + space_fmt + '}/{1}]',
  109. 'eta: {eta}',
  110. '{meters}',
  111. 'time: {time}',
  112. 'data: {data}',
  113. 'max mem: {memory:.0f}'
  114. ])
  115. else:
  116. log_msg = self.delimiter.join([
  117. header,
  118. '[{0' + space_fmt + '}/{1}]',
  119. 'eta: {eta}',
  120. '{meters}',
  121. 'time: {time}',
  122. 'data: {data}'
  123. ])
  124. MB = 1024.0 * 1024.0
  125. for obj in iterable:
  126. data_time.update(time.time() - end)
  127. yield obj
  128. iter_time.update(time.time() - end)
  129. if i % print_freq == 0 or i == len(iterable) - 1:
  130. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  131. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  132. if torch.cuda.is_available():
  133. print(log_msg.format(
  134. i, len(iterable), eta=eta_string,
  135. meters=str(self),
  136. time=str(iter_time), data=str(data_time),
  137. memory=torch.cuda.max_memory_allocated() / MB))
  138. else:
  139. print(log_msg.format(
  140. i, len(iterable), eta=eta_string,
  141. meters=str(self),
  142. time=str(iter_time), data=str(data_time)))
  143. i += 1
  144. end = time.time()
  145. total_time = time.time() - start_time
  146. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  147. print('{} Total time: {} ({:.4f} s / it)'.format(
  148. header, total_time_str, total_time / len(iterable)))
  149. # ---------------------------- For Dataset ----------------------------
  150. ## build dataloader
  151. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  152. # distributed
  153. if args.distributed:
  154. sampler = DistributedSampler(dataset)
  155. else:
  156. sampler = torch.utils.data.RandomSampler(dataset)
  157. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  158. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  159. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  160. return dataloader
  161. ## collate_fn for dataloader
  162. class CollateFunc(object):
  163. def __call__(self, batch):
  164. targets = []
  165. images = []
  166. for sample in batch:
  167. image = sample[0]
  168. target = sample[1]
  169. images.append(image)
  170. targets.append(target)
  171. images = torch.stack(images, 0) # [B, C, H, W]
  172. return images, targets
  173. # ---------------------------- For Loss ----------------------------
  174. ## FocalLoss
  175. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  176. """
  177. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  178. Args:
  179. inputs: A float tensor of arbitrary shape.
  180. The predictions for each example.
  181. targets: A float tensor with the same shape as inputs. Stores the binary
  182. classification label for each element in inputs
  183. (0 for the negative class and 1 for the positive class).
  184. alpha: (optional) Weighting factor in range (0,1) to balance
  185. positive vs negative examples. Default = -1 (no weighting).
  186. gamma: Exponent of the modulating factor (1 - p_t) to
  187. balance easy vs hard examples.
  188. Returns:
  189. Loss tensor
  190. """
  191. prob = inputs.sigmoid()
  192. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  193. p_t = prob * targets + (1 - prob) * (1 - targets)
  194. loss = ce_loss * ((1 - p_t) ** gamma)
  195. if alpha >= 0:
  196. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  197. loss = alpha_t * loss
  198. return loss.mean(1).sum() / num_boxes
  199. ## InverseSigmoid
  200. def inverse_sigmoid(x, eps=1e-5):
  201. x = x.clamp(min=0, max=1)
  202. x1 = x.clamp(min=eps)
  203. x2 = (1 - x).clamp(min=eps)
  204. return torch.log(x1/x2)
  205. # ---------------------------- For Model ----------------------------
  206. ## fuse Conv & BN layer
  207. def fuse_conv_bn(module):
  208. """Recursively fuse conv and bn in a module.
  209. During inference, the functionary of batch norm layers is turned off
  210. but only the mean and var alone channels are used, which exposes the
  211. chance to fuse it with the preceding conv layers to save computations and
  212. simplify network structures.
  213. Args:
  214. module (nn.Module): Module to be fused.
  215. Returns:
  216. nn.Module: Fused module.
  217. """
  218. last_conv = None
  219. last_conv_name = None
  220. def _fuse_conv_bn(conv, bn):
  221. """Fuse conv and bn into one module.
  222. Args:
  223. conv (nn.Module): Conv to be fused.
  224. bn (nn.Module): BN to be fused.
  225. Returns:
  226. nn.Module: Fused module.
  227. """
  228. conv_w = conv.weight
  229. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  230. bn.running_mean)
  231. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  232. conv.weight = nn.Parameter(conv_w *
  233. factor.reshape([conv.out_channels, 1, 1, 1]))
  234. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  235. return conv
  236. for name, child in module.named_children():
  237. if isinstance(child,
  238. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  239. if last_conv is None: # only fuse BN that is after Conv
  240. continue
  241. fused_conv = _fuse_conv_bn(last_conv, child)
  242. module._modules[last_conv_name] = fused_conv
  243. # To reduce changes, set BN as Identity instead of deleting it.
  244. module._modules[name] = nn.Identity()
  245. last_conv = None
  246. elif isinstance(child, nn.Conv2d):
  247. last_conv = child
  248. last_conv_name = name
  249. else:
  250. fuse_conv_bn(child)
  251. return module
  252. ## replace module
  253. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  254. """
  255. Replace given type in module to a new type. mostly used in deploy.
  256. Args:
  257. module (nn.Module): model to apply replace operation.
  258. replaced_module_type (Type): module type to be replaced.
  259. new_module_type (Type)
  260. replace_func (function): python function to describe replace logic. Defalut value None.
  261. Returns:
  262. model (nn.Module): module that already been replaced.
  263. """
  264. def default_replace_func(replaced_module_type, new_module_type):
  265. return new_module_type()
  266. if replace_func is None:
  267. replace_func = default_replace_func
  268. model = module
  269. if isinstance(module, replaced_module_type):
  270. model = replace_func(replaced_module_type, new_module_type)
  271. else: # recurrsively replace
  272. for name, child in module.named_children():
  273. new_child = replace_module(child, replaced_module_type, new_module_type)
  274. if new_child is not child: # child is already replaced
  275. model.add_module(name, new_child)
  276. return model
  277. ## compute FLOPs & Parameters
  278. def compute_flops(model, img_size, device):
  279. x = torch.randn(1, 3, img_size, img_size).to(device)
  280. print('==============================')
  281. flops, params = profile(model, inputs=(x, ), verbose=False)
  282. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  283. print('Params : {:.2f} M'.format(params / 1e6))
  284. ## load trained weight
  285. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  286. # check ckpt file
  287. if path_to_ckpt is None:
  288. print('no weight file ...')
  289. else:
  290. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  291. print('--------------------------------------')
  292. print('Best model infor:')
  293. print('Epoch: {}'.format(checkpoint["epoch"]))
  294. print('mAP: {}'.format(checkpoint["mAP"]))
  295. print('--------------------------------------')
  296. checkpoint_state_dict = checkpoint["model"]
  297. model.load_state_dict(checkpoint_state_dict)
  298. print('Finished loading model!')
  299. # fuse conv & bn
  300. if fuse_cbn:
  301. print('Fusing Conv & BN ...')
  302. model = fuse_conv_bn(model)
  303. return model
  304. ## Model EMA
  305. class ModelEMA(object):
  306. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  307. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  308. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  309. """
  310. def __init__(self, cfg, model, updates=0):
  311. # Create EMA
  312. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  313. self.updates = updates # number of EMA updates
  314. self.decay = lambda x: cfg['ema_decay'] * (1 - math.exp(-x / cfg['ema_tau'])) # decay exponential ramp (to help early epochs)
  315. for p in self.ema.parameters():
  316. p.requires_grad_(False)
  317. def is_parallel(self, model):
  318. # Returns True if model is of type DP or DDP
  319. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  320. def de_parallel(self, model):
  321. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  322. return model.module if self.is_parallel(model) else model
  323. def copy_attr(self, a, b, include=(), exclude=()):
  324. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  325. for k, v in b.__dict__.items():
  326. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  327. continue
  328. else:
  329. setattr(a, k, v)
  330. def update(self, model):
  331. # Update EMA parameters
  332. self.updates += 1
  333. d = self.decay(self.updates)
  334. msd = self.de_parallel(model).state_dict() # model state_dict
  335. for k, v in self.ema.state_dict().items():
  336. if v.dtype.is_floating_point: # true for FP16 and FP32
  337. v *= d
  338. v += (1 - d) * msd[k].detach()
  339. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  340. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  341. # Update EMA attributes
  342. self.copy_attr(self.ema, model, include, exclude)
  343. ## SiLU
  344. class SiLU(nn.Module):
  345. """export-friendly version of nn.SiLU()"""
  346. @staticmethod
  347. def forward(x):
  348. return x * torch.sigmoid(x)
  349. # ---------------------------- NMS ----------------------------
  350. ## basic NMS
  351. def nms(bboxes, scores, nms_thresh):
  352. """"Pure Python NMS."""
  353. x1 = bboxes[:, 0] #xmin
  354. y1 = bboxes[:, 1] #ymin
  355. x2 = bboxes[:, 2] #xmax
  356. y2 = bboxes[:, 3] #ymax
  357. areas = (x2 - x1) * (y2 - y1)
  358. order = scores.argsort()[::-1]
  359. keep = []
  360. while order.size > 0:
  361. i = order[0]
  362. keep.append(i)
  363. # compute iou
  364. xx1 = np.maximum(x1[i], x1[order[1:]])
  365. yy1 = np.maximum(y1[i], y1[order[1:]])
  366. xx2 = np.minimum(x2[i], x2[order[1:]])
  367. yy2 = np.minimum(y2[i], y2[order[1:]])
  368. w = np.maximum(1e-10, xx2 - xx1)
  369. h = np.maximum(1e-10, yy2 - yy1)
  370. inter = w * h
  371. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  372. #reserve all the boundingbox whose ovr less than thresh
  373. inds = np.where(iou <= nms_thresh)[0]
  374. order = order[inds + 1]
  375. return keep
  376. ## class-agnostic NMS
  377. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  378. # nms
  379. keep = nms(bboxes, scores, nms_thresh)
  380. scores = scores[keep]
  381. labels = labels[keep]
  382. bboxes = bboxes[keep]
  383. return scores, labels, bboxes
  384. ## class-aware NMS
  385. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  386. # nms
  387. keep = np.zeros(len(bboxes), dtype=np.int32)
  388. for i in range(num_classes):
  389. inds = np.where(labels == i)[0]
  390. if len(inds) == 0:
  391. continue
  392. c_bboxes = bboxes[inds]
  393. c_scores = scores[inds]
  394. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  395. keep[inds[c_keep]] = 1
  396. keep = np.where(keep > 0)
  397. scores = scores[keep]
  398. labels = labels[keep]
  399. bboxes = bboxes[keep]
  400. return scores, labels, bboxes
  401. ## multi-class NMS
  402. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  403. if class_agnostic:
  404. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  405. else:
  406. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  407. # ---------------------------- Processor for Deployment ----------------------------
  408. ## Pre-processer
  409. class PreProcessor(object):
  410. def __init__(self, img_size, keep_ratio=True):
  411. self.img_size = img_size
  412. self.keep_ratio = keep_ratio
  413. self.input_size = [img_size, img_size]
  414. def __call__(self, image, swap=(2, 0, 1)):
  415. """
  416. Input:
  417. image: (ndarray) [H, W, 3] or [H, W]
  418. formar: color format
  419. """
  420. if len(image.shape) == 3:
  421. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  422. else:
  423. padded_img = np.ones(self.input_size, np.float32) * 114.
  424. # resize
  425. if self.keep_ratio:
  426. orig_h, orig_w = image.shape[:2]
  427. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  428. resize_size = (int(orig_w * r), int(orig_h * r))
  429. if r != 1:
  430. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  431. else:
  432. resized_img = image
  433. # padding
  434. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  435. # [H, W, C] -> [C, H, W]
  436. padded_img = padded_img.transpose(swap)
  437. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  438. return padded_img, r
  439. else:
  440. orig_h, orig_w = image.shape[:2]
  441. r = np.array([self.input_size[0] / orig_w, self.input_size[1] / orig_w])
  442. if [orig_h, orig_w] == self.input_size:
  443. resized_img = image
  444. else:
  445. resized_img = cv2.resize(image, self.input_size, interpolation=cv2.INTER_LINEAR)
  446. return resized_img, r
  447. ## Post-processer
  448. class PostProcessor(object):
  449. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  450. self.num_classes = num_classes
  451. self.conf_thresh = conf_thresh
  452. self.nms_thresh = nms_thresh
  453. def __call__(self, predictions):
  454. """
  455. Input:
  456. predictions: (ndarray) [n_anchors_all, 4+1+C]
  457. """
  458. bboxes = predictions[..., :4]
  459. scores = predictions[..., 4:]
  460. # scores & labels
  461. labels = np.argmax(scores, axis=1) # [M,]
  462. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  463. # thresh
  464. keep = np.where(scores > self.conf_thresh)
  465. scores = scores[keep]
  466. labels = labels[keep]
  467. bboxes = bboxes[keep]
  468. # nms
  469. scores, labels, bboxes = multiclass_nms(
  470. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  471. return bboxes, scores, labels