misc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import cv2
  6. import math
  7. import numpy as np
  8. from copy import deepcopy
  9. from thop import profile
  10. # ---------------------------- For Dataset ----------------------------
  11. ## build dataloader
  12. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  13. # distributed
  14. if args.distributed:
  15. sampler = DistributedSampler(dataset)
  16. else:
  17. sampler = torch.utils.data.RandomSampler(dataset)
  18. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  19. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  20. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  21. return dataloader
  22. ## collate_fn for dataloader
  23. class CollateFunc(object):
  24. def __call__(self, batch):
  25. targets = []
  26. images = []
  27. for sample in batch:
  28. image = sample[0]
  29. target = sample[1]
  30. images.append(image)
  31. targets.append(target)
  32. images = torch.stack(images, 0) # [B, C, H, W]
  33. return images, targets
  34. # ---------------------------- For Loss ----------------------------
  35. ## FocalLoss
  36. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  37. """
  38. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  39. Args:
  40. inputs: A float tensor of arbitrary shape.
  41. The predictions for each example.
  42. targets: A float tensor with the same shape as inputs. Stores the binary
  43. classification label for each element in inputs
  44. (0 for the negative class and 1 for the positive class).
  45. alpha: (optional) Weighting factor in range (0,1) to balance
  46. positive vs negative examples. Default = -1 (no weighting).
  47. gamma: Exponent of the modulating factor (1 - p_t) to
  48. balance easy vs hard examples.
  49. Returns:
  50. Loss tensor
  51. """
  52. prob = inputs.sigmoid()
  53. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  54. p_t = prob * targets + (1 - prob) * (1 - targets)
  55. loss = ce_loss * ((1 - p_t) ** gamma)
  56. if alpha >= 0:
  57. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  58. loss = alpha_t * loss
  59. return loss.mean(1).sum() / num_boxes
  60. ## InverseSigmoid
  61. def inverse_sigmoid(x, eps=1e-5):
  62. x = x.clamp(min=0, max=1)
  63. x1 = x.clamp(min=eps)
  64. x2 = (1 - x).clamp(min=eps)
  65. return torch.log(x1/x2)
  66. # ---------------------------- For Model ----------------------------
  67. ## fuse Conv & BN layer
  68. def fuse_conv_bn(module):
  69. """Recursively fuse conv and bn in a module.
  70. During inference, the functionary of batch norm layers is turned off
  71. but only the mean and var alone channels are used, which exposes the
  72. chance to fuse it with the preceding conv layers to save computations and
  73. simplify network structures.
  74. Args:
  75. module (nn.Module): Module to be fused.
  76. Returns:
  77. nn.Module: Fused module.
  78. """
  79. last_conv = None
  80. last_conv_name = None
  81. def _fuse_conv_bn(conv, bn):
  82. """Fuse conv and bn into one module.
  83. Args:
  84. conv (nn.Module): Conv to be fused.
  85. bn (nn.Module): BN to be fused.
  86. Returns:
  87. nn.Module: Fused module.
  88. """
  89. conv_w = conv.weight
  90. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  91. bn.running_mean)
  92. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  93. conv.weight = nn.Parameter(conv_w *
  94. factor.reshape([conv.out_channels, 1, 1, 1]))
  95. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  96. return conv
  97. for name, child in module.named_children():
  98. if isinstance(child,
  99. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  100. if last_conv is None: # only fuse BN that is after Conv
  101. continue
  102. fused_conv = _fuse_conv_bn(last_conv, child)
  103. module._modules[last_conv_name] = fused_conv
  104. # To reduce changes, set BN as Identity instead of deleting it.
  105. module._modules[name] = nn.Identity()
  106. last_conv = None
  107. elif isinstance(child, nn.Conv2d):
  108. last_conv = child
  109. last_conv_name = name
  110. else:
  111. fuse_conv_bn(child)
  112. return module
  113. ## replace module
  114. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  115. """
  116. Replace given type in module to a new type. mostly used in deploy.
  117. Args:
  118. module (nn.Module): model to apply replace operation.
  119. replaced_module_type (Type): module type to be replaced.
  120. new_module_type (Type)
  121. replace_func (function): python function to describe replace logic. Defalut value None.
  122. Returns:
  123. model (nn.Module): module that already been replaced.
  124. """
  125. def default_replace_func(replaced_module_type, new_module_type):
  126. return new_module_type()
  127. if replace_func is None:
  128. replace_func = default_replace_func
  129. model = module
  130. if isinstance(module, replaced_module_type):
  131. model = replace_func(replaced_module_type, new_module_type)
  132. else: # recurrsively replace
  133. for name, child in module.named_children():
  134. new_child = replace_module(child, replaced_module_type, new_module_type)
  135. if new_child is not child: # child is already replaced
  136. model.add_module(name, new_child)
  137. return model
  138. ## compute FLOPs & Parameters
  139. def compute_flops(model, img_size, device):
  140. x = torch.randn(1, 3, img_size, img_size).to(device)
  141. print('==============================')
  142. flops, params = profile(model, inputs=(x, ), verbose=False)
  143. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  144. print('Params : {:.2f} M'.format(params / 1e6))
  145. ## load trained weight
  146. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  147. # check ckpt file
  148. if path_to_ckpt is None:
  149. print('no weight file ...')
  150. else:
  151. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  152. print('--------------------------------------')
  153. print('Best model infor:')
  154. print('Epoch: {}'.format(checkpoint["epoch"]))
  155. print('mAP: {}'.format(checkpoint["mAP"]))
  156. print('--------------------------------------')
  157. checkpoint_state_dict = checkpoint["model"]
  158. model.load_state_dict(checkpoint_state_dict)
  159. print('Finished loading model!')
  160. # fuse conv & bn
  161. if fuse_cbn:
  162. print('Fusing Conv & BN ...')
  163. model = fuse_conv_bn(model)
  164. return model
  165. ## Model EMA
  166. class ModelEMA(object):
  167. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  168. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  169. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  170. """
  171. def __init__(self, cfg, model, updates=0):
  172. # Create EMA
  173. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  174. self.updates = updates # number of EMA updates
  175. self.decay = lambda x: cfg['ema_decay'] * (1 - math.exp(-x / cfg['ema_tau'])) # decay exponential ramp (to help early epochs)
  176. for p in self.ema.parameters():
  177. p.requires_grad_(False)
  178. def is_parallel(self, model):
  179. # Returns True if model is of type DP or DDP
  180. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  181. def de_parallel(self, model):
  182. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  183. return model.module if self.is_parallel(model) else model
  184. def copy_attr(self, a, b, include=(), exclude=()):
  185. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  186. for k, v in b.__dict__.items():
  187. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  188. continue
  189. else:
  190. setattr(a, k, v)
  191. def update(self, model):
  192. # Update EMA parameters
  193. self.updates += 1
  194. d = self.decay(self.updates)
  195. msd = self.de_parallel(model).state_dict() # model state_dict
  196. for k, v in self.ema.state_dict().items():
  197. if v.dtype.is_floating_point: # true for FP16 and FP32
  198. v *= d
  199. v += (1 - d) * msd[k].detach()
  200. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  201. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  202. # Update EMA attributes
  203. self.copy_attr(self.ema, model, include, exclude)
  204. ## SiLU
  205. class SiLU(nn.Module):
  206. """export-friendly version of nn.SiLU()"""
  207. @staticmethod
  208. def forward(x):
  209. return x * torch.sigmoid(x)
  210. # ---------------------------- NMS ----------------------------
  211. ## basic NMS
  212. def nms(bboxes, scores, nms_thresh):
  213. """"Pure Python NMS."""
  214. x1 = bboxes[:, 0] #xmin
  215. y1 = bboxes[:, 1] #ymin
  216. x2 = bboxes[:, 2] #xmax
  217. y2 = bboxes[:, 3] #ymax
  218. areas = (x2 - x1) * (y2 - y1)
  219. order = scores.argsort()[::-1]
  220. keep = []
  221. while order.size > 0:
  222. i = order[0]
  223. keep.append(i)
  224. # compute iou
  225. xx1 = np.maximum(x1[i], x1[order[1:]])
  226. yy1 = np.maximum(y1[i], y1[order[1:]])
  227. xx2 = np.minimum(x2[i], x2[order[1:]])
  228. yy2 = np.minimum(y2[i], y2[order[1:]])
  229. w = np.maximum(1e-10, xx2 - xx1)
  230. h = np.maximum(1e-10, yy2 - yy1)
  231. inter = w * h
  232. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  233. #reserve all the boundingbox whose ovr less than thresh
  234. inds = np.where(iou <= nms_thresh)[0]
  235. order = order[inds + 1]
  236. return keep
  237. ## class-agnostic NMS
  238. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  239. # nms
  240. keep = nms(bboxes, scores, nms_thresh)
  241. scores = scores[keep]
  242. labels = labels[keep]
  243. bboxes = bboxes[keep]
  244. return scores, labels, bboxes
  245. ## class-aware NMS
  246. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  247. # nms
  248. keep = np.zeros(len(bboxes), dtype=np.int32)
  249. for i in range(num_classes):
  250. inds = np.where(labels == i)[0]
  251. if len(inds) == 0:
  252. continue
  253. c_bboxes = bboxes[inds]
  254. c_scores = scores[inds]
  255. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  256. keep[inds[c_keep]] = 1
  257. keep = np.where(keep > 0)
  258. scores = scores[keep]
  259. labels = labels[keep]
  260. bboxes = bboxes[keep]
  261. return scores, labels, bboxes
  262. ## multi-class NMS
  263. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  264. if class_agnostic:
  265. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  266. else:
  267. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  268. # ---------------------------- Processor for Deployment ----------------------------
  269. ## Pre-processer
  270. class PreProcessor(object):
  271. def __init__(self, img_size):
  272. self.img_size = img_size
  273. self.input_size = [img_size, img_size]
  274. def __call__(self, image, swap=(2, 0, 1)):
  275. """
  276. Input:
  277. image: (ndarray) [H, W, 3] or [H, W]
  278. formar: color format
  279. """
  280. if len(image.shape) == 3:
  281. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  282. else:
  283. padded_img = np.ones(self.input_size, np.float32) * 114.
  284. # resize
  285. orig_h, orig_w = image.shape[:2]
  286. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  287. resize_size = (int(orig_w * r), int(orig_h * r))
  288. if r != 1:
  289. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  290. else:
  291. resized_img = image
  292. # padding
  293. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  294. # [H, W, C] -> [C, H, W]
  295. padded_img = padded_img.transpose(swap)
  296. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  297. return padded_img, r
  298. ## Post-processer
  299. class PostProcessor(object):
  300. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  301. self.num_classes = num_classes
  302. self.conf_thresh = conf_thresh
  303. self.nms_thresh = nms_thresh
  304. def __call__(self, predictions):
  305. """
  306. Input:
  307. predictions: (ndarray) [n_anchors_all, 4+1+C]
  308. """
  309. bboxes = predictions[..., :4]
  310. scores = predictions[..., 4:]
  311. # scores & labels
  312. labels = np.argmax(scores, axis=1) # [M,]
  313. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  314. # thresh
  315. keep = np.where(scores > self.conf_thresh)
  316. scores = scores[keep]
  317. labels = labels[keep]
  318. bboxes = bboxes[keep]
  319. # nms
  320. scores, labels, bboxes = multiclass_nms(
  321. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  322. return bboxes, scores, labels