engine.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import random
  5. # ----------------- Extra Components -----------------
  6. from utils import distributed_utils
  7. from utils.misc import MetricLogger, SmoothedValue
  8. from utils.vis_tools import vis_data
  9. # ----------------- Optimizer & LrScheduler Components -----------------
  10. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  11. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
  12. class YoloTrainer(object):
  13. def __init__(self,
  14. # Basic parameters
  15. args,
  16. cfg,
  17. device,
  18. # Model parameters
  19. model,
  20. model_ema,
  21. criterion,
  22. # Data parameters
  23. train_transform,
  24. val_transform,
  25. dataset,
  26. train_loader,
  27. evaluator,
  28. ):
  29. # ------------------- basic parameters -------------------
  30. self.args = args
  31. self.cfg = cfg
  32. self.epoch = 0
  33. self.best_map = -1.
  34. self.device = device
  35. self.criterion = criterion
  36. self.heavy_eval = False
  37. self.model_ema = model_ema
  38. # weak augmentatino stage
  39. self.second_stage = False
  40. self.second_stage_epoch = cfg.no_aug_epoch
  41. # path to save model
  42. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  43. os.makedirs(self.path_to_save, exist_ok=True)
  44. # ---------------------------- Transform ----------------------------
  45. self.train_transform = train_transform
  46. self.val_transform = val_transform
  47. # ---------------------------- Dataset & Dataloader ----------------------------
  48. self.dataset = dataset
  49. self.train_loader = train_loader
  50. # ---------------------------- Evaluator ----------------------------
  51. self.evaluator = evaluator
  52. # ---------------------------- Build Grad. Scaler ----------------------------
  53. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  54. # ---------------------------- Build Optimizer ----------------------------
  55. cfg.base_lr = cfg.per_image_lr * args.batch_size
  56. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  57. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  58. # ---------------------------- Build LR Scheduler ----------------------------
  59. warmup_iters = cfg.warmup_epoch * len(self.train_loader)
  60. self.lr_scheduler_warmup = LinearWarmUpLrScheduler(warmup_iters, cfg.base_lr, cfg.warmup_bias_lr)
  61. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  62. def train(self, model):
  63. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  64. if self.args.distributed:
  65. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  66. # check second stage
  67. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  68. self.check_second_stage()
  69. # save model of the last mosaic epoch
  70. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  71. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  72. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  73. torch.save({'model': model.state_dict(),
  74. 'mAP': round(self.evaluator.map*100, 1),
  75. 'optimizer': self.optimizer.state_dict(),
  76. 'epoch': self.epoch,
  77. 'args': self.args},
  78. checkpoint_path)
  79. # train one epoch
  80. self.epoch = epoch
  81. self.train_one_epoch(model)
  82. # LR Schedule
  83. if (epoch + 1) > self.cfg.warmup_epoch:
  84. self.lr_scheduler.step()
  85. # eval one epoch
  86. if self.heavy_eval:
  87. model_eval = model.module if self.args.distributed else model
  88. self.eval(model_eval)
  89. else:
  90. model_eval = model.module if self.args.distributed else model
  91. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  92. self.eval(model_eval)
  93. if self.args.debug:
  94. print("For debug mode, we only train 1 epoch")
  95. break
  96. def eval(self, model):
  97. # set eval mode
  98. model.eval()
  99. model_eval = model if self.model_ema is None else self.model_ema.ema
  100. cur_map = -1.
  101. to_save = False
  102. if distributed_utils.is_main_process():
  103. if self.evaluator is None:
  104. print('No evaluator ... save model and go on training.')
  105. to_save = True
  106. weight_name = '{}_no_eval.pth'.format(self.args.model)
  107. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  108. else:
  109. print('Eval ...')
  110. # Evaluate
  111. with torch.no_grad():
  112. self.evaluator.evaluate(model_eval)
  113. cur_map = self.evaluator.map
  114. if cur_map > self.best_map:
  115. # update best-map
  116. self.best_map = cur_map
  117. to_save = True
  118. # Save model
  119. if to_save:
  120. print('Saving state, epoch:', self.epoch)
  121. weight_name = '{}_best.pth'.format(self.args.model)
  122. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  123. state_dicts = {
  124. 'model': model_eval.state_dict(),
  125. 'mAP': round(cur_map*100, 1),
  126. 'optimizer': self.optimizer.state_dict(),
  127. 'lr_scheduler': self.lr_scheduler.state_dict(),
  128. 'epoch': self.epoch,
  129. 'args': self.args,
  130. }
  131. if self.model_ema is not None:
  132. state_dicts["ema_updates"] = self.model_ema.updates
  133. torch.save(state_dicts, checkpoint_path)
  134. if self.args.distributed:
  135. # wait for all processes to synchronize
  136. dist.barrier()
  137. # set train mode.
  138. model.train()
  139. def train_one_epoch(self, model):
  140. metric_logger = MetricLogger(delimiter=" ")
  141. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  142. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  143. metric_logger.add_meter('gnorm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  144. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  145. epoch_size = len(self.train_loader)
  146. print_freq = 10
  147. gnorm = 0.0
  148. # basic parameters
  149. epoch_size = len(self.train_loader)
  150. img_size = self.cfg.train_img_size
  151. nw = epoch_size * self.cfg.warmup_epoch
  152. # Train one epoch
  153. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  154. ni = iter_i + self.epoch * epoch_size
  155. # Warmup
  156. if nw > 0 and ni < nw:
  157. self.lr_scheduler_warmup(ni, self.optimizer)
  158. elif ni == nw:
  159. print("Warmup stage is over.")
  160. self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
  161. # To device
  162. images = images.to(self.device, non_blocking=True).float()
  163. # Multi scale
  164. images, targets, img_size = self.rescale_image_targets(
  165. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  166. # Visualize train targets
  167. if self.args.vis_tgt:
  168. vis_data(images,
  169. targets,
  170. self.cfg.num_classes,
  171. self.cfg.normalize_coords,
  172. self.train_transform.color_format,
  173. self.cfg.pixel_mean,
  174. self.cfg.pixel_std,
  175. self.cfg.box_format)
  176. # Inference
  177. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  178. outputs = model(images)
  179. # Compute loss
  180. loss_dict = self.criterion(outputs=outputs, targets=targets)
  181. losses = loss_dict['losses']
  182. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  183. # Backward
  184. self.scaler.scale(losses).backward()
  185. # Optimize
  186. if self.cfg.clip_max_norm > 0:
  187. self.scaler.unscale_(self.optimizer)
  188. gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  189. self.scaler.step(self.optimizer)
  190. self.scaler.update()
  191. self.optimizer.zero_grad()
  192. # ModelEMA
  193. if self.model_ema is not None:
  194. self.model_ema.update(model)
  195. # Update log
  196. metric_logger.update(**loss_dict_reduced)
  197. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  198. metric_logger.update(size=img_size)
  199. metric_logger.update(gnorm=gnorm)
  200. if self.args.debug:
  201. print("For debug mode, we only train 1 iteration")
  202. break
  203. # Gather the stats from all processes
  204. metric_logger.synchronize_between_processes()
  205. print("Averaged stats:", metric_logger)
  206. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  207. """
  208. Deployed for Multi scale trick.
  209. """
  210. # During training phase, the shape of input image is square.
  211. old_img_size = images.shape[-1]
  212. min_img_size = old_img_size * multi_scale_range[0]
  213. max_img_size = old_img_size * multi_scale_range[1]
  214. # Choose a new image size
  215. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  216. # Resize
  217. if new_img_size != old_img_size:
  218. # interpolate
  219. images = torch.nn.functional.interpolate(
  220. input=images,
  221. size=new_img_size,
  222. mode='bilinear',
  223. align_corners=False)
  224. # rescale targets
  225. if not self.cfg.normalize_coords:
  226. for tgt in targets:
  227. boxes = tgt["boxes"].clone()
  228. labels = tgt["labels"].clone()
  229. boxes = torch.clamp(boxes, 0, old_img_size)
  230. # rescale box
  231. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  232. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  233. # refine tgt
  234. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  235. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  236. keep = (min_tgt_size >= 1)
  237. tgt["boxes"] = boxes[keep]
  238. tgt["labels"] = labels[keep]
  239. return images, targets, new_img_size
  240. def check_second_stage(self):
  241. # set second stage
  242. print('============== Second stage of Training ==============')
  243. self.second_stage = True
  244. self.heavy_eval = True
  245. # close mosaic augmentation
  246. if self.train_loader.dataset.mosaic_prob > 0.:
  247. print(' - Close < Mosaic Augmentation > ...')
  248. self.train_loader.dataset.mosaic_prob = 0.
  249. # close mixup augmentation
  250. if self.train_loader.dataset.mixup_prob > 0.:
  251. print(' - Close < Mixup Augmentation > ...')
  252. self.train_loader.dataset.mixup_prob = 0.
  253. # close copy-paste augmentation
  254. if self.train_loader.dataset.copy_paste > 0.:
  255. print(' - Close < Copy-paste Augmentation > ...')
  256. self.train_loader.dataset.copy_paste = 0.
  257. class RTDetrTrainer(object):
  258. def __init__(self,
  259. # Basic parameters
  260. args,
  261. cfg,
  262. device,
  263. # Model parameters
  264. model,
  265. model_ema,
  266. criterion,
  267. # Data parameters
  268. train_transform,
  269. val_transform,
  270. dataset,
  271. train_loader,
  272. evaluator,
  273. ):
  274. # ------------------- basic parameters -------------------
  275. self.args = args
  276. self.cfg = cfg
  277. self.epoch = 0
  278. self.best_map = -1.
  279. self.device = device
  280. self.criterion = criterion
  281. self.heavy_eval = False
  282. self.model_ema = model_ema
  283. # path to save model
  284. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  285. os.makedirs(self.path_to_save, exist_ok=True)
  286. # ---------------------------- Transform ----------------------------
  287. self.train_transform = train_transform
  288. self.val_transform = val_transform
  289. # ---------------------------- Dataset & Dataloader ----------------------------
  290. self.dataset = dataset
  291. self.train_loader = train_loader
  292. # ---------------------------- Evaluator ----------------------------
  293. self.evaluator = evaluator
  294. # ---------------------------- Build Grad. Scaler ----------------------------
  295. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  296. # ---------------------------- Build Optimizer ----------------------------
  297. self.grad_accumulate = max(16 // args.batch_size, 1)
  298. cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
  299. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  300. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  301. # ---------------------------- Build LR Scheduler ----------------------------
  302. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.warmup_iters, cfg.base_lr)
  303. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  304. def train(self, model):
  305. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  306. if self.args.distributed:
  307. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  308. # train one epoch
  309. self.epoch = epoch
  310. self.train_one_epoch(model)
  311. # LR Scheduler
  312. self.lr_scheduler.step()
  313. # eval one epoch
  314. if self.heavy_eval:
  315. model_eval = model.module if self.args.distributed else model
  316. self.eval(model_eval)
  317. else:
  318. model_eval = model.module if self.args.distributed else model
  319. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  320. self.eval(model_eval)
  321. if self.args.debug:
  322. print("For debug mode, we only train 1 epoch")
  323. break
  324. def eval(self, model):
  325. # set eval mode
  326. model.eval()
  327. model_eval = model if self.model_ema is None else self.model_ema.ema
  328. cur_map = -1.
  329. to_save = False
  330. if distributed_utils.is_main_process():
  331. if self.evaluator is None:
  332. print('No evaluator ... save model and go on training.')
  333. to_save = True
  334. weight_name = '{}_no_eval.pth'.format(self.args.model)
  335. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  336. else:
  337. print('Eval ...')
  338. # Evaluate
  339. with torch.no_grad():
  340. self.evaluator.evaluate(model_eval)
  341. cur_map = self.evaluator.map
  342. if cur_map > self.best_map:
  343. # update best-map
  344. self.best_map = cur_map
  345. to_save = True
  346. # Save model
  347. if to_save:
  348. print('Saving state, epoch:', self.epoch)
  349. weight_name = '{}_best.pth'.format(self.args.model)
  350. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  351. state_dicts = {
  352. 'model': model_eval.state_dict(),
  353. 'mAP': round(cur_map*100, 1),
  354. 'optimizer': self.optimizer.state_dict(),
  355. 'lr_scheduler': self.lr_scheduler.state_dict(),
  356. 'epoch': self.epoch,
  357. 'args': self.args,
  358. }
  359. if self.model_ema is not None:
  360. state_dicts["ema_updates"] = self.model_ema.updates
  361. torch.save(state_dicts, checkpoint_path)
  362. if self.args.distributed:
  363. # wait for all processes to synchronize
  364. dist.barrier()
  365. # set train mode.
  366. model.train()
  367. def train_one_epoch(self, model):
  368. metric_logger = MetricLogger(delimiter=" ")
  369. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  370. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  371. metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  372. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  373. epoch_size = len(self.train_loader)
  374. print_freq = 10
  375. # basic parameters
  376. epoch_size = len(self.train_loader)
  377. img_size = self.cfg.train_img_size
  378. nw = self.cfg.warmup_iters
  379. lr_warmup_stage = True
  380. # Train one epoch
  381. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  382. ni = iter_i + self.epoch * epoch_size
  383. # WarmUp
  384. if ni < nw and lr_warmup_stage:
  385. self.wp_lr_scheduler(ni, self.optimizer)
  386. elif ni == nw and lr_warmup_stage:
  387. print('Warmup stage is over.')
  388. lr_warmup_stage = False
  389. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  390. # To device
  391. images = images.to(self.device, non_blocking=True).float()
  392. for tgt in targets:
  393. tgt['boxes'] = tgt['boxes'].to(self.device)
  394. tgt['labels'] = tgt['labels'].to(self.device)
  395. # Multi scale
  396. images, targets, img_size = self.rescale_image_targets(
  397. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  398. # Visualize train targets
  399. if self.args.vis_tgt:
  400. vis_data(images,
  401. targets,
  402. self.cfg.num_classes,
  403. self.cfg.normalize_coords,
  404. self.train_transform.color_format,
  405. self.cfg.pixel_mean,
  406. self.cfg.pixel_std,
  407. self.cfg.box_format)
  408. # Inference
  409. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  410. outputs = model(images, targets)
  411. loss_dict = self.criterion(outputs, targets)
  412. losses = sum(loss_dict.values())
  413. losses /= self.grad_accumulate
  414. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  415. # Backward
  416. self.scaler.scale(losses).backward()
  417. # Optimize
  418. if (iter_i + 1) % self.grad_accumulate == 0:
  419. if self.cfg.clip_max_norm > 0:
  420. self.scaler.unscale_(self.optimizer)
  421. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  422. self.scaler.step(self.optimizer)
  423. self.scaler.update()
  424. self.optimizer.zero_grad()
  425. # ModelEMA
  426. if self.model_ema is not None:
  427. self.model_ema.update(model)
  428. # Update log
  429. metric_logger.update(**loss_dict_reduced)
  430. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  431. metric_logger.update(size=img_size)
  432. if self.args.debug:
  433. print("For debug mode, we only train 1 iteration")
  434. break
  435. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  436. """
  437. Deployed for Multi scale trick.
  438. """
  439. # During training phase, the shape of input image is square.
  440. old_img_size = images.shape[-1]
  441. min_img_size = old_img_size * multi_scale_range[0]
  442. max_img_size = old_img_size * multi_scale_range[1]
  443. # Choose a new image size
  444. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  445. # Resize
  446. if new_img_size != old_img_size:
  447. # interpolate
  448. images = torch.nn.functional.interpolate(
  449. input=images,
  450. size=new_img_size,
  451. mode='bilinear',
  452. align_corners=False)
  453. return images, targets, new_img_size
  454. # Build Trainer
  455. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  456. # ----------------------- Det trainers -----------------------
  457. if cfg.trainer == 'yolo':
  458. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  459. elif cfg.trainer == 'rtdetr':
  460. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  461. else:
  462. raise NotImplementedError(cfg.trainer)