misc.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import torchvision
  6. import cv2
  7. import math
  8. import numpy as np
  9. from copy import deepcopy
  10. from thop import profile
  11. # ---------------------------- For Dataset ----------------------------
  12. ## build dataloader
  13. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  14. # distributed
  15. if args.distributed:
  16. sampler = DistributedSampler(dataset)
  17. else:
  18. sampler = torch.utils.data.RandomSampler(dataset)
  19. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  20. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  21. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  22. return dataloader
  23. ## collate_fn for dataloader
  24. class CollateFunc(object):
  25. def __call__(self, batch):
  26. targets = []
  27. images = []
  28. for sample in batch:
  29. image = sample[0]
  30. target = sample[1]
  31. images.append(image)
  32. targets.append(target)
  33. images = torch.stack(images, 0) # [B, C, H, W]
  34. return images, targets
  35. # ---------------------------- For Loss ----------------------------
  36. ## FocalLoss
  37. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  38. """
  39. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  40. Args:
  41. inputs: A float tensor of arbitrary shape.
  42. The predictions for each example.
  43. targets: A float tensor with the same shape as inputs. Stores the binary
  44. classification label for each element in inputs
  45. (0 for the negative class and 1 for the positive class).
  46. alpha: (optional) Weighting factor in range (0,1) to balance
  47. positive vs negative examples. Default = -1 (no weighting).
  48. gamma: Exponent of the modulating factor (1 - p_t) to
  49. balance easy vs hard examples.
  50. Returns:
  51. Loss tensor
  52. """
  53. prob = inputs.sigmoid()
  54. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  55. p_t = prob * targets + (1 - prob) * (1 - targets)
  56. loss = ce_loss * ((1 - p_t) ** gamma)
  57. if alpha >= 0:
  58. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  59. loss = alpha_t * loss
  60. return loss.mean(1).sum() / num_boxes
  61. ## InverseSigmoid
  62. def inverse_sigmoid(x, eps=1e-5):
  63. x = x.clamp(min=0, max=1)
  64. x1 = x.clamp(min=eps)
  65. x2 = (1 - x).clamp(min=eps)
  66. return torch.log(x1/x2)
  67. # ---------------------------- For Model ----------------------------
  68. ## fuse Conv & BN layer
  69. def fuse_conv_bn(module):
  70. """Recursively fuse conv and bn in a module.
  71. During inference, the functionary of batch norm layers is turned off
  72. but only the mean and var alone channels are used, which exposes the
  73. chance to fuse it with the preceding conv layers to save computations and
  74. simplify network structures.
  75. Args:
  76. module (nn.Module): Module to be fused.
  77. Returns:
  78. nn.Module: Fused module.
  79. """
  80. last_conv = None
  81. last_conv_name = None
  82. def _fuse_conv_bn(conv, bn):
  83. """Fuse conv and bn into one module.
  84. Args:
  85. conv (nn.Module): Conv to be fused.
  86. bn (nn.Module): BN to be fused.
  87. Returns:
  88. nn.Module: Fused module.
  89. """
  90. conv_w = conv.weight
  91. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  92. bn.running_mean)
  93. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  94. conv.weight = nn.Parameter(conv_w *
  95. factor.reshape([conv.out_channels, 1, 1, 1]))
  96. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  97. return conv
  98. for name, child in module.named_children():
  99. if isinstance(child,
  100. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  101. if last_conv is None: # only fuse BN that is after Conv
  102. continue
  103. fused_conv = _fuse_conv_bn(last_conv, child)
  104. module._modules[last_conv_name] = fused_conv
  105. # To reduce changes, set BN as Identity instead of deleting it.
  106. module._modules[name] = nn.Identity()
  107. last_conv = None
  108. elif isinstance(child, nn.Conv2d):
  109. last_conv = child
  110. last_conv_name = name
  111. else:
  112. fuse_conv_bn(child)
  113. return module
  114. ## replace module
  115. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  116. """
  117. Replace given type in module to a new type. mostly used in deploy.
  118. Args:
  119. module (nn.Module): model to apply replace operation.
  120. replaced_module_type (Type): module type to be replaced.
  121. new_module_type (Type)
  122. replace_func (function): python function to describe replace logic. Defalut value None.
  123. Returns:
  124. model (nn.Module): module that already been replaced.
  125. """
  126. def default_replace_func(replaced_module_type, new_module_type):
  127. return new_module_type()
  128. if replace_func is None:
  129. replace_func = default_replace_func
  130. model = module
  131. if isinstance(module, replaced_module_type):
  132. model = replace_func(replaced_module_type, new_module_type)
  133. else: # recurrsively replace
  134. for name, child in module.named_children():
  135. new_child = replace_module(child, replaced_module_type, new_module_type)
  136. if new_child is not child: # child is already replaced
  137. model.add_module(name, new_child)
  138. return model
  139. ## compute FLOPs & Parameters
  140. def compute_flops(model, img_size, device):
  141. x = torch.randn(1, 3, img_size, img_size).to(device)
  142. print('==============================')
  143. flops, params = profile(model, inputs=(x, ), verbose=False)
  144. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  145. print('Params : {:.2f} M'.format(params / 1e6))
  146. ## load trained weight
  147. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  148. # check ckpt file
  149. if path_to_ckpt is None:
  150. print('no weight file ...')
  151. else:
  152. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  153. print('--------------------------------------')
  154. print('Best model infor:')
  155. print('Epoch: {}'.format(checkpoint["epoch"]))
  156. print('mAP: {}'.format(checkpoint["mAP"]))
  157. print('--------------------------------------')
  158. checkpoint_state_dict = checkpoint["model"]
  159. model.load_state_dict(checkpoint_state_dict)
  160. print('Finished loading model!')
  161. # fuse conv & bn
  162. if fuse_cbn:
  163. print('Fusing Conv & BN ...')
  164. model = fuse_conv_bn(model)
  165. return model
  166. ## Model EMA
  167. class ModelEMA(object):
  168. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  169. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  170. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  171. """
  172. def __init__(self, cfg, model, updates=0):
  173. # Create EMA
  174. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  175. self.updates = updates # number of EMA updates
  176. self.decay = lambda x: cfg['ema_decay'] * (1 - math.exp(-x / cfg['ema_tau'])) # decay exponential ramp (to help early epochs)
  177. for p in self.ema.parameters():
  178. p.requires_grad_(False)
  179. def is_parallel(self, model):
  180. # Returns True if model is of type DP or DDP
  181. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  182. def de_parallel(self, model):
  183. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  184. return model.module if self.is_parallel(model) else model
  185. def copy_attr(self, a, b, include=(), exclude=()):
  186. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  187. for k, v in b.__dict__.items():
  188. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  189. continue
  190. else:
  191. setattr(a, k, v)
  192. def update(self, model):
  193. # Update EMA parameters
  194. self.updates += 1
  195. d = self.decay(self.updates)
  196. msd = self.de_parallel(model).state_dict() # model state_dict
  197. for k, v in self.ema.state_dict().items():
  198. if v.dtype.is_floating_point: # true for FP16 and FP32
  199. v *= d
  200. v += (1 - d) * msd[k].detach()
  201. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  202. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  203. # Update EMA attributes
  204. self.copy_attr(self.ema, model, include, exclude)
  205. ## SiLU
  206. class SiLU(nn.Module):
  207. """export-friendly version of nn.SiLU()"""
  208. @staticmethod
  209. def forward(x):
  210. return x * torch.sigmoid(x)
  211. # ---------------------------- NMS ----------------------------
  212. ## basic NMS
  213. def nms(bboxes, scores, nms_thresh):
  214. """"Pure Python NMS."""
  215. x1 = bboxes[:, 0] #xmin
  216. y1 = bboxes[:, 1] #ymin
  217. x2 = bboxes[:, 2] #xmax
  218. y2 = bboxes[:, 3] #ymax
  219. areas = (x2 - x1) * (y2 - y1)
  220. order = scores.argsort()[::-1]
  221. keep = []
  222. while order.size > 0:
  223. i = order[0]
  224. keep.append(i)
  225. # compute iou
  226. xx1 = np.maximum(x1[i], x1[order[1:]])
  227. yy1 = np.maximum(y1[i], y1[order[1:]])
  228. xx2 = np.minimum(x2[i], x2[order[1:]])
  229. yy2 = np.minimum(y2[i], y2[order[1:]])
  230. w = np.maximum(1e-10, xx2 - xx1)
  231. h = np.maximum(1e-10, yy2 - yy1)
  232. inter = w * h
  233. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  234. #reserve all the boundingbox whose ovr less than thresh
  235. inds = np.where(iou <= nms_thresh)[0]
  236. order = order[inds + 1]
  237. return keep
  238. ## class-agnostic NMS
  239. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  240. # nms
  241. keep = nms(bboxes, scores, nms_thresh)
  242. scores = scores[keep]
  243. labels = labels[keep]
  244. bboxes = bboxes[keep]
  245. return scores, labels, bboxes
  246. ## class-aware NMS
  247. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  248. # nms
  249. keep = np.zeros(len(bboxes), dtype=np.int32)
  250. for i in range(num_classes):
  251. inds = np.where(labels == i)[0]
  252. if len(inds) == 0:
  253. continue
  254. c_bboxes = bboxes[inds]
  255. c_scores = scores[inds]
  256. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  257. keep[inds[c_keep]] = 1
  258. keep = np.where(keep > 0)
  259. scores = scores[keep]
  260. labels = labels[keep]
  261. bboxes = bboxes[keep]
  262. return scores, labels, bboxes
  263. ## multi-class NMS
  264. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  265. if class_agnostic:
  266. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  267. else:
  268. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  269. # ---------------------------- Processor for Deployment ----------------------------
  270. ## Pre-processer
  271. class PreProcessor(object):
  272. def __init__(self, img_size, keep_ratio=True):
  273. self.img_size = img_size
  274. self.keep_ratio = keep_ratio
  275. self.input_size = [img_size, img_size]
  276. def __call__(self, image, swap=(2, 0, 1)):
  277. """
  278. Input:
  279. image: (ndarray) [H, W, 3] or [H, W]
  280. formar: color format
  281. """
  282. if len(image.shape) == 3:
  283. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  284. else:
  285. padded_img = np.ones(self.input_size, np.float32) * 114.
  286. # resize
  287. if self.keep_ratio:
  288. orig_h, orig_w = image.shape[:2]
  289. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  290. resize_size = (int(orig_w * r), int(orig_h * r))
  291. if r != 1:
  292. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  293. else:
  294. resized_img = image
  295. # padding
  296. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  297. # [H, W, C] -> [C, H, W]
  298. padded_img = padded_img.transpose(swap)
  299. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  300. return padded_img, r
  301. else:
  302. orig_h, orig_w = image.shape[:2]
  303. r = np.array([self.input_size[0] / orig_w, self.input_size[1] / orig_w])
  304. if [orig_h, orig_w] == self.input_size:
  305. resized_img = image
  306. else:
  307. resized_img = cv2.resize(image, self.input_size, interpolation=cv2.INTER_LINEAR)
  308. return resized_img, r
  309. ## Post-processer
  310. class PostProcessor(object):
  311. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  312. self.num_classes = num_classes
  313. self.conf_thresh = conf_thresh
  314. self.nms_thresh = nms_thresh
  315. def __call__(self, predictions):
  316. """
  317. Input:
  318. predictions: (ndarray) [n_anchors_all, 4+1+C]
  319. """
  320. bboxes = predictions[..., :4]
  321. scores = predictions[..., 4:]
  322. # scores & labels
  323. labels = np.argmax(scores, axis=1) # [M,]
  324. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  325. # thresh
  326. keep = np.where(scores > self.conf_thresh)
  327. scores = scores[keep]
  328. labels = labels[keep]
  329. bboxes = bboxes[keep]
  330. # nms
  331. scores, labels, bboxes = multiclass_nms(
  332. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  333. return bboxes, scores, labels