engine.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import random
  5. # ----------------- Extra Components -----------------
  6. from utils import distributed_utils
  7. from utils.misc import MetricLogger, SmoothedValue
  8. from utils.vis_tools import vis_data
  9. # ----------------- Optimizer & LrScheduler Components -----------------
  10. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  11. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
  12. class YoloTrainer(object):
  13. def __init__(self,
  14. # Basic parameters
  15. args,
  16. cfg,
  17. device,
  18. # Model parameters
  19. model,
  20. model_ema,
  21. criterion,
  22. # Data parameters
  23. train_transform,
  24. val_transform,
  25. dataset,
  26. train_loader,
  27. evaluator,
  28. ):
  29. # ------------------- basic parameters -------------------
  30. self.args = args
  31. self.cfg = cfg
  32. self.epoch = 0
  33. self.best_map = -1.
  34. self.device = device
  35. self.criterion = criterion
  36. self.heavy_eval = False
  37. self.model_ema = model_ema
  38. # weak augmentatino stage
  39. self.second_stage = False
  40. self.second_stage_epoch = cfg.no_aug_epoch
  41. # path to save model
  42. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  43. os.makedirs(self.path_to_save, exist_ok=True)
  44. # ---------------------------- Transform ----------------------------
  45. self.train_transform = train_transform
  46. self.val_transform = val_transform
  47. # ---------------------------- Dataset & Dataloader ----------------------------
  48. self.dataset = dataset
  49. self.train_loader = train_loader
  50. # ---------------------------- Evaluator ----------------------------
  51. self.evaluator = evaluator
  52. # ---------------------------- Build Grad. Scaler ----------------------------
  53. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  54. # ---------------------------- Build Optimizer ----------------------------
  55. cfg.grad_accumulate = max(64 // args.batch_size, 1)
  56. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  57. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  58. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  59. # ---------------------------- Build LR Scheduler ----------------------------
  60. self.lr_scheduler_warmup = LinearWarmUpLrScheduler(cfg.base_lr, wp_iter=cfg.warmup_epoch * len(self.train_loader))
  61. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  62. # ---------------------------- Build Model-EMA ----------------------------
  63. if self.model_ema is not None:
  64. update_init = self.start_epoch * len(self.train_loader) // cfg.grad_accumulate
  65. print("Initialize ModelEMA's updates: {}".format(update_init))
  66. self.model_ema.updates = update_init
  67. def train(self, model):
  68. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  69. if self.args.distributed:
  70. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  71. # check second stage
  72. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  73. self.check_second_stage()
  74. # save model of the last mosaic epoch
  75. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  76. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  77. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  78. torch.save({'model': model.state_dict(),
  79. 'mAP': round(self.evaluator.map*100, 1),
  80. 'optimizer': self.optimizer.state_dict(),
  81. 'epoch': self.epoch,
  82. 'args': self.args},
  83. checkpoint_path)
  84. # train one epoch
  85. self.epoch = epoch
  86. self.train_one_epoch(model)
  87. # LR Schedule
  88. if (epoch + 1) > self.cfg.warmup_epoch:
  89. self.lr_scheduler.step()
  90. # eval one epoch
  91. if self.heavy_eval:
  92. model_eval = model.module if self.args.distributed else model
  93. self.eval(model_eval)
  94. else:
  95. model_eval = model.module if self.args.distributed else model
  96. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  97. self.eval(model_eval)
  98. if self.args.debug:
  99. print("For debug mode, we only train 1 epoch")
  100. break
  101. def eval(self, model):
  102. # set eval mode
  103. model.eval()
  104. model_eval = model if self.model_ema is None else self.model_ema.ema
  105. if distributed_utils.is_main_process():
  106. # check evaluator
  107. if self.evaluator is None:
  108. print('No evaluator ... save model and go on training.')
  109. print('Saving state, epoch: {}'.format(self.epoch))
  110. weight_name = '{}_no_eval.pth'.format(self.args.model)
  111. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  112. torch.save({'model': model_eval.state_dict(),
  113. 'mAP': -1.,
  114. 'optimizer': self.optimizer.state_dict(),
  115. 'lr_scheduler': self.lr_scheduler.state_dict(),
  116. 'epoch': self.epoch,
  117. 'args': self.args},
  118. checkpoint_path)
  119. else:
  120. print('eval ...')
  121. # evaluate
  122. with torch.no_grad():
  123. self.evaluator.evaluate(model_eval)
  124. # save model
  125. cur_map = self.evaluator.map
  126. if cur_map > self.best_map:
  127. # update best-map
  128. self.best_map = cur_map
  129. # save model
  130. print('Saving state, epoch:', self.epoch)
  131. weight_name = '{}_best.pth'.format(self.args.model)
  132. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  133. torch.save({'model': model_eval.state_dict(),
  134. 'mAP': round(self.best_map*100, 1),
  135. 'optimizer': self.optimizer.state_dict(),
  136. 'lr_scheduler': self.lr_scheduler.state_dict(),
  137. 'epoch': self.epoch,
  138. 'args': self.args},
  139. checkpoint_path)
  140. if self.args.distributed:
  141. # wait for all processes to synchronize
  142. dist.barrier()
  143. # set train mode.
  144. model.train()
  145. def train_one_epoch(self, model):
  146. metric_logger = MetricLogger(delimiter=" ")
  147. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  148. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  149. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  150. epoch_size = len(self.train_loader)
  151. print_freq = 10
  152. # basic parameters
  153. epoch_size = len(self.train_loader)
  154. img_size = self.cfg.train_img_size
  155. nw = epoch_size * self.cfg.warmup_epoch
  156. # Train one epoch
  157. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  158. ni = iter_i + self.epoch * epoch_size
  159. # Warmup
  160. if nw > 0 and ni < nw:
  161. self.lr_scheduler_warmup(ni, self.optimizer)
  162. elif ni == nw:
  163. print("Warmup stage is over.")
  164. self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
  165. # To device
  166. images = images.to(self.device, non_blocking=True).float()
  167. # Multi scale
  168. images, targets, img_size = self.rescale_image_targets(
  169. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  170. # Visualize train targets
  171. if self.args.vis_tgt:
  172. vis_data(images,
  173. targets,
  174. self.cfg.num_classes,
  175. self.cfg.normalize_coords,
  176. self.train_transform.color_format,
  177. self.cfg.pixel_mean,
  178. self.cfg.pixel_std,
  179. self.cfg.box_format)
  180. # Inference
  181. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  182. outputs = model(images)
  183. # Compute loss
  184. loss_dict = self.criterion(outputs=outputs, targets=targets)
  185. losses = loss_dict['losses']
  186. losses /= self.cfg.grad_accumulate
  187. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  188. # Backward
  189. self.scaler.scale(losses).backward()
  190. # Gradient clip
  191. if self.cfg.clip_max_norm > 0:
  192. self.scaler.unscale_(self.optimizer)
  193. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  194. # Optimize
  195. if (iter_i + 1) % self.cfg.grad_accumulate == 0:
  196. self.scaler.step(self.optimizer)
  197. self.scaler.update()
  198. self.optimizer.zero_grad()
  199. # ModelEMA
  200. if self.model_ema is not None:
  201. self.model_ema.update(model)
  202. # Update log
  203. metric_logger.update(**loss_dict_reduced)
  204. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  205. metric_logger.update(size=img_size)
  206. if self.args.debug:
  207. print("For debug mode, we only train 1 iteration")
  208. break
  209. # Gather the stats from all processes
  210. metric_logger.synchronize_between_processes()
  211. print("Averaged stats:", metric_logger)
  212. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  213. """
  214. Deployed for Multi scale trick.
  215. """
  216. # During training phase, the shape of input image is square.
  217. old_img_size = images.shape[-1]
  218. min_img_size = old_img_size * multi_scale_range[0]
  219. max_img_size = old_img_size * multi_scale_range[1]
  220. # Choose a new image size
  221. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  222. # Resize
  223. if new_img_size != old_img_size:
  224. # interpolate
  225. images = torch.nn.functional.interpolate(
  226. input=images,
  227. size=new_img_size,
  228. mode='bilinear',
  229. align_corners=False)
  230. # rescale targets
  231. if not self.cfg.normalize_coords:
  232. for tgt in targets:
  233. boxes = tgt["boxes"].clone()
  234. labels = tgt["labels"].clone()
  235. boxes = torch.clamp(boxes, 0, old_img_size)
  236. # rescale box
  237. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  238. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  239. # refine tgt
  240. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  241. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  242. keep = (min_tgt_size >= 1)
  243. tgt["boxes"] = boxes[keep]
  244. tgt["labels"] = labels[keep]
  245. return images, targets, new_img_size
  246. def check_second_stage(self):
  247. # set second stage
  248. print('============== Second stage of Training ==============')
  249. self.second_stage = True
  250. self.heavy_eval = True
  251. # close mosaic augmentation
  252. if self.train_loader.dataset.mosaic_prob > 0.:
  253. print(' - Close < Mosaic Augmentation > ...')
  254. self.train_loader.dataset.mosaic_prob = 0.
  255. # close mixup augmentation
  256. if self.train_loader.dataset.mixup_prob > 0.:
  257. print(' - Close < Mixup Augmentation > ...')
  258. self.train_loader.dataset.mixup_prob = 0.
  259. # close copy-paste augmentation
  260. if self.train_loader.dataset.copy_paste > 0.:
  261. print(' - Close < Copy-paste Augmentation > ...')
  262. self.train_loader.dataset.copy_paste = 0.
  263. class RTDetrTrainer(object):
  264. def __init__(self,
  265. # Basic parameters
  266. args,
  267. cfg,
  268. device,
  269. # Model parameters
  270. model,
  271. model_ema,
  272. criterion,
  273. # Data parameters
  274. train_transform,
  275. val_transform,
  276. dataset,
  277. train_loader,
  278. evaluator,
  279. ):
  280. # ------------------- basic parameters -------------------
  281. self.args = args
  282. self.cfg = cfg
  283. self.epoch = 0
  284. self.best_map = -1.
  285. self.device = device
  286. self.criterion = criterion
  287. self.heavy_eval = False
  288. self.model_ema = model_ema
  289. # path to save model
  290. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  291. os.makedirs(self.path_to_save, exist_ok=True)
  292. # ---------------------------- Transform ----------------------------
  293. self.train_transform = train_transform
  294. self.val_transform = val_transform
  295. # ---------------------------- Dataset & Dataloader ----------------------------
  296. self.dataset = dataset
  297. self.train_loader = train_loader
  298. # ---------------------------- Evaluator ----------------------------
  299. self.evaluator = evaluator
  300. # ---------------------------- Build Grad. Scaler ----------------------------
  301. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  302. # ---------------------------- Build Optimizer ----------------------------
  303. cfg.grad_accumulate = max(16 // args.batch_size, 1)
  304. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  305. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  306. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  307. # ---------------------------- Build LR Scheduler ----------------------------
  308. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.base_lr, wp_iter=cfg.warmup_iters)
  309. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  310. # ---------------------------- Build Model-EMA ----------------------------
  311. if self.model_ema is not None:
  312. update_init = self.start_epoch * len(self.train_loader) // cfg.grad_accumulate
  313. print("Initialize ModelEMA's updates: {}".format(update_init))
  314. self.model_ema.updates = update_init
  315. def train(self, model):
  316. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  317. if self.args.distributed:
  318. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  319. # train one epoch
  320. self.epoch = epoch
  321. self.train_one_epoch(model)
  322. # LR Scheduler
  323. self.lr_scheduler.step()
  324. # eval one epoch
  325. if self.heavy_eval:
  326. model_eval = model.module if self.args.distributed else model
  327. self.eval(model_eval)
  328. else:
  329. model_eval = model.module if self.args.distributed else model
  330. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  331. self.eval(model_eval)
  332. if self.args.debug:
  333. print("For debug mode, we only train 1 epoch")
  334. break
  335. def eval(self, model):
  336. # set eval mode
  337. model.eval()
  338. model_eval = model if self.model_ema is None else self.model_ema.ema
  339. if distributed_utils.is_main_process():
  340. # check evaluator
  341. if self.evaluator is None:
  342. print('No evaluator ... save model and go on training.')
  343. print('Saving state, epoch: {}'.format(self.epoch))
  344. weight_name = '{}_no_eval.pth'.format(self.args.model)
  345. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  346. torch.save({'model': model_eval.state_dict(),
  347. 'mAP': -1.,
  348. 'optimizer': self.optimizer.state_dict(),
  349. 'lr_scheduler': self.lr_scheduler.state_dict(),
  350. 'epoch': self.epoch,
  351. 'args': self.args},
  352. checkpoint_path)
  353. else:
  354. print('eval ...')
  355. # evaluate
  356. with torch.no_grad():
  357. self.evaluator.evaluate(model_eval)
  358. # save model
  359. cur_map = self.evaluator.map
  360. if cur_map > self.best_map:
  361. # update best-map
  362. self.best_map = cur_map
  363. # save model
  364. print('Saving state, epoch:', self.epoch)
  365. weight_name = '{}_best.pth'.format(self.args.model)
  366. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  367. torch.save({'model': model_eval.state_dict(),
  368. 'mAP': round(self.best_map*100, 1),
  369. 'optimizer': self.optimizer.state_dict(),
  370. 'lr_scheduler': self.lr_scheduler.state_dict(),
  371. 'epoch': self.epoch,
  372. 'args': self.args},
  373. checkpoint_path)
  374. if self.args.distributed:
  375. # wait for all processes to synchronize
  376. dist.barrier()
  377. # set train mode.
  378. model.train()
  379. def train_one_epoch(self, model):
  380. metric_logger = MetricLogger(delimiter=" ")
  381. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  382. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  383. metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  384. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  385. epoch_size = len(self.train_loader)
  386. print_freq = 10
  387. # basic parameters
  388. epoch_size = len(self.train_loader)
  389. img_size = self.cfg.train_img_size
  390. nw = self.cfg.warmup_iters
  391. lr_warmup_stage = True
  392. # Train one epoch
  393. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  394. ni = iter_i + self.epoch * epoch_size
  395. # WarmUp
  396. if ni < nw and lr_warmup_stage:
  397. self.wp_lr_scheduler(ni, self.optimizer)
  398. elif ni == nw and lr_warmup_stage:
  399. print('Warmup stage is over.')
  400. lr_warmup_stage = False
  401. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  402. # To device
  403. images = images.to(self.device, non_blocking=True).float()
  404. for tgt in targets:
  405. tgt['boxes'] = tgt['boxes'].to(self.device)
  406. tgt['labels'] = tgt['labels'].to(self.device)
  407. # Multi scale
  408. images, targets, img_size = self.rescale_image_targets(
  409. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  410. # Visualize train targets
  411. if self.args.vis_tgt:
  412. vis_data(images,
  413. targets,
  414. self.cfg.num_classes,
  415. self.cfg.normalize_coords,
  416. self.train_transform.color_format,
  417. self.cfg.pixel_mean,
  418. self.cfg.pixel_std,
  419. self.cfg.box_format)
  420. # Inference
  421. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  422. outputs = model(images, targets)
  423. loss_dict = self.criterion(outputs, targets)
  424. losses = sum(loss_dict.values())
  425. losses /= self.cfg.grad_accumulate
  426. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  427. # Backward
  428. self.scaler.scale(losses).backward()
  429. # Gradient clip
  430. grad_norm = None
  431. if self.cfg.clip_max_norm > 0:
  432. self.scaler.unscale_(self.optimizer)
  433. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  434. # Optimize
  435. if (iter_i + 1) % self.cfg.grad_accumulate == 0:
  436. self.scaler.step(self.optimizer)
  437. self.scaler.update()
  438. self.optimizer.zero_grad()
  439. # ModelEMA
  440. if self.model_ema is not None:
  441. self.model_ema.update(model)
  442. # Update log
  443. metric_logger.update(loss=losses.item(), **loss_dict_reduced)
  444. metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
  445. metric_logger.update(grad_norm=grad_norm)
  446. metric_logger.update(size=img_size)
  447. if self.args.debug:
  448. print("For debug mode, we only train 1 iteration")
  449. break
  450. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  451. """
  452. Deployed for Multi scale trick.
  453. """
  454. # During training phase, the shape of input image is square.
  455. old_img_size = images.shape[-1]
  456. min_img_size = old_img_size * multi_scale_range[0]
  457. max_img_size = old_img_size * multi_scale_range[1]
  458. # Choose a new image size
  459. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  460. # Resize
  461. if new_img_size != old_img_size:
  462. # interpolate
  463. images = torch.nn.functional.interpolate(
  464. input=images,
  465. size=new_img_size,
  466. mode='bilinear',
  467. align_corners=False)
  468. return images, targets, new_img_size
  469. # Build Trainer
  470. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  471. # ----------------------- Det trainers -----------------------
  472. if cfg.trainer == 'yolo':
  473. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  474. elif cfg.trainer == 'rtdetr':
  475. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  476. else:
  477. raise NotImplementedError(cfg.trainer)