misc.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import os
  6. import cv2
  7. import math
  8. import numpy as np
  9. from copy import deepcopy
  10. from thop import profile
  11. from evaluator.coco_evaluator import COCOAPIEvaluator
  12. from evaluator.voc_evaluator import VOCAPIEvaluator
  13. from evaluator.ourdataset_evaluator import OurDatasetEvaluator
  14. from dataset.voc import VOCDetection, VOC_CLASSES
  15. from dataset.coco import COCODataset, coco_class_index, coco_class_labels
  16. from dataset.ourdataset import OurDataset, our_class_labels
  17. from dataset.data_augment import build_transform
  18. # ---------------------------- For Dataset ----------------------------
  19. ## build dataset
  20. def build_dataset(args, trans_config, device, is_train=False):
  21. # transform
  22. print('==============================')
  23. print('Transform Config: {}'.format(trans_config))
  24. train_transform = build_transform(args.img_size, trans_config, True)
  25. val_transform = build_transform(args.img_size, trans_config, False)
  26. # dataset
  27. if args.dataset == 'voc':
  28. data_dir = os.path.join(args.root, 'VOCdevkit')
  29. num_classes = 20
  30. class_names = VOC_CLASSES
  31. class_indexs = None
  32. # dataset
  33. dataset = VOCDetection(
  34. img_size=args.img_size,
  35. data_dir=data_dir,
  36. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  37. transform=train_transform,
  38. trans_config=trans_config,
  39. is_train=is_train
  40. )
  41. # evaluator
  42. evaluator = VOCAPIEvaluator(
  43. data_dir=data_dir,
  44. device=device,
  45. transform=val_transform
  46. )
  47. elif args.dataset == 'coco':
  48. data_dir = os.path.join(args.root, 'COCO')
  49. num_classes = 80
  50. class_names = coco_class_labels
  51. class_indexs = coco_class_index
  52. # dataset
  53. dataset = COCODataset(
  54. img_size=args.img_size,
  55. data_dir=data_dir,
  56. image_set='train2017' if is_train else 'val2017',
  57. transform=train_transform,
  58. trans_config=trans_config,
  59. is_train=is_train
  60. )
  61. # evaluator
  62. evaluator = COCOAPIEvaluator(
  63. data_dir=data_dir,
  64. device=device,
  65. transform=val_transform
  66. )
  67. elif args.dataset == 'ourdataset':
  68. data_dir = os.path.join(args.root, 'OurDataset')
  69. class_names = our_class_labels
  70. num_classes = len(our_class_labels)
  71. class_indexs = None
  72. # dataset
  73. dataset = OurDataset(
  74. data_dir=data_dir,
  75. img_size=args.img_size,
  76. image_set='train' if is_train else 'val',
  77. transform=train_transform,
  78. trans_config=trans_config,
  79. is_train=is_train
  80. )
  81. # evaluator
  82. evaluator = OurDatasetEvaluator(
  83. data_dir=data_dir,
  84. device=device,
  85. image_set='val',
  86. transform=val_transform
  87. )
  88. else:
  89. print('unknow dataset !! Only support voc, coco !!')
  90. exit(0)
  91. print('==============================')
  92. print('Training model on:', args.dataset)
  93. print('The dataset size:', len(dataset))
  94. return dataset, (num_classes, class_names, class_indexs), evaluator
  95. ## build dataloader
  96. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  97. # distributed
  98. if args.distributed:
  99. sampler = DistributedSampler(dataset)
  100. else:
  101. sampler = torch.utils.data.RandomSampler(dataset)
  102. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  103. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  104. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  105. return dataloader
  106. ## collate_fn for dataloader
  107. class CollateFunc(object):
  108. def __call__(self, batch):
  109. targets = []
  110. images = []
  111. for sample in batch:
  112. image = sample[0]
  113. target = sample[1]
  114. images.append(image)
  115. targets.append(target)
  116. images = torch.stack(images, 0) # [B, C, H, W]
  117. return images, targets
  118. # ---------------------------- For Model ----------------------------
  119. ## fuse Conv & BN layer
  120. def fuse_conv_bn(module):
  121. """Recursively fuse conv and bn in a module.
  122. During inference, the functionary of batch norm layers is turned off
  123. but only the mean and var alone channels are used, which exposes the
  124. chance to fuse it with the preceding conv layers to save computations and
  125. simplify network structures.
  126. Args:
  127. module (nn.Module): Module to be fused.
  128. Returns:
  129. nn.Module: Fused module.
  130. """
  131. last_conv = None
  132. last_conv_name = None
  133. def _fuse_conv_bn(conv, bn):
  134. """Fuse conv and bn into one module.
  135. Args:
  136. conv (nn.Module): Conv to be fused.
  137. bn (nn.Module): BN to be fused.
  138. Returns:
  139. nn.Module: Fused module.
  140. """
  141. conv_w = conv.weight
  142. conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
  143. bn.running_mean)
  144. factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
  145. conv.weight = nn.Parameter(conv_w *
  146. factor.reshape([conv.out_channels, 1, 1, 1]))
  147. conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
  148. return conv
  149. for name, child in module.named_children():
  150. if isinstance(child,
  151. (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
  152. if last_conv is None: # only fuse BN that is after Conv
  153. continue
  154. fused_conv = _fuse_conv_bn(last_conv, child)
  155. module._modules[last_conv_name] = fused_conv
  156. # To reduce changes, set BN as Identity instead of deleting it.
  157. module._modules[name] = nn.Identity()
  158. last_conv = None
  159. elif isinstance(child, nn.Conv2d):
  160. last_conv = child
  161. last_conv_name = name
  162. else:
  163. fuse_conv_bn(child)
  164. return module
  165. ## replace module
  166. def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
  167. """
  168. Replace given type in module to a new type. mostly used in deploy.
  169. Args:
  170. module (nn.Module): model to apply replace operation.
  171. replaced_module_type (Type): module type to be replaced.
  172. new_module_type (Type)
  173. replace_func (function): python function to describe replace logic. Defalut value None.
  174. Returns:
  175. model (nn.Module): module that already been replaced.
  176. """
  177. def default_replace_func(replaced_module_type, new_module_type):
  178. return new_module_type()
  179. if replace_func is None:
  180. replace_func = default_replace_func
  181. model = module
  182. if isinstance(module, replaced_module_type):
  183. model = replace_func(replaced_module_type, new_module_type)
  184. else: # recurrsively replace
  185. for name, child in module.named_children():
  186. new_child = replace_module(child, replaced_module_type, new_module_type)
  187. if new_child is not child: # child is already replaced
  188. model.add_module(name, new_child)
  189. return model
  190. ## compute FLOPs & Parameters
  191. def compute_flops(model, img_size, device):
  192. x = torch.randn(1, 3, img_size, img_size).to(device)
  193. print('==============================')
  194. flops, params = profile(model, inputs=(x, ), verbose=False)
  195. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  196. print('Params : {:.2f} M'.format(params / 1e6))
  197. ## load trained weight
  198. def load_weight(model, path_to_ckpt, fuse_cbn=False):
  199. # check ckpt file
  200. if path_to_ckpt is None:
  201. print('no weight file ...')
  202. else:
  203. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  204. checkpoint_state_dict = checkpoint.pop("model")
  205. model.load_state_dict(checkpoint_state_dict)
  206. print('Finished loading model!')
  207. # fuse conv & bn
  208. if fuse_cbn:
  209. print('Fusing Conv & BN ...')
  210. model = fuse_conv_bn(model)
  211. return model
  212. ## Model EMA
  213. class ModelEMA(object):
  214. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  215. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  216. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  217. """
  218. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  219. # Create EMA
  220. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  221. self.updates = updates # number of EMA updates
  222. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  223. for p in self.ema.parameters():
  224. p.requires_grad_(False)
  225. def is_parallel(self, model):
  226. # Returns True if model is of type DP or DDP
  227. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  228. def de_parallel(self, model):
  229. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  230. return model.module if self.is_parallel(model) else model
  231. def copy_attr(self, a, b, include=(), exclude=()):
  232. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  233. for k, v in b.__dict__.items():
  234. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  235. continue
  236. else:
  237. setattr(a, k, v)
  238. def update(self, model):
  239. # Update EMA parameters
  240. self.updates += 1
  241. d = self.decay(self.updates)
  242. msd = self.de_parallel(model).state_dict() # model state_dict
  243. for k, v in self.ema.state_dict().items():
  244. if v.dtype.is_floating_point: # true for FP16 and FP32
  245. v *= d
  246. v += (1 - d) * msd[k].detach()
  247. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  248. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  249. # Update EMA attributes
  250. self.copy_attr(self.ema, model, include, exclude)
  251. # ---------------------------- NMS ----------------------------
  252. ## basic NMS
  253. def nms(bboxes, scores, nms_thresh):
  254. """"Pure Python NMS."""
  255. x1 = bboxes[:, 0] #xmin
  256. y1 = bboxes[:, 1] #ymin
  257. x2 = bboxes[:, 2] #xmax
  258. y2 = bboxes[:, 3] #ymax
  259. areas = (x2 - x1) * (y2 - y1)
  260. order = scores.argsort()[::-1]
  261. keep = []
  262. while order.size > 0:
  263. i = order[0]
  264. keep.append(i)
  265. # compute iou
  266. xx1 = np.maximum(x1[i], x1[order[1:]])
  267. yy1 = np.maximum(y1[i], y1[order[1:]])
  268. xx2 = np.minimum(x2[i], x2[order[1:]])
  269. yy2 = np.minimum(y2[i], y2[order[1:]])
  270. w = np.maximum(1e-10, xx2 - xx1)
  271. h = np.maximum(1e-10, yy2 - yy1)
  272. inter = w * h
  273. iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
  274. #reserve all the boundingbox whose ovr less than thresh
  275. inds = np.where(iou <= nms_thresh)[0]
  276. order = order[inds + 1]
  277. return keep
  278. ## class-agnostic NMS
  279. def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
  280. # nms
  281. keep = nms(bboxes, scores, nms_thresh)
  282. scores = scores[keep]
  283. labels = labels[keep]
  284. bboxes = bboxes[keep]
  285. return scores, labels, bboxes
  286. ## class-aware NMS
  287. def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
  288. # nms
  289. keep = np.zeros(len(bboxes), dtype=np.int)
  290. for i in range(num_classes):
  291. inds = np.where(labels == i)[0]
  292. if len(inds) == 0:
  293. continue
  294. c_bboxes = bboxes[inds]
  295. c_scores = scores[inds]
  296. c_keep = nms(c_bboxes, c_scores, nms_thresh)
  297. keep[inds[c_keep]] = 1
  298. keep = np.where(keep > 0)
  299. scores = scores[keep]
  300. labels = labels[keep]
  301. bboxes = bboxes[keep]
  302. return scores, labels, bboxes
  303. ## multi-class NMS
  304. def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
  305. if class_agnostic:
  306. return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
  307. else:
  308. return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
  309. # ---------------------------- Processor for Deployment ----------------------------
  310. ## Pre-processer
  311. class PreProcessor(object):
  312. def __init__(self, img_size):
  313. self.img_size = img_size
  314. self.input_size = [img_size, img_size]
  315. def __call__(self, image, swap=(2, 0, 1)):
  316. """
  317. Input:
  318. image: (ndarray) [H, W, 3] or [H, W]
  319. formar: color format
  320. """
  321. if len(image.shape) == 3:
  322. padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
  323. else:
  324. padded_img = np.ones(self.input_size, np.float32) * 114.
  325. # resize
  326. orig_h, orig_w = image.shape[:2]
  327. r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
  328. resize_size = (int(orig_w * r), int(orig_h * r))
  329. if r != 1:
  330. resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
  331. else:
  332. resized_img = image
  333. # padding
  334. padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
  335. # [H, W, C] -> [C, H, W]
  336. padded_img = padded_img.transpose(swap)
  337. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
  338. return padded_img, r
  339. ## Post-processer
  340. class PostProcessor(object):
  341. def __init__(self, img_size, strides, num_classes, conf_thresh=0.15, nms_thresh=0.5):
  342. self.img_size = img_size
  343. self.strides = strides
  344. self.num_classes = num_classes
  345. self.conf_thresh = conf_thresh
  346. self.nms_thresh = nms_thresh
  347. def __call__(self, predictions):
  348. """
  349. Input:
  350. predictions: (ndarray) [n_anchors_all, 4+1+C]
  351. """
  352. bboxes = predictions[..., :4]
  353. obj_preds = predictions[..., 4:5]
  354. cls_preds = predictions[..., 5:]
  355. scores = np.sqrt(obj_preds * cls_preds)
  356. # scores & labels
  357. labels = np.argmax(scores, axis=1) # [M,]
  358. scores = scores[(np.arange(scores.shape[0]), labels)] # [M,]
  359. # thresh
  360. keep = np.where(scores > self.conf_thresh)
  361. scores = scores[keep]
  362. labels = labels[keep]
  363. bboxes = bboxes[keep]
  364. # nms
  365. scores, labels, bboxes = multiclass_nms(
  366. scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
  367. return bboxes, scores, labels