misc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import cv2
  6. import math
  7. import numpy as np
  8. from copy import deepcopy
  9. from thop import profile
  10. # ---------------------------- For Dataset ----------------------------
  11. ## build dataloader
  12. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  13. # distributed
  14. if args.distributed:
  15. sampler = DistributedSampler(dataset)
  16. else:
  17. sampler = torch.utils.data.RandomSampler(dataset)
  18. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  19. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  20. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  21. return dataloader
  22. ## collate_fn for dataloader
  23. class CollateFunc(object):
  24. def __call__(self, batch):
  25. targets = []
  26. images = []
  27. for sample in batch:
  28. image = sample[0]
  29. target = sample[1]
  30. images.append(image)
  31. targets.append(target)
  32. images = torch.stack(images, 0) # [B, C, H, W]
  33. return images, targets
  34. # ---------------------------- For Model ----------------------------
  35. ## fuse Conv & BN layer
  36. def fuse_conv_bn(module):
  37. """Recursively fuse conv and bn in a module.
  38. During inference, the functionary of batch norm layers is turned off
  39. but only the mean and var alone channels are used, which exposes the
  40. chance to fuse it with the preceding conv layers to save computations and
  41. simplify network structures.
  42. Args:
  43. module (nn.Module): Module to be fused.
  44. Returns:
  45. nn.Module: Fused module.
  46. """
  47. last_conv = None
  48. last_conv_name = None
  49. def _fuse_conv_bn(conv, bn):
  50. """Fuse conv and bn into one module.
  51. Args:
  52. conv (nn.Module): Conv to be fused.
  53. bn (nn.Module): BN to be fused.
  54. Returns:
  55. nn.Module: Fused module.
  56. """
  57. conv_w = conv.weight
  58. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  59. bn.running_mean)
  60. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  61. conv.weight = nn.Parameter(conv_w *
  62. factor.reshape([conv.out_channels, 1, 1, 1]))
  63. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  64. return conv
  65. for name, child in module.named_children():
  66. if isinstance(child,
  67. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  68. if last_conv is None: # only fuse BN that is after Conv
  69. continue
  70. fused_conv = _fuse_conv_bn(last_conv, child)
  71. module._modules[last_conv_name] = fused_conv
  72. # To reduce changes, set BN as Identity instead of deleting it.
  73. module._modules[name] = nn.Identity()
  74. last_conv = None
  75. elif isinstance(child, nn.Conv2d):
  76. last_conv = child
  77. last_conv_name = name
  78. else:
  79. fuse_conv_bn(child)
  80. return module
  81. ## replace module
  82. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  83. """
  84. Replace given type in module to a new type. mostly used in deploy.
  85. Args:
  86. module (nn.Module): model to apply replace operation.
  87. replaced_module_type (Type): module type to be replaced.
  88. new_module_type (Type)
  89. replace_func (function): python function to describe replace logic. Defalut value None.
  90. Returns:
  91. model (nn.Module): module that already been replaced.
  92. """
  93. def default_replace_func(replaced_module_type, new_module_type):
  94. return new_module_type()
  95. if replace_func is None:
  96. replace_func = default_replace_func
  97. model = module
  98. if isinstance(module, replaced_module_type):
  99. model = replace_func(replaced_module_type, new_module_type)
  100. else: # recurrsively replace
  101. for name, child in module.named_children():
  102. new_child = replace_module(child, replaced_module_type, new_module_type)
  103. if new_child is not child: # child is already replaced
  104. model.add_module(name, new_child)
  105. return model
  106. ## compute FLOPs & Parameters
  107. def compute_flops(model, img_size, device):
  108. x = torch.randn(1, 3, img_size, img_size).to(device)
  109. print('==============================')
  110. flops, params = profile(model, inputs=(x, ), verbose=False)
  111. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  112. print('Params : {:.2f} M'.format(params / 1e6))
  113. ## load trained weight
  114. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  115. # check ckpt file
  116. if path_to_ckpt is None:
  117. print('no weight file ...')
  118. else:
  119. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  120. print('--------------------------------------')
  121. print('Best model infor:')
  122. print('Epoch: {}'.format(checkpoint.pop("epoch")))
  123. print('mAP: {}'.format(checkpoint.pop("mAP")))
  124. print('--------------------------------------')
  125. checkpoint_state_dict = checkpoint.pop("model")
  126. model.load_state_dict(checkpoint_state_dict)
  127. print('Finished loading model!')
  128. # fuse conv & bn
  129. if fuse_cbn:
  130. print('Fusing Conv & BN ...')
  131. model = fuse_conv_bn(model)
  132. return model
  133. ## Model EMA
  134. class ModelEMA(object):
  135. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  136. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  137. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  138. """
  139. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  140. # Create EMA
  141. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  142. self.updates = updates # number of EMA updates
  143. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  144. for p in self.ema.parameters():
  145. p.requires_grad_(False)
  146. def is_parallel(self, model):
  147. # Returns True if model is of type DP or DDP
  148. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  149. def de_parallel(self, model):
  150. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  151. return model.module if self.is_parallel(model) else model
  152. def copy_attr(self, a, b, include=(), exclude=()):
  153. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  154. for k, v in b.__dict__.items():
  155. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  156. continue
  157. else:
  158. setattr(a, k, v)
  159. def update(self, model):
  160. # Update EMA parameters
  161. self.updates += 1
  162. d = self.decay(self.updates)
  163. msd = self.de_parallel(model).state_dict() # model state_dict
  164. for k, v in self.ema.state_dict().items():
  165. if v.dtype.is_floating_point: # true for FP16 and FP32
  166. v *= d
  167. v += (1 - d) * msd[k].detach()
  168. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  169. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  170. # Update EMA attributes
  171. self.copy_attr(self.ema, model, include, exclude)
  172. ## SiLU
  173. class SiLU(nn.Module):
  174. """export-friendly version of nn.SiLU()"""
  175. @staticmethod
  176. def forward(x):
  177. return x * torch.sigmoid(x)
  178. # ---------------------------- NMS ----------------------------
  179. ## basic NMS
  180. def nms(bboxes, scores, nms_thresh):
  181. """"Pure Python NMS."""
  182. x1 = bboxes[:, 0] #xmin
  183. y1 = bboxes[:, 1] #ymin
  184. x2 = bboxes[:, 2] #xmax
  185. y2 = bboxes[:, 3] #ymax
  186. areas = (x2 - x1) * (y2 - y1)
  187. order = scores.argsort()[::-1]
  188. keep = []
  189. while order.size > 0:
  190. i = order[0]
  191. keep.append(i)
  192. # compute iou
  193. xx1 = np.maximum(x1[i], x1[order[1:]])
  194. yy1 = np.maximum(y1[i], y1[order[1:]])
  195. xx2 = np.minimum(x2[i], x2[order[1:]])
  196. yy2 = np.minimum(y2[i], y2[order[1:]])
  197. w = np.maximum(1e-10, xx2 - xx1)
  198. h = np.maximum(1e-10, yy2 - yy1)
  199. inter = w * h
  200. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  201. #reserve all the boundingbox whose ovr less than thresh
  202. inds = np.where(iou <= nms_thresh)[0]
  203. order = order[inds + 1]
  204. return keep
  205. ## class-agnostic NMS
  206. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  207. # nms
  208. keep = nms(bboxes, scores, nms_thresh)
  209. scores = scores[keep]
  210. labels = labels[keep]
  211. bboxes = bboxes[keep]
  212. return scores, labels, bboxes
  213. ## class-aware NMS
  214. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  215. # nms
  216. keep = np.zeros(len(bboxes), dtype=np.int)
  217. for i in range(num_classes):
  218. inds = np.where(labels == i)[0]
  219. if len(inds) == 0:
  220. continue
  221. c_bboxes = bboxes[inds]
  222. c_scores = scores[inds]
  223. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  224. keep[inds[c_keep]] = 1
  225. keep = np.where(keep > 0)
  226. scores = scores[keep]
  227. labels = labels[keep]
  228. bboxes = bboxes[keep]
  229. return scores, labels, bboxes
  230. ## multi-class NMS
  231. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  232. if class_agnostic:
  233. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  234. else:
  235. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  236. # ---------------------------- Processor for Deployment ----------------------------
  237. ## Pre-processer
  238. class PreProcessor(object):
  239. def __init__(self, img_size):
  240. self.img_size = img_size
  241. self.input_size = [img_size, img_size]
  242. def __call__(self, image, swap=(2, 0, 1)):
  243. """
  244. Input:
  245. image: (ndarray) [H, W, 3] or [H, W]
  246. formar: color format
  247. """
  248. if len(image.shape) == 3:
  249. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  250. else:
  251. padded_img = np.ones(self.input_size, np.float32) * 114.
  252. # resize
  253. orig_h, orig_w = image.shape[:2]
  254. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  255. resize_size = (int(orig_w * r), int(orig_h * r))
  256. if r != 1:
  257. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  258. else:
  259. resized_img = image
  260. # padding
  261. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  262. # [H, W, C] -> [C, H, W]
  263. padded_img = padded_img.transpose(swap)
  264. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  265. return padded_img, r
  266. ## Post-processer
  267. class PostProcessor(object):
  268. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  269. self.num_classes = num_classes
  270. self.conf_thresh = conf_thresh
  271. self.nms_thresh = nms_thresh
  272. def __call__(self, predictions):
  273. """
  274. Input:
  275. predictions: (ndarray) [n_anchors_all, 4+1+C]
  276. """
  277. bboxes = predictions[..., :4]
  278. scores = predictions[..., 4:]
  279. # scores & labels
  280. labels = np.argmax(scores, axis=1) # [M,]
  281. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  282. # thresh
  283. keep = np.where(scores > self.conf_thresh)
  284. scores = scores[keep]
  285. labels = labels[keep]
  286. bboxes = bboxes[keep]
  287. # nms
  288. scores, labels, bboxes = multiclass_nms(
  289. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  290. return bboxes, scores, labels