misc.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import numpy as np
  6. import os
  7. import math
  8. from copy import deepcopy
  9. from evaluator.coco_evaluator import COCOAPIEvaluator
  10. from evaluator.voc_evaluator import VOCAPIEvaluator
  11. from evaluator.ourdataset_evaluator import OurDatasetEvaluator
  12. from dataset.voc import VOCDetection, VOC_CLASSES
  13. from dataset.coco import COCODataset, coco_class_index, coco_class_labels
  14. from dataset.ourdataset import OurDataset, our_class_labels
  15. from dataset.data_augment import build_transform
  16. from utils import fuse_conv_bn
  17. from models.yolov7.yolov7_basic import RepConv
  18. # ---------------------------- For Dataset ----------------------------
  19. ## build dataset
  20. def build_dataset(args, trans_config, device, is_train=False):
  21. # transform
  22. print('==============================')
  23. print('Transform Config: {}'.format(trans_config))
  24. train_transform = build_transform(args.img_size, trans_config, True)
  25. val_transform = build_transform(args.img_size, trans_config, False)
  26. # dataset
  27. if args.dataset == 'voc':
  28. data_dir = os.path.join(args.root, 'VOCdevkit')
  29. num_classes = 20
  30. class_names = VOC_CLASSES
  31. class_indexs = None
  32. # dataset
  33. dataset = VOCDetection(
  34. img_size=args.img_size,
  35. data_dir=data_dir,
  36. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  37. transform=train_transform,
  38. trans_config=trans_config,
  39. is_train=is_train
  40. )
  41. # evaluator
  42. evaluator = VOCAPIEvaluator(
  43. data_dir=data_dir,
  44. device=device,
  45. transform=val_transform
  46. )
  47. elif args.dataset == 'coco':
  48. data_dir = os.path.join(args.root, 'COCO')
  49. num_classes = 80
  50. class_names = coco_class_labels
  51. class_indexs = coco_class_index
  52. # dataset
  53. dataset = COCODataset(
  54. img_size=args.img_size,
  55. data_dir=data_dir,
  56. image_set='train2017' if is_train else 'val2017',
  57. transform=train_transform,
  58. trans_config=trans_config,
  59. is_train=is_train
  60. )
  61. # evaluator
  62. evaluator = COCOAPIEvaluator(
  63. data_dir=data_dir,
  64. device=device,
  65. transform=val_transform
  66. )
  67. elif args.dataset == 'ourdataset':
  68. data_dir = os.path.join(args.root, 'OurDataset')
  69. class_names = our_class_labels
  70. num_classes = len(our_class_labels)
  71. class_indexs = None
  72. # dataset
  73. dataset = OurDataset(
  74. data_dir=data_dir,
  75. img_size=args.img_size,
  76. image_set='train' if is_train else 'val',
  77. transform=train_transform,
  78. trans_config=trans_config,
  79. is_train=is_train
  80. )
  81. # evaluator
  82. evaluator = OurDatasetEvaluator(
  83. data_dir=data_dir,
  84. device=device,
  85. image_set='val',
  86. transform=val_transform
  87. )
  88. else:
  89. print('unknow dataset !! Only support voc, coco !!')
  90. exit(0)
  91. print('==============================')
  92. print('Training model on:', args.dataset)
  93. print('The dataset size:', len(dataset))
  94. return dataset, (num_classes, class_names, class_indexs), evaluator
  95. ## build dataloader
  96. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  97. # distributed
  98. if args.distributed:
  99. sampler = DistributedSampler(dataset)
  100. else:
  101. sampler = torch.utils.data.RandomSampler(dataset)
  102. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  103. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  104. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  105. return dataloader
  106. ## collate_fn for dataloader
  107. class CollateFunc(object):
  108. def __call__(self, batch):
  109. targets = []
  110. images = []
  111. for sample in batch:
  112. image = sample[0]
  113. target = sample[1]
  114. images.append(image)
  115. targets.append(target)
  116. images = torch.stack(images, 0) # [B, C, H, W]
  117. return images, targets
  118. # ---------------------------- For Model ----------------------------
  119. ## load trained weight
  120. def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_repconv=False):
  121. # check ckpt file
  122. if path_to_ckpt is None:
  123. print('no weight file ...')
  124. else:
  125. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  126. checkpoint_state_dict = checkpoint.pop("model")
  127. model.load_state_dict(checkpoint_state_dict)
  128. print('Finished loading model!')
  129. # fuse repconv
  130. if fuse_repconv:
  131. print('Fusing RepConv block ...')
  132. for m in model.modules():
  133. if isinstance(m, RepConv):
  134. m.fuse_repvgg_block()
  135. # fuse conv & bn
  136. if fuse_cbn:
  137. print('Fusing Conv & BN ...')
  138. model = fuse_conv_bn.fuse_conv_bn(model)
  139. return model
  140. ## Model EMA
  141. class ModelEMA(object):
  142. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  143. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  144. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  145. """
  146. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  147. # Create EMA
  148. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  149. self.updates = updates # number of EMA updates
  150. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  151. for p in self.ema.parameters():
  152. p.requires_grad_(False)
  153. def is_parallel(self, model):
  154. # Returns True if model is of type DP or DDP
  155. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  156. def de_parallel(self, model):
  157. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  158. return model.module if self.is_parallel(model) else model
  159. def copy_attr(self, a, b, include=(), exclude=()):
  160. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  161. for k, v in b.__dict__.items():
  162. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  163. continue
  164. else:
  165. setattr(a, k, v)
  166. def update(self, model):
  167. # Update EMA parameters
  168. self.updates += 1
  169. d = self.decay(self.updates)
  170. msd = self.de_parallel(model).state_dict() # model state_dict
  171. for k, v in self.ema.state_dict().items():
  172. if v.dtype.is_floating_point: # true for FP16 and FP32
  173. v *= d
  174. v += (1 - d) * msd[k].detach()
  175. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  176. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  177. # Update EMA attributes
  178. self.copy_attr(self.ema, model, include, exclude)