misc.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import numpy as np
  6. import os
  7. import math
  8. from copy import deepcopy
  9. from evaluator.coco_evaluator import COCOAPIEvaluator
  10. from evaluator.voc_evaluator import VOCAPIEvaluator
  11. from dataset.voc import VOCDetection, VOC_CLASSES
  12. from dataset.coco import COCODataset, coco_class_index, coco_class_labels
  13. from dataset.data_augment import build_transform
  14. def build_dataset(args, trans_config, device, is_train=False):
  15. # transform
  16. print('==============================')
  17. print('Transform Config: {}'.format(trans_config))
  18. train_transform = build_transform(args.img_size, trans_config, True)
  19. val_transform = build_transform(args.img_size, trans_config, False)
  20. # dataset
  21. if args.dataset == 'voc':
  22. data_dir = os.path.join(args.root, 'VOCdevkit')
  23. num_classes = 20
  24. class_names = VOC_CLASSES
  25. class_indexs = None
  26. # dataset
  27. dataset = VOCDetection(
  28. img_size=args.img_size,
  29. data_dir=data_dir,
  30. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  31. transform=train_transform,
  32. trans_config=trans_config,
  33. is_train=is_train
  34. )
  35. # evaluator
  36. evaluator = VOCAPIEvaluator(
  37. data_dir=data_dir,
  38. device=device,
  39. transform=val_transform
  40. )
  41. elif args.dataset == 'coco':
  42. data_dir = os.path.join(args.root, 'COCO')
  43. num_classes = 80
  44. class_names = coco_class_labels
  45. class_indexs = coco_class_index
  46. # dataset
  47. dataset = COCODataset(
  48. img_size=args.img_size,
  49. data_dir=data_dir,
  50. image_set='train2017' if is_train else 'val2017',
  51. transform=train_transform,
  52. trans_config=trans_config,
  53. is_train=is_train
  54. )
  55. # evaluator
  56. evaluator = COCOAPIEvaluator(
  57. data_dir=data_dir,
  58. device=device,
  59. transform=val_transform
  60. )
  61. else:
  62. print('unknow dataset !! Only support voc, coco !!')
  63. exit(0)
  64. print('==============================')
  65. print('Training model on:', args.dataset)
  66. print('The dataset size:', len(dataset))
  67. return dataset, (num_classes, class_names, class_indexs), evaluator
  68. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  69. # distributed
  70. if args.distributed:
  71. sampler = DistributedSampler(dataset)
  72. else:
  73. sampler = torch.utils.data.RandomSampler(dataset)
  74. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  75. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  76. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  77. return dataloader
  78. def load_weight(model, path_to_ckpt):
  79. # check ckpt file
  80. if path_to_ckpt is None:
  81. print('no weight file ...')
  82. return model
  83. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  84. try:
  85. checkpoint_state_dict = checkpoint.pop("model")
  86. except:
  87. checkpoint_state_dict = checkpoint
  88. model.load_state_dict(checkpoint_state_dict)
  89. print('Finished loading model!')
  90. return model
  91. def is_parallel(model):
  92. # Returns True if model is of type DP or DDP
  93. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  94. def de_parallel(model):
  95. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  96. return model.module if is_parallel(model) else model
  97. def copy_attr(a, b, include=(), exclude=()):
  98. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  99. for k, v in b.__dict__.items():
  100. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  101. continue
  102. else:
  103. setattr(a, k, v)
  104. # Model EMA
  105. class ModelEMA(object):
  106. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  107. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  108. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  109. """
  110. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  111. # Create EMA
  112. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  113. self.updates = updates # number of EMA updates
  114. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  115. for p in self.ema.parameters():
  116. p.requires_grad_(False)
  117. def update(self, model):
  118. # Update EMA parameters
  119. self.updates += 1
  120. d = self.decay(self.updates)
  121. msd = de_parallel(model).state_dict() # model state_dict
  122. for k, v in self.ema.state_dict().items():
  123. if v.dtype.is_floating_point: # true for FP16 and FP32
  124. v *= d
  125. v += (1 - d) * msd[k].detach()
  126. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  127. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  128. # Update EMA attributes
  129. copy_attr(self.ema, model, include, exclude)
  130. class CollateFunc(object):
  131. def __call__(self, batch):
  132. targets = []
  133. images = []
  134. for sample in batch:
  135. image = sample[0]
  136. target = sample[1]
  137. images.append(image)
  138. targets.append(target)
  139. images = torch.stack(images, 0) # [B, C, H, W]
  140. return images, targets