misc.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import numpy as np
  6. import os
  7. import math
  8. from copy import deepcopy
  9. from evaluator.coco_evaluator import COCOAPIEvaluator
  10. from evaluator.voc_evaluator import VOCAPIEvaluator
  11. from dataset.voc import VOCDetection, VOC_CLASSES
  12. from dataset.coco import COCODataset, coco_class_index, coco_class_labels
  13. from dataset.data_augment import build_transform
  14. from utils import fuse_conv_bn
  15. from models.yolov7.yolov7_basic import RepConv
  16. # ---------------------------- For Dataset ----------------------------
  17. ## build dataset
  18. def build_dataset(args, trans_config, device, is_train=False):
  19. # transform
  20. print('==============================')
  21. print('Transform Config: {}'.format(trans_config))
  22. train_transform = build_transform(args.img_size, trans_config, True)
  23. val_transform = build_transform(args.img_size, trans_config, False)
  24. # dataset
  25. if args.dataset == 'voc':
  26. data_dir = os.path.join(args.root, 'VOCdevkit')
  27. num_classes = 20
  28. class_names = VOC_CLASSES
  29. class_indexs = None
  30. # dataset
  31. dataset = VOCDetection(
  32. img_size=args.img_size,
  33. data_dir=data_dir,
  34. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  35. transform=train_transform,
  36. trans_config=trans_config,
  37. is_train=is_train
  38. )
  39. # evaluator
  40. evaluator = VOCAPIEvaluator(
  41. data_dir=data_dir,
  42. device=device,
  43. transform=val_transform
  44. )
  45. elif args.dataset == 'coco':
  46. data_dir = os.path.join(args.root, 'COCO')
  47. num_classes = 80
  48. class_names = coco_class_labels
  49. class_indexs = coco_class_index
  50. # dataset
  51. dataset = COCODataset(
  52. img_size=args.img_size,
  53. data_dir=data_dir,
  54. image_set='train2017' if is_train else 'val2017',
  55. transform=train_transform,
  56. trans_config=trans_config,
  57. is_train=is_train
  58. )
  59. # evaluator
  60. evaluator = COCOAPIEvaluator(
  61. data_dir=data_dir,
  62. device=device,
  63. transform=val_transform
  64. )
  65. else:
  66. print('unknow dataset !! Only support voc, coco !!')
  67. exit(0)
  68. print('==============================')
  69. print('Training model on:', args.dataset)
  70. print('The dataset size:', len(dataset))
  71. return dataset, (num_classes, class_names, class_indexs), evaluator
  72. ## build dataloader
  73. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  74. # distributed
  75. if args.distributed:
  76. sampler = DistributedSampler(dataset)
  77. else:
  78. sampler = torch.utils.data.RandomSampler(dataset)
  79. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  80. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  81. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  82. return dataloader
  83. ## collate_fn for dataloader
  84. class CollateFunc(object):
  85. def __call__(self, batch):
  86. targets = []
  87. images = []
  88. for sample in batch:
  89. image = sample[0]
  90. target = sample[1]
  91. images.append(image)
  92. targets.append(target)
  93. images = torch.stack(images, 0) # [B, C, H, W]
  94. return images, targets
  95. # ---------------------------- For Model ----------------------------
  96. ## load trained weight
  97. def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_repconv=False):
  98. # check ckpt file
  99. if path_to_ckpt is None:
  100. print('no weight file ...')
  101. else:
  102. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  103. checkpoint_state_dict = checkpoint.pop("model")
  104. model.load_state_dict(checkpoint_state_dict)
  105. print('Finished loading model!')
  106. # fuse repconv
  107. if fuse_repconv:
  108. print('Fusing RepConv block ...')
  109. for m in model.modules():
  110. if isinstance(m, RepConv):
  111. m.fuse_repvgg_block()
  112. # fuse conv & bn
  113. if fuse_cbn:
  114. print('Fusing Conv & BN ...')
  115. model = fuse_conv_bn.fuse_conv_bn(model)
  116. return model
  117. ## Model EMA
  118. class ModelEMA(object):
  119. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  120. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  121. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  122. """
  123. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  124. # Create EMA
  125. self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
  126. self.updates = updates # number of EMA updates
  127. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  128. for p in self.ema.parameters():
  129. p.requires_grad_(False)
  130. def is_parallel(self, model):
  131. # Returns True if model is of type DP or DDP
  132. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  133. def de_parallel(self, model):
  134. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  135. return model.module if self.is_parallel(model) else model
  136. def copy_attr(self, a, b, include=(), exclude=()):
  137. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  138. for k, v in b.__dict__.items():
  139. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  140. continue
  141. else:
  142. setattr(a, k, v)
  143. def update(self, model):
  144. # Update EMA parameters
  145. self.updates += 1
  146. d = self.decay(self.updates)
  147. msd = self.de_parallel(model).state_dict() # model state_dict
  148. for k, v in self.ema.state_dict().items():
  149. if v.dtype.is_floating_point: # true for FP16 and FP32
  150. v *= d
  151. v += (1 - d) * msd[k].detach()
  152. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  153. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  154. # Update EMA attributes
  155. self.copy_attr(self.ema, model, include, exclude)