misc.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import torchvision
  6. import cv2
  7. import math
  8. import numpy as np
  9. from copy import deepcopy
  10. from thop import profile
  11. # ---------------------------- For Dataset ----------------------------
  12. ## build dataloader
  13. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  14. # distributed
  15. if args.distributed:
  16. sampler = DistributedSampler(dataset)
  17. else:
  18. sampler = torch.utils.data.RandomSampler(dataset)
  19. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  20. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  21. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  22. return dataloader
  23. ## collate_fn for dataloader
  24. class CollateFunc(object):
  25. def __call__(self, batch):
  26. targets = []
  27. images = []
  28. for sample in batch:
  29. image = sample[0]
  30. target = sample[1]
  31. images.append(image)
  32. targets.append(target)
  33. images = torch.stack(images, 0) # [B, C, H, W]
  34. return images, targets
  35. # ---------------------------- For Loss ----------------------------
  36. ## FocalLoss
  37. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  38. """
  39. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  40. Args:
  41. inputs: A float tensor of arbitrary shape.
  42. The predictions for each example.
  43. targets: A float tensor with the same shape as inputs. Stores the binary
  44. classification label for each element in inputs
  45. (0 for the negative class and 1 for the positive class).
  46. alpha: (optional) Weighting factor in range (0,1) to balance
  47. positive vs negative examples. Default = -1 (no weighting).
  48. gamma: Exponent of the modulating factor (1 - p_t) to
  49. balance easy vs hard examples.
  50. Returns:
  51. Loss tensor
  52. """
  53. prob = inputs.sigmoid()
  54. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  55. p_t = prob * targets + (1 - prob) * (1 - targets)
  56. loss = ce_loss * ((1 - p_t) ** gamma)
  57. if alpha >= 0:
  58. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  59. loss = alpha_t * loss
  60. return loss.mean(1).sum() / num_boxes
  61. ## InverseSigmoid
  62. def inverse_sigmoid(x, eps=1e-5):
  63. x = x.clamp(min=0, max=1)
  64. x1 = x.clamp(min=eps)
  65. x2 = (1 - x).clamp(min=eps)
  66. return torch.log(x1/x2)
  67. # ---------------------------- For Model ----------------------------
  68. ## fuse Conv & BN layer
  69. def fuse_conv_bn(module):
  70. """Recursively fuse conv and bn in a module.
  71. During inference, the functionary of batch norm layers is turned off
  72. but only the mean and var alone channels are used, which exposes the
  73. chance to fuse it with the preceding conv layers to save computations and
  74. simplify network structures.
  75. Args:
  76. module (nn.Module): Module to be fused.
  77. Returns:
  78. nn.Module: Fused module.
  79. """
  80. last_conv = None
  81. last_conv_name = None
  82. def _fuse_conv_bn(conv, bn):
  83. """Fuse conv and bn into one module.
  84. Args:
  85. conv (nn.Module): Conv to be fused.
  86. bn (nn.Module): BN to be fused.
  87. Returns:
  88. nn.Module: Fused module.
  89. """
  90. conv_w = conv.weight
  91. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  92. bn.running_mean)
  93. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  94. conv.weight = nn.Parameter(conv_w *
  95. factor.reshape([conv.out_channels, 1, 1, 1]))
  96. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  97. return conv
  98. for name, child in module.named_children():
  99. if isinstance(child,
  100. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  101. if last_conv is None: # only fuse BN that is after Conv
  102. continue
  103. fused_conv = _fuse_conv_bn(last_conv, child)
  104. module._modules[last_conv_name] = fused_conv
  105. # To reduce changes, set BN as Identity instead of deleting it.
  106. module._modules[name] = nn.Identity()
  107. last_conv = None
  108. elif isinstance(child, nn.Conv2d):
  109. last_conv = child
  110. last_conv_name = name
  111. else:
  112. fuse_conv_bn(child)
  113. return module
  114. ## replace module
  115. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  116. """
  117. Replace given type in module to a new type. mostly used in deploy.
  118. Args:
  119. module (nn.Module): model to apply replace operation.
  120. replaced_module_type (Type): module type to be replaced.
  121. new_module_type (Type)
  122. replace_func (function): python function to describe replace logic. Defalut value None.
  123. Returns:
  124. model (nn.Module): module that already been replaced.
  125. """
  126. def default_replace_func(replaced_module_type, new_module_type):
  127. return new_module_type()
  128. if replace_func is None:
  129. replace_func = default_replace_func
  130. model = module
  131. if isinstance(module, replaced_module_type):
  132. model = replace_func(replaced_module_type, new_module_type)
  133. else: # recurrsively replace
  134. for name, child in module.named_children():
  135. new_child = replace_module(child, replaced_module_type, new_module_type)
  136. if new_child is not child: # child is already replaced
  137. model.add_module(name, new_child)
  138. return model
  139. ## compute FLOPs & Parameters
  140. def compute_flops(model, img_size, device):
  141. x = torch.randn(1, 3, img_size, img_size).to(device)
  142. print('==============================')
  143. flops, params = profile(model, inputs=(x, ), verbose=False)
  144. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  145. print('Params : {:.2f} M'.format(params / 1e6))
  146. ## load trained weight
  147. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  148. # check ckpt file
  149. if path_to_ckpt is None:
  150. print('no weight file ...')
  151. else:
  152. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  153. print('--------------------------------------')
  154. print('Best model infor:')
  155. print('Epoch: {}'.format(checkpoint["epoch"]))
  156. print('mAP: {}'.format(checkpoint["mAP"]))
  157. print('--------------------------------------')
  158. checkpoint_state_dict = checkpoint["model"]
  159. model.load_state_dict(checkpoint_state_dict)
  160. print('Finished loading model!')
  161. # fuse conv & bn
  162. if fuse_cbn:
  163. print('Fusing Conv & BN ...')
  164. model = fuse_conv_bn(model)
  165. return model
  166. ## Model EMA
  167. class ModelEMA(object):
  168. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  169. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  170. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  171. """
  172. def __init__(self, cfg, model, updates=0):
  173. # Create EMA
  174. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  175. self.updates = updates # number of EMA updates
  176. self.decay = lambda x: cfg['ema_decay'] * (1 - math.exp(-x / cfg['ema_tau'])) # decay exponential ramp (to help early epochs)
  177. for p in self.ema.parameters():
  178. p.requires_grad_(False)
  179. def is_parallel(self, model):
  180. # Returns True if model is of type DP or DDP
  181. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  182. def de_parallel(self, model):
  183. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  184. return model.module if self.is_parallel(model) else model
  185. def copy_attr(self, a, b, include=(), exclude=()):
  186. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  187. for k, v in b.__dict__.items():
  188. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  189. continue
  190. else:
  191. setattr(a, k, v)
  192. def update(self, model):
  193. # Update EMA parameters
  194. self.updates += 1
  195. d = self.decay(self.updates)
  196. msd = self.de_parallel(model).state_dict() # model state_dict
  197. for k, v in self.ema.state_dict().items():
  198. if v.dtype.is_floating_point: # true for FP16 and FP32
  199. v *= d
  200. v += (1 - d) * msd[k].detach()
  201. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  202. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  203. # Update EMA attributes
  204. self.copy_attr(self.ema, model, include, exclude)
  205. ## SiLU
  206. class SiLU(nn.Module):
  207. """export-friendly version of nn.SiLU()"""
  208. @staticmethod
  209. def forward(x):
  210. return x * torch.sigmoid(x)
  211. # ---------------------------- NMS ----------------------------
  212. ## basic NMS
  213. def nms(bboxes, scores, nms_thresh):
  214. """"Pure Python NMS."""
  215. x1 = bboxes[:, 0] #xmin
  216. y1 = bboxes[:, 1] #ymin
  217. x2 = bboxes[:, 2] #xmax
  218. y2 = bboxes[:, 3] #ymax
  219. areas = (x2 - x1) * (y2 - y1)
  220. order = scores.argsort()[::-1]
  221. keep = []
  222. while order.size > 0:
  223. i = order[0]
  224. keep.append(i)
  225. # compute iou
  226. xx1 = np.maximum(x1[i], x1[order[1:]])
  227. yy1 = np.maximum(y1[i], y1[order[1:]])
  228. xx2 = np.minimum(x2[i], x2[order[1:]])
  229. yy2 = np.minimum(y2[i], y2[order[1:]])
  230. w = np.maximum(1e-10, xx2 - xx1)
  231. h = np.maximum(1e-10, yy2 - yy1)
  232. inter = w * h
  233. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  234. #reserve all the boundingbox whose ovr less than thresh
  235. inds = np.where(iou <= nms_thresh)[0]
  236. order = order[inds + 1]
  237. return keep
  238. ## class-agnostic NMS
  239. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  240. # nms
  241. keep = nms(bboxes, scores, nms_thresh)
  242. scores = scores[keep]
  243. labels = labels[keep]
  244. bboxes = bboxes[keep]
  245. return scores, labels, bboxes
  246. ## class-aware NMS
  247. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  248. # nms
  249. keep = np.zeros(len(bboxes), dtype=np.int32)
  250. for i in range(num_classes):
  251. inds = np.where(labels == i)[0]
  252. if len(inds) == 0:
  253. continue
  254. c_bboxes = bboxes[inds]
  255. c_scores = scores[inds]
  256. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  257. keep[inds[c_keep]] = 1
  258. keep = np.where(keep > 0)
  259. scores = scores[keep]
  260. labels = labels[keep]
  261. bboxes = bboxes[keep]
  262. return scores, labels, bboxes
  263. ## multi-class NMS
  264. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  265. if class_agnostic:
  266. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  267. else:
  268. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  269. def non_max_suppression(
  270. prediction,
  271. conf_thres=0.25,
  272. iou_thres=0.45,
  273. classes=None,
  274. agnostic=False,
  275. multi_label=False,
  276. max_det=300,
  277. nc=0, # number of classes (optional)
  278. max_nms=30000,
  279. max_wh=7680,
  280. ):
  281. """
  282. Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
  283. Args:
  284. prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
  285. containing the predicted boxes, classes, and masks. The tensor should be in the format
  286. output by a model, such as YOLO.
  287. conf_thres (float): The confidence threshold below which boxes will be filtered out.
  288. Valid values are between 0.0 and 1.0.
  289. iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
  290. Valid values are between 0.0 and 1.0.
  291. classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
  292. agnostic (bool): If True, the model is agnostic to the number of classes, and all
  293. classes will be considered as one.
  294. multi_label (bool): If True, each box may have multiple labels.
  295. labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
  296. list contains the apriori labels for a given image. The list should be in the format
  297. output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
  298. max_det (int): The maximum number of boxes to keep after NMS.
  299. nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
  300. max_time_img (float): The maximum time (seconds) for processing one image.
  301. max_nms (int): The maximum number of boxes into torchvision.ops.nms().
  302. max_wh (int): The maximum box width and height in pixels
  303. Returns:
  304. (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
  305. shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
  306. (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
  307. """
  308. # Checks
  309. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  310. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  311. device = prediction.device # [N, C+4]
  312. nc = nc or (prediction.shape[1] - 4) # number of classes
  313. xc = prediction[:, 4:].amax(1) > conf_thres # candidates
  314. # Settings
  315. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  316. output = torch.zeros((0, 6), device=device)
  317. # Apply constraints
  318. prediction = prediction[xc] # confidence
  319. # If none remain process next image
  320. if not prediction.shape[0]:
  321. pass
  322. # Detections matrix nx6 (xyxy, conf, cls)
  323. box, cls = prediction.split((4, nc), 1)
  324. if multi_label:
  325. i, j = torch.where(cls > conf_thres)
  326. prediction = torch.cat((box[i], prediction[i, 4 + j, None], j[:, None].float()), 1)
  327. else: # best class only
  328. conf, j = cls.max(1, keepdim=True)
  329. prediction = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  330. # Filter by class
  331. if classes is not None:
  332. prediction = prediction[(prediction[:, 5:6] == torch.tensor(classes, device=device)).any(1)]
  333. # Check shape
  334. n = prediction.shape[0] # number of boxes
  335. if n > max_nms: # excess boxes
  336. prediction = prediction[prediction[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  337. # Batched NMS
  338. c = prediction[:, 5:6] * (0 if agnostic else max_wh) # classes
  339. boxes, scores = prediction[:, :4] + c, prediction[:, 4] # boxes (offset by class), scores
  340. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  341. i = i[:max_det] # limit detections
  342. output = prediction[i]
  343. return output
  344. # ---------------------------- Processor for Deployment ----------------------------
  345. ## Pre-processer
  346. class PreProcessor(object):
  347. def __init__(self, img_size):
  348. self.img_size = img_size
  349. self.input_size = [img_size, img_size]
  350. def __call__(self, image, swap=(2, 0, 1)):
  351. """
  352. Input:
  353. image: (ndarray) [H, W, 3] or [H, W]
  354. formar: color format
  355. """
  356. if len(image.shape) == 3:
  357. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  358. else:
  359. padded_img = np.ones(self.input_size, np.float32) * 114.
  360. # resize
  361. orig_h, orig_w = image.shape[:2]
  362. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  363. resize_size = (int(orig_w * r), int(orig_h * r))
  364. if r != 1:
  365. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  366. else:
  367. resized_img = image
  368. # padding
  369. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  370. # [H, W, C] -> [C, H, W]
  371. padded_img = padded_img.transpose(swap)
  372. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
  373. return padded_img, r
  374. ## Post-processer
  375. class PostProcessor(object):
  376. def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  377. self.num_classes = num_classes
  378. self.conf_thresh = conf_thresh
  379. self.nms_thresh = nms_thresh
  380. def __call__(self, predictions):
  381. """
  382. Input:
  383. predictions: (ndarray) [n_anchors_all, 4+1+C]
  384. """
  385. bboxes = predictions[..., :4]
  386. scores = predictions[..., 4:]
  387. # scores & labels
  388. labels = np.argmax(scores, axis=1) # [M,]
  389. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  390. # thresh
  391. keep = np.where(scores > self.conf_thresh)
  392. scores = scores[keep]
  393. labels = labels[keep]
  394. bboxes = bboxes[keep]
  395. # nms
  396. scores, labels, bboxes = multiclass_nms(
  397. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  398. return bboxes, scores, labels