misc.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, DistributedSampler
  5. import numpy as np
  6. import os
  7. import math
  8. from copy import deepcopy
  9. from evaluator.coco_evaluator import COCOAPIEvaluator
  10. from evaluator.voc_evaluator import VOCAPIEvaluator
  11. from dataset.voc import VOCDetection, VOC_CLASSES
  12. from dataset.coco import COCODataset, coco_class_index, coco_class_labels
  13. from dataset.data_augment import build_transform
  14. def build_dataset(args, trans_config, device, is_train=False):
  15. # transform
  16. print('==============================')
  17. print('Transform Config: {}'.format(trans_config))
  18. train_transform = build_transform(args.img_size, trans_config, True)
  19. val_transform = build_transform(args.img_size, trans_config, False)
  20. # dataset
  21. if args.dataset == 'voc':
  22. data_dir = os.path.join(args.root, 'VOCdevkit')
  23. num_classes = 20
  24. class_names = VOC_CLASSES
  25. class_indexs = None
  26. # dataset
  27. dataset = VOCDetection(
  28. img_size=args.img_size,
  29. data_dir=data_dir,
  30. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  31. transform=train_transform,
  32. trans_config=trans_config,
  33. is_train=is_train
  34. )
  35. # evaluator
  36. evaluator = VOCAPIEvaluator(
  37. data_dir=data_dir,
  38. device=device,
  39. transform=val_transform
  40. )
  41. elif args.dataset == 'coco':
  42. data_dir = os.path.join(args.root, 'COCO')
  43. num_classes = 80
  44. class_names = coco_class_labels
  45. class_indexs = coco_class_index
  46. # dataset
  47. dataset = COCODataset(
  48. img_size=args.img_size,
  49. data_dir=data_dir,
  50. image_set='train2017',
  51. transform=train_transform,
  52. trans_config=trans_config,
  53. is_train=is_train
  54. )
  55. # evaluator
  56. evaluator = COCOAPIEvaluator(
  57. data_dir=data_dir,
  58. device=device,
  59. transform=val_transform
  60. )
  61. else:
  62. print('unknow dataset !! Only support voc, coco !!')
  63. exit(0)
  64. print('==============================')
  65. print('Training model on:', args.dataset)
  66. print('The dataset size:', len(dataset))
  67. return dataset, (num_classes, class_names, class_indexs), evaluator
  68. def build_dataloader(args, dataset, batch_size, collate_fn=None):
  69. # distributed
  70. if args.distributed:
  71. sampler = DistributedSampler(dataset)
  72. else:
  73. sampler = torch.utils.data.RandomSampler(dataset)
  74. batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
  75. dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
  76. collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
  77. return dataloader
  78. def load_weight(model, path_to_ckpt):
  79. # check ckpt file
  80. if path_to_ckpt is None:
  81. print('no weight file ...')
  82. return model
  83. checkpoint = torch.load(path_to_ckpt, map_location='cpu')
  84. try:
  85. checkpoint_state_dict = checkpoint.pop("model")
  86. except:
  87. checkpoint_state_dict = checkpoint
  88. model.load_state_dict(checkpoint_state_dict)
  89. print('Finished loading model!')
  90. return model
  91. def is_parallel(model):
  92. # Returns True if model is of type DP or DDP
  93. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  94. # Model EMA
  95. class ModelEMA(object):
  96. def __init__(self, model, decay=0.9999, updates=0):
  97. # create EMA
  98. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  99. self.updates = updates
  100. self.decay = lambda x: decay * (1 - math.exp(-x / 2000.))
  101. for p in self.ema.parameters():
  102. p.requires_grad_(False)
  103. def update(self, model):
  104. # Update EMA parameters
  105. with torch.no_grad():
  106. self.updates += 1
  107. d = self.decay(self.updates)
  108. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  109. for k, v in self.ema.state_dict().items():
  110. if v.dtype.is_floating_point:
  111. v *= d
  112. v += (1. - d) * msd[k].detach()
  113. class CollateFunc(object):
  114. def __call__(self, batch):
  115. targets = []
  116. images = []
  117. for sample in batch:
  118. image = sample[0]
  119. target = sample[1]
  120. images.append(image)
  121. targets.append(target)
  122. images = torch.stack(images, 0) # [B, C, H, W]
  123. return images, targets