engine.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import random
  5. # ----------------- Extra Components -----------------
  6. from utils import distributed_utils
  7. from utils.misc import MetricLogger, SmoothedValue
  8. from utils.vis_tools import vis_data
  9. # ----------------- Optimizer & LrScheduler Components -----------------
  10. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  11. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
  12. class YoloTrainer(object):
  13. def __init__(self,
  14. # Basic parameters
  15. args,
  16. cfg,
  17. device,
  18. # Model parameters
  19. model,
  20. model_ema,
  21. criterion,
  22. # Data parameters
  23. train_transform,
  24. val_transform,
  25. dataset,
  26. train_loader,
  27. evaluator,
  28. ):
  29. # ------------------- basic parameters -------------------
  30. self.args = args
  31. self.cfg = cfg
  32. self.epoch = 0
  33. self.best_map = -1.
  34. self.device = device
  35. self.criterion = criterion
  36. self.heavy_eval = False
  37. self.model_ema = model_ema
  38. # weak augmentatino stage
  39. self.second_stage = False
  40. self.second_stage_epoch = cfg.no_aug_epoch
  41. # path to save model
  42. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  43. os.makedirs(self.path_to_save, exist_ok=True)
  44. # ---------------------------- Transform ----------------------------
  45. self.train_transform = train_transform
  46. self.val_transform = val_transform
  47. # ---------------------------- Dataset & Dataloader ----------------------------
  48. self.dataset = dataset
  49. self.train_loader = train_loader
  50. # ---------------------------- Evaluator ----------------------------
  51. self.evaluator = evaluator
  52. # ---------------------------- Build Grad. Scaler ----------------------------
  53. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  54. # ---------------------------- Build Optimizer ----------------------------
  55. self.grad_accumulate = max(64 // args.batch_size, 1)
  56. cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
  57. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  58. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  59. # ---------------------------- Build LR Scheduler ----------------------------
  60. warmup_iters = cfg.warmup_epoch * len(self.train_loader)
  61. self.lr_scheduler_warmup = LinearWarmUpLrScheduler(warmup_iters, cfg.base_lr, cfg.warmup_bias_lr)
  62. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  63. def train(self, model):
  64. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  65. if self.args.distributed:
  66. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  67. # check second stage
  68. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  69. self.check_second_stage()
  70. # save model of the last mosaic epoch
  71. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  72. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  73. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  74. torch.save({'model': model.state_dict(),
  75. 'mAP': round(self.evaluator.map*100, 1),
  76. 'optimizer': self.optimizer.state_dict(),
  77. 'epoch': self.epoch,
  78. 'args': self.args},
  79. checkpoint_path)
  80. # train one epoch
  81. self.epoch = epoch
  82. self.train_one_epoch(model)
  83. # LR Schedule
  84. if (epoch + 1) > self.cfg.warmup_epoch:
  85. self.lr_scheduler.step()
  86. # eval one epoch
  87. if self.heavy_eval:
  88. model_eval = model.module if self.args.distributed else model
  89. self.eval(model_eval)
  90. else:
  91. model_eval = model.module if self.args.distributed else model
  92. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  93. self.eval(model_eval)
  94. if self.args.debug:
  95. print("For debug mode, we only train 1 epoch")
  96. break
  97. def eval(self, model):
  98. # set eval mode
  99. model.eval()
  100. model_eval = model if self.model_ema is None else self.model_ema.ema
  101. cur_map = -1.
  102. to_save = False
  103. if distributed_utils.is_main_process():
  104. if self.evaluator is None:
  105. print('No evaluator ... save model and go on training.')
  106. to_save = True
  107. weight_name = '{}_no_eval.pth'.format(self.args.model)
  108. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  109. else:
  110. print('Eval ...')
  111. # Evaluate
  112. with torch.no_grad():
  113. self.evaluator.evaluate(model_eval)
  114. cur_map = self.evaluator.map
  115. if cur_map > self.best_map:
  116. # update best-map
  117. self.best_map = cur_map
  118. to_save = True
  119. # Save model
  120. if to_save:
  121. print('Saving state, epoch:', self.epoch)
  122. weight_name = '{}_best.pth'.format(self.args.model)
  123. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  124. state_dicts = {
  125. 'model': model_eval.state_dict(),
  126. 'mAP': round(cur_map*100, 1),
  127. 'optimizer': self.optimizer.state_dict(),
  128. 'lr_scheduler': self.lr_scheduler.state_dict(),
  129. 'epoch': self.epoch,
  130. 'args': self.args,
  131. }
  132. if self.model_ema is not None:
  133. state_dicts["ema_updates"] = self.model_ema.updates
  134. torch.save(state_dicts, checkpoint_path)
  135. if self.args.distributed:
  136. # wait for all processes to synchronize
  137. dist.barrier()
  138. # set train mode.
  139. model.train()
  140. def train_one_epoch(self, model):
  141. metric_logger = MetricLogger(delimiter=" ")
  142. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  143. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  144. metric_logger.add_meter('gnorm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  145. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  146. epoch_size = len(self.train_loader)
  147. print_freq = 10
  148. gnorm = 0.0
  149. # basic parameters
  150. epoch_size = len(self.train_loader)
  151. img_size = self.cfg.train_img_size
  152. nw = epoch_size * self.cfg.warmup_epoch
  153. # Train one epoch
  154. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  155. ni = iter_i + self.epoch * epoch_size
  156. # Warmup
  157. if nw > 0 and ni < nw:
  158. self.lr_scheduler_warmup(ni, self.optimizer)
  159. elif ni == nw:
  160. print("Warmup stage is over.")
  161. self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
  162. # To device
  163. images = images.to(self.device, non_blocking=True).float()
  164. # Multi scale
  165. images, targets, img_size = self.rescale_image_targets(
  166. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  167. # Visualize train targets
  168. if self.args.vis_tgt:
  169. vis_data(images,
  170. targets,
  171. self.cfg.num_classes,
  172. self.cfg.normalize_coords,
  173. self.train_transform.color_format,
  174. self.cfg.pixel_mean,
  175. self.cfg.pixel_std,
  176. self.cfg.box_format)
  177. # Inference
  178. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  179. outputs = model(images)
  180. # Compute loss
  181. loss_dict = self.criterion(outputs=outputs, targets=targets)
  182. losses = loss_dict['losses']
  183. losses /= self.grad_accumulate
  184. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  185. # Backward
  186. self.scaler.scale(losses).backward()
  187. # Optimize
  188. if (iter_i + 1) % self.grad_accumulate == 0:
  189. if self.cfg.clip_max_norm > 0:
  190. self.scaler.unscale_(self.optimizer)
  191. gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  192. self.scaler.step(self.optimizer)
  193. self.scaler.update()
  194. self.optimizer.zero_grad()
  195. # ModelEMA
  196. if self.model_ema is not None:
  197. self.model_ema.update(model)
  198. # Update log
  199. metric_logger.update(**loss_dict_reduced)
  200. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  201. metric_logger.update(size=img_size)
  202. metric_logger.update(gnorm=gnorm)
  203. if self.args.debug:
  204. print("For debug mode, we only train 1 iteration")
  205. break
  206. # Gather the stats from all processes
  207. metric_logger.synchronize_between_processes()
  208. print("Averaged stats:", metric_logger)
  209. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  210. """
  211. Deployed for Multi scale trick.
  212. """
  213. # During training phase, the shape of input image is square.
  214. old_img_size = images.shape[-1]
  215. min_img_size = old_img_size * multi_scale_range[0]
  216. max_img_size = old_img_size * multi_scale_range[1]
  217. # Choose a new image size
  218. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  219. # Resize
  220. if new_img_size != old_img_size:
  221. # interpolate
  222. images = torch.nn.functional.interpolate(
  223. input=images,
  224. size=new_img_size,
  225. mode='bilinear',
  226. align_corners=False)
  227. # rescale targets
  228. if not self.cfg.normalize_coords:
  229. for tgt in targets:
  230. boxes = tgt["boxes"].clone()
  231. labels = tgt["labels"].clone()
  232. boxes = torch.clamp(boxes, 0, old_img_size)
  233. # rescale box
  234. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  235. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  236. # refine tgt
  237. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  238. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  239. keep = (min_tgt_size >= 8)
  240. tgt["boxes"] = boxes[keep]
  241. tgt["labels"] = labels[keep]
  242. return images, targets, new_img_size
  243. def check_second_stage(self):
  244. # set second stage
  245. print('============== Second stage of Training ==============')
  246. self.second_stage = True
  247. self.heavy_eval = True
  248. # close mosaic augmentation
  249. if self.train_loader.dataset.mosaic_prob > 0.:
  250. print(' - Close < Mosaic Augmentation > ...')
  251. self.train_loader.dataset.mosaic_prob = 0.
  252. # close mixup augmentation
  253. if self.train_loader.dataset.mixup_prob > 0.:
  254. print(' - Close < Mixup Augmentation > ...')
  255. self.train_loader.dataset.mixup_prob = 0.
  256. # close copy-paste augmentation
  257. if self.train_loader.dataset.copy_paste > 0.:
  258. print(' - Close < Copy-paste Augmentation > ...')
  259. self.train_loader.dataset.copy_paste = 0.
  260. class RTDetrTrainer(object):
  261. def __init__(self,
  262. # Basic parameters
  263. args,
  264. cfg,
  265. device,
  266. # Model parameters
  267. model,
  268. model_ema,
  269. criterion,
  270. # Data parameters
  271. train_transform,
  272. val_transform,
  273. dataset,
  274. train_loader,
  275. evaluator,
  276. ):
  277. # ------------------- basic parameters -------------------
  278. self.args = args
  279. self.cfg = cfg
  280. self.epoch = 0
  281. self.best_map = -1.
  282. self.device = device
  283. self.criterion = criterion
  284. self.heavy_eval = False
  285. self.model_ema = model_ema
  286. # path to save model
  287. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  288. os.makedirs(self.path_to_save, exist_ok=True)
  289. # ---------------------------- Transform ----------------------------
  290. self.train_transform = train_transform
  291. self.val_transform = val_transform
  292. # ---------------------------- Dataset & Dataloader ----------------------------
  293. self.dataset = dataset
  294. self.train_loader = train_loader
  295. # ---------------------------- Evaluator ----------------------------
  296. self.evaluator = evaluator
  297. # ---------------------------- Build Grad. Scaler ----------------------------
  298. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  299. # ---------------------------- Build Optimizer ----------------------------
  300. self.grad_accumulate = max(16 // args.batch_size, 1)
  301. cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
  302. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  303. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  304. # ---------------------------- Build LR Scheduler ----------------------------
  305. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.warmup_iters, cfg.base_lr)
  306. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  307. def train(self, model):
  308. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  309. if self.args.distributed:
  310. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  311. # train one epoch
  312. self.epoch = epoch
  313. self.train_one_epoch(model)
  314. # LR Scheduler
  315. self.lr_scheduler.step()
  316. # eval one epoch
  317. if self.heavy_eval:
  318. model_eval = model.module if self.args.distributed else model
  319. self.eval(model_eval)
  320. else:
  321. model_eval = model.module if self.args.distributed else model
  322. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  323. self.eval(model_eval)
  324. if self.args.debug:
  325. print("For debug mode, we only train 1 epoch")
  326. break
  327. def eval(self, model):
  328. # set eval mode
  329. model.eval()
  330. model_eval = model if self.model_ema is None else self.model_ema.ema
  331. cur_map = -1.
  332. to_save = False
  333. if distributed_utils.is_main_process():
  334. if self.evaluator is None:
  335. print('No evaluator ... save model and go on training.')
  336. to_save = True
  337. weight_name = '{}_no_eval.pth'.format(self.args.model)
  338. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  339. else:
  340. print('Eval ...')
  341. # Evaluate
  342. with torch.no_grad():
  343. self.evaluator.evaluate(model_eval)
  344. cur_map = self.evaluator.map
  345. if cur_map > self.best_map:
  346. # update best-map
  347. self.best_map = cur_map
  348. to_save = True
  349. # Save model
  350. if to_save:
  351. print('Saving state, epoch:', self.epoch)
  352. weight_name = '{}_best.pth'.format(self.args.model)
  353. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  354. state_dicts = {
  355. 'model': model_eval.state_dict(),
  356. 'mAP': round(cur_map*100, 1),
  357. 'optimizer': self.optimizer.state_dict(),
  358. 'lr_scheduler': self.lr_scheduler.state_dict(),
  359. 'epoch': self.epoch,
  360. 'args': self.args,
  361. }
  362. if self.model_ema is not None:
  363. state_dicts["ema_updates"] = self.model_ema.updates
  364. torch.save(state_dicts, checkpoint_path)
  365. if self.args.distributed:
  366. # wait for all processes to synchronize
  367. dist.barrier()
  368. # set train mode.
  369. model.train()
  370. def train_one_epoch(self, model):
  371. metric_logger = MetricLogger(delimiter=" ")
  372. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  373. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  374. metric_logger.add_meter('gnorm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  375. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  376. epoch_size = len(self.train_loader)
  377. print_freq = 10
  378. gnorm = 0.0
  379. # basic parameters
  380. epoch_size = len(self.train_loader)
  381. img_size = self.cfg.train_img_size
  382. nw = self.cfg.warmup_iters
  383. lr_warmup_stage = True
  384. # Train one epoch
  385. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  386. ni = iter_i + self.epoch * epoch_size
  387. # WarmUp
  388. if lr_warmup_stage:
  389. if ni % self.grad_accumulate == 0:
  390. ni = ni // self.grad_accumulate
  391. if ni < nw:
  392. self.wp_lr_scheduler(ni, self.optimizer)
  393. elif ni == nw and lr_warmup_stage:
  394. print('Warmup stage is over.')
  395. lr_warmup_stage = False
  396. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  397. # To device
  398. images = images.to(self.device, non_blocking=True).float()
  399. for tgt in targets:
  400. tgt['boxes'] = tgt['boxes'].to(self.device)
  401. tgt['labels'] = tgt['labels'].to(self.device)
  402. # Multi scale
  403. images, targets, img_size = self.rescale_image_targets(
  404. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  405. # Visualize train targets
  406. if self.args.vis_tgt:
  407. vis_data(images,
  408. targets,
  409. self.cfg.num_classes,
  410. self.cfg.normalize_coords,
  411. self.train_transform.color_format,
  412. self.cfg.pixel_mean,
  413. self.cfg.pixel_std,
  414. self.cfg.box_format)
  415. # Inference
  416. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  417. outputs = model(images, targets)
  418. # Compute loss
  419. loss_dict = self.criterion(outputs, targets)
  420. losses = sum(loss_dict.values())
  421. losses /= self.grad_accumulate
  422. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  423. # Backward
  424. self.scaler.scale(losses).backward()
  425. # Optimize
  426. if (iter_i + 1) % self.grad_accumulate == 0:
  427. if self.cfg.clip_max_norm > 0:
  428. self.scaler.unscale_(self.optimizer)
  429. gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  430. self.scaler.step(self.optimizer)
  431. self.scaler.update()
  432. self.optimizer.zero_grad()
  433. # ModelEMA
  434. if self.model_ema is not None:
  435. self.model_ema.update(model)
  436. # Update log
  437. metric_logger.update(**loss_dict_reduced)
  438. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  439. metric_logger.update(size=img_size)
  440. metric_logger.update(gnorm=gnorm)
  441. if self.args.debug:
  442. print("For debug mode, we only train 1 iteration")
  443. break
  444. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  445. """
  446. Deployed for Multi scale trick.
  447. """
  448. # During training phase, the shape of input image is square.
  449. old_img_size = images.shape[-1]
  450. min_img_size = old_img_size * multi_scale_range[0]
  451. max_img_size = old_img_size * multi_scale_range[1]
  452. # Choose a new image size
  453. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  454. # Resize
  455. if new_img_size != old_img_size:
  456. # interpolate
  457. images = torch.nn.functional.interpolate(
  458. input=images,
  459. size=new_img_size,
  460. mode='bilinear',
  461. align_corners=False)
  462. return images, targets, new_img_size
  463. # Build Trainer
  464. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  465. # ----------------------- Det trainers -----------------------
  466. if cfg.trainer == 'yolo':
  467. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  468. elif cfg.trainer == 'rtdetr':
  469. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  470. else:
  471. raise NotImplementedError(cfg.trainer)