engine.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import random
  5. # ----------------- Extra Components -----------------
  6. from utils import distributed_utils
  7. from utils.misc import MetricLogger, SmoothedValue
  8. from utils.vis_tools import vis_data
  9. # ----------------- Optimizer & LrScheduler Components -----------------
  10. from utils.solver.optimizer import build_yolo_optimizer
  11. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
  12. class YoloTrainer(object):
  13. def __init__(self,
  14. # Basic parameters
  15. args,
  16. cfg,
  17. device,
  18. # Model parameters
  19. model,
  20. model_ema,
  21. criterion,
  22. # Data parameters
  23. train_loader,
  24. evaluator,
  25. ):
  26. # ------------------- basic parameters -------------------
  27. self.args = args
  28. self.cfg = cfg
  29. self.epoch = 0
  30. self.best_map = -1.
  31. self.device = device
  32. self.criterion = criterion
  33. self.heavy_eval = False
  34. self.model_ema = model_ema
  35. # weak augmentatino stage
  36. self.second_stage = False
  37. self.second_stage_epoch = cfg.no_aug_epoch
  38. # path to save model
  39. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  40. os.makedirs(self.path_to_save, exist_ok=True)
  41. # ---------------------------- Dataset & Dataloader ----------------------------
  42. self.train_loader = train_loader
  43. # ---------------------------- Evaluator ----------------------------
  44. self.evaluator = evaluator
  45. # ---------------------------- Build Grad. Scaler ----------------------------
  46. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  47. # ---------------------------- Build Optimizer ----------------------------
  48. self.grad_accumulate = max(cfg.batch_size_base // args.batch_size, 1)
  49. cfg.base_lr = cfg.base_lr / cfg.batch_size_base * args.batch_size * self.grad_accumulate # Auto scale learning rate
  50. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  51. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  52. # ---------------------------- Build LR Scheduler ----------------------------
  53. warmup_iters = cfg.warmup_epoch * len(self.train_loader)
  54. self.lr_scheduler_warmup = LinearWarmUpLrScheduler(warmup_iters, cfg.base_lr, cfg.warmup_bias_lr)
  55. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  56. self.best_map = cfg.best_map / 100.0
  57. print("Best mAP metric: {}".format(self.best_map))
  58. def train(self, model):
  59. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  60. if self.args.distributed:
  61. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  62. # check second stage
  63. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  64. self.check_second_stage()
  65. # save model of the last mosaic epoch
  66. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  67. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  68. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  69. torch.save({'model': model.state_dict(),
  70. 'mAP': round(self.evaluator.map*100, 1),
  71. 'optimizer': self.optimizer.state_dict(),
  72. 'epoch': self.epoch,
  73. 'args': self.args},
  74. checkpoint_path)
  75. # train one epoch
  76. self.epoch = epoch
  77. self.train_one_epoch(model)
  78. # LR Schedule
  79. if (epoch + 1) > self.cfg.warmup_epoch:
  80. self.lr_scheduler.step()
  81. # eval one epoch
  82. if self.heavy_eval:
  83. model_eval = model.module if self.args.distributed else model
  84. self.eval(model_eval)
  85. else:
  86. model_eval = model.module if self.args.distributed else model
  87. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  88. self.eval(model_eval)
  89. if self.args.debug:
  90. print("For debug mode, we only train 1 epoch")
  91. break
  92. def eval(self, model):
  93. # set eval mode
  94. model.eval()
  95. model_eval = model if self.model_ema is None else self.model_ema.ema
  96. cur_map = -1.
  97. to_save = False
  98. if distributed_utils.is_main_process():
  99. if self.evaluator is None:
  100. print('No evaluator ... save model and go on training.')
  101. to_save = True
  102. weight_name = '{}_no_eval.pth'.format(self.args.model)
  103. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  104. else:
  105. print('Eval ...')
  106. # Evaluate
  107. with torch.no_grad():
  108. self.evaluator.evaluate(model_eval)
  109. cur_map = self.evaluator.map
  110. if cur_map > self.best_map:
  111. # update best-map
  112. self.best_map = cur_map
  113. to_save = True
  114. # Save model
  115. if to_save:
  116. print('Saving state, epoch:', self.epoch)
  117. weight_name = '{}_best.pth'.format(self.args.model)
  118. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  119. state_dicts = {
  120. 'model': model_eval.state_dict(),
  121. 'mAP': round(cur_map*100, 3),
  122. 'optimizer': self.optimizer.state_dict(),
  123. 'lr_scheduler': self.lr_scheduler.state_dict(),
  124. 'epoch': self.epoch,
  125. 'args': self.args,
  126. }
  127. if self.model_ema is not None:
  128. state_dicts["ema_updates"] = self.model_ema.updates
  129. torch.save(state_dicts, checkpoint_path)
  130. if self.args.distributed:
  131. # wait for all processes to synchronize
  132. dist.barrier()
  133. # set train mode.
  134. model.train()
  135. def train_one_epoch(self, model):
  136. metric_logger = MetricLogger(delimiter=" ")
  137. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  138. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  139. metric_logger.add_meter('gnorm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  140. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  141. epoch_size = len(self.train_loader)
  142. print_freq = 10
  143. gnorm = 0.0
  144. # basic parameters
  145. epoch_size = len(self.train_loader)
  146. img_size = self.cfg.train_img_size
  147. nw = epoch_size * self.cfg.warmup_epoch
  148. # Train one epoch
  149. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  150. ni = iter_i + self.epoch * epoch_size
  151. # Warmup
  152. if nw > 0 and ni < nw:
  153. self.lr_scheduler_warmup(ni, self.optimizer)
  154. elif ni == nw:
  155. print("Warmup stage is over.")
  156. self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
  157. # To device
  158. images = images.to(self.device, non_blocking=True).float()
  159. # Multi scale
  160. images, targets, img_size = self.rescale_image_targets(
  161. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  162. # Visualize train targets
  163. if self.args.vis_tgt:
  164. vis_data(images,
  165. targets,
  166. self.cfg.num_classes,
  167. self.cfg.pixel_mean,
  168. self.cfg.pixel_std,
  169. )
  170. # Inference
  171. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  172. outputs = model(images)
  173. # Compute loss
  174. loss_dict = self.criterion(outputs=outputs, targets=targets)
  175. losses = loss_dict['losses']
  176. losses /= self.grad_accumulate
  177. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  178. # Backward
  179. self.scaler.scale(losses).backward()
  180. # Optimize
  181. if (iter_i + 1) % self.grad_accumulate == 0:
  182. if self.cfg.clip_max_norm > 0:
  183. self.scaler.unscale_(self.optimizer)
  184. gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  185. self.scaler.step(self.optimizer)
  186. self.scaler.update()
  187. self.optimizer.zero_grad()
  188. # ModelEMA
  189. if self.model_ema is not None:
  190. self.model_ema.update(model)
  191. # Update log
  192. metric_logger.update(**loss_dict_reduced)
  193. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  194. metric_logger.update(size=img_size)
  195. metric_logger.update(gnorm=gnorm)
  196. if self.args.debug:
  197. print("For debug mode, we only train 1 iteration")
  198. break
  199. # Gather the stats from all processes
  200. metric_logger.synchronize_between_processes()
  201. print("Averaged stats:", metric_logger)
  202. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  203. """
  204. Deployed for Multi scale trick.
  205. """
  206. # During training phase, the shape of input image is square.
  207. old_img_size = images.shape[-1]
  208. min_img_size = old_img_size * multi_scale_range[0]
  209. max_img_size = old_img_size * multi_scale_range[1]
  210. # Choose a new image size
  211. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  212. # Resize
  213. if new_img_size != old_img_size:
  214. # interpolate
  215. images = torch.nn.functional.interpolate(
  216. input=images,
  217. size=new_img_size,
  218. mode='bilinear',
  219. align_corners=False)
  220. # rescale targets
  221. for tgt in targets:
  222. boxes = tgt["boxes"].clone()
  223. labels = tgt["labels"].clone()
  224. boxes = torch.clamp(boxes, 0, old_img_size)
  225. # rescale box
  226. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  227. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  228. # refine tgt
  229. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  230. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  231. keep = (min_tgt_size >= 8)
  232. tgt["boxes"] = boxes[keep]
  233. tgt["labels"] = labels[keep]
  234. return images, targets, new_img_size
  235. def check_second_stage(self):
  236. # set second stage
  237. print('============== Second stage of Training ==============')
  238. self.second_stage = True
  239. self.heavy_eval = True
  240. # close mosaic augmentation
  241. if self.train_loader.dataset.mosaic_prob > 0.:
  242. print(' - Close < Mosaic Augmentation > ...')
  243. self.train_loader.dataset.mosaic_prob = 0.
  244. # close mixup augmentation
  245. if self.train_loader.dataset.mixup_prob > 0.:
  246. print(' - Close < Mixup Augmentation > ...')
  247. self.train_loader.dataset.mixup_prob = 0.
  248. # close copy-paste augmentation
  249. if self.train_loader.dataset.copy_paste > 0.:
  250. print(' - Close < Copy-paste Augmentation > ...')
  251. self.train_loader.dataset.copy_paste = 0.