engine.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import random
  5. # ----------------- Extra Components -----------------
  6. from utils import distributed_utils
  7. from utils.misc import MetricLogger, SmoothedValue
  8. from utils.vis_tools import vis_data
  9. # ----------------- Optimizer & LrScheduler Components -----------------
  10. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  11. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
  12. class YoloTrainer(object):
  13. def __init__(self,
  14. # Basic parameters
  15. args,
  16. cfg,
  17. device,
  18. # Model parameters
  19. model,
  20. model_ema,
  21. criterion,
  22. # Data parameters
  23. train_transform,
  24. val_transform,
  25. dataset,
  26. train_loader,
  27. evaluator,
  28. ):
  29. # ------------------- basic parameters -------------------
  30. self.args = args
  31. self.cfg = cfg
  32. self.epoch = 0
  33. self.best_map = -1.
  34. self.device = device
  35. self.criterion = criterion
  36. self.heavy_eval = False
  37. self.model_ema = model_ema
  38. # weak augmentatino stage
  39. self.second_stage = False
  40. self.second_stage_epoch = cfg.no_aug_epoch
  41. # path to save model
  42. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  43. os.makedirs(self.path_to_save, exist_ok=True)
  44. # ---------------------------- Transform ----------------------------
  45. self.train_transform = train_transform
  46. self.val_transform = val_transform
  47. # ---------------------------- Dataset & Dataloader ----------------------------
  48. self.dataset = dataset
  49. self.train_loader = train_loader
  50. # ---------------------------- Evaluator ----------------------------
  51. self.evaluator = evaluator
  52. # ---------------------------- Build Grad. Scaler ----------------------------
  53. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  54. # ---------------------------- Build Optimizer ----------------------------
  55. cfg.grad_accumulate = max(64 // args.batch_size, 1, args.grad_accumulate)
  56. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  57. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  58. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  59. # ---------------------------- Build LR Scheduler ----------------------------
  60. warmup_iters = cfg.warmup_epoch * len(self.train_loader)
  61. self.lr_scheduler_warmup = LinearWarmUpLrScheduler(warmup_iters, cfg.base_lr, cfg.warmup_bias_lr, cfg.warmup_momentum)
  62. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  63. # ---------------------------- Build Model-EMA ----------------------------
  64. if self.model_ema is not None:
  65. update_init = self.start_epoch * len(self.train_loader) // cfg.grad_accumulate
  66. print("Initialize ModelEMA's updates: {}".format(update_init))
  67. self.model_ema.updates = update_init
  68. def train(self, model):
  69. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  70. if self.args.distributed:
  71. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  72. # check second stage
  73. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  74. self.check_second_stage()
  75. # save model of the last mosaic epoch
  76. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  77. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  78. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  79. torch.save({'model': model.state_dict(),
  80. 'mAP': round(self.evaluator.map*100, 1),
  81. 'optimizer': self.optimizer.state_dict(),
  82. 'epoch': self.epoch,
  83. 'args': self.args},
  84. checkpoint_path)
  85. # train one epoch
  86. self.epoch = epoch
  87. self.train_one_epoch(model)
  88. # LR Schedule
  89. if (epoch + 1) > self.cfg.warmup_epoch:
  90. self.lr_scheduler.step()
  91. # eval one epoch
  92. if self.heavy_eval:
  93. model_eval = model.module if self.args.distributed else model
  94. self.eval(model_eval)
  95. else:
  96. model_eval = model.module if self.args.distributed else model
  97. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  98. self.eval(model_eval)
  99. if self.args.debug:
  100. print("For debug mode, we only train 1 epoch")
  101. break
  102. def eval(self, model):
  103. # set eval mode
  104. model.eval()
  105. model_eval = model if self.model_ema is None else self.model_ema.ema
  106. if distributed_utils.is_main_process():
  107. # check evaluator
  108. if self.evaluator is None:
  109. print('No evaluator ... save model and go on training.')
  110. print('Saving state, epoch: {}'.format(self.epoch))
  111. weight_name = '{}_no_eval.pth'.format(self.args.model)
  112. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  113. torch.save({'model': model_eval.state_dict(),
  114. 'mAP': -1.,
  115. 'optimizer': self.optimizer.state_dict(),
  116. 'lr_scheduler': self.lr_scheduler.state_dict(),
  117. 'epoch': self.epoch,
  118. 'args': self.args},
  119. checkpoint_path)
  120. else:
  121. print('eval ...')
  122. # evaluate
  123. with torch.no_grad():
  124. self.evaluator.evaluate(model_eval)
  125. # save model
  126. cur_map = self.evaluator.map
  127. if cur_map > self.best_map:
  128. # update best-map
  129. self.best_map = cur_map
  130. # save model
  131. print('Saving state, epoch:', self.epoch)
  132. weight_name = '{}_best.pth'.format(self.args.model)
  133. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  134. torch.save({'model': model_eval.state_dict(),
  135. 'mAP': round(self.best_map*100, 1),
  136. 'optimizer': self.optimizer.state_dict(),
  137. 'lr_scheduler': self.lr_scheduler.state_dict(),
  138. 'epoch': self.epoch,
  139. 'args': self.args},
  140. checkpoint_path)
  141. if self.args.distributed:
  142. # wait for all processes to synchronize
  143. dist.barrier()
  144. # set train mode.
  145. model.train()
  146. def train_one_epoch(self, model):
  147. metric_logger = MetricLogger(delimiter=" ")
  148. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  149. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  150. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  151. epoch_size = len(self.train_loader)
  152. print_freq = 10
  153. # basic parameters
  154. epoch_size = len(self.train_loader)
  155. img_size = self.cfg.train_img_size
  156. nw = epoch_size * self.cfg.warmup_epoch
  157. # Train one epoch
  158. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  159. ni = iter_i + self.epoch * epoch_size
  160. # Warmup
  161. if nw > 0 and ni < nw:
  162. self.lr_scheduler_warmup(ni, self.optimizer)
  163. elif ni == nw:
  164. print("Warmup stage is over.")
  165. self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
  166. # To device
  167. images = images.to(self.device, non_blocking=True).float()
  168. # Multi scale
  169. images, targets, img_size = self.rescale_image_targets(
  170. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  171. # Visualize train targets
  172. if self.args.vis_tgt:
  173. vis_data(images,
  174. targets,
  175. self.cfg.num_classes,
  176. self.cfg.normalize_coords,
  177. self.train_transform.color_format,
  178. self.cfg.pixel_mean,
  179. self.cfg.pixel_std,
  180. self.cfg.box_format)
  181. # Inference
  182. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  183. outputs = model(images)
  184. # Compute loss
  185. loss_dict = self.criterion(outputs=outputs, targets=targets)
  186. losses = loss_dict['losses']
  187. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  188. losses /= self.cfg.grad_accumulate
  189. # Backward
  190. self.scaler.scale(losses).backward()
  191. # Gradient clip
  192. if self.cfg.clip_max_norm > 0:
  193. self.scaler.unscale_(self.optimizer)
  194. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  195. # Optimize
  196. if (iter_i + 1) % self.cfg.grad_accumulate == 0:
  197. self.scaler.step(self.optimizer)
  198. self.scaler.update()
  199. self.optimizer.zero_grad()
  200. # ModelEMA
  201. if self.model_ema is not None:
  202. self.model_ema.update(model)
  203. # Update log
  204. metric_logger.update(**loss_dict_reduced)
  205. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  206. metric_logger.update(size=img_size)
  207. if self.args.debug:
  208. print("For debug mode, we only train 1 iteration")
  209. break
  210. # Gather the stats from all processes
  211. metric_logger.synchronize_between_processes()
  212. print("Averaged stats:", metric_logger)
  213. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  214. """
  215. Deployed for Multi scale trick.
  216. """
  217. # During training phase, the shape of input image is square.
  218. old_img_size = images.shape[-1]
  219. min_img_size = old_img_size * multi_scale_range[0]
  220. max_img_size = old_img_size * multi_scale_range[1]
  221. # Choose a new image size
  222. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  223. # Resize
  224. if new_img_size != old_img_size:
  225. # interpolate
  226. images = torch.nn.functional.interpolate(
  227. input=images,
  228. size=new_img_size,
  229. mode='bilinear',
  230. align_corners=False)
  231. # rescale targets
  232. if not self.cfg.normalize_coords:
  233. for tgt in targets:
  234. boxes = tgt["boxes"].clone()
  235. labels = tgt["labels"].clone()
  236. boxes = torch.clamp(boxes, 0, old_img_size)
  237. # rescale box
  238. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  239. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  240. # refine tgt
  241. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  242. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  243. keep = (min_tgt_size >= 1)
  244. tgt["boxes"] = boxes[keep]
  245. tgt["labels"] = labels[keep]
  246. return images, targets, new_img_size
  247. def check_second_stage(self):
  248. # set second stage
  249. print('============== Second stage of Training ==============')
  250. self.second_stage = True
  251. self.heavy_eval = True
  252. # close mosaic augmentation
  253. if self.train_loader.dataset.mosaic_prob > 0.:
  254. print(' - Close < Mosaic Augmentation > ...')
  255. self.train_loader.dataset.mosaic_prob = 0.
  256. # close mixup augmentation
  257. if self.train_loader.dataset.mixup_prob > 0.:
  258. print(' - Close < Mixup Augmentation > ...')
  259. self.train_loader.dataset.mixup_prob = 0.
  260. # close copy-paste augmentation
  261. if self.train_loader.dataset.copy_paste > 0.:
  262. print(' - Close < Copy-paste Augmentation > ...')
  263. self.train_loader.dataset.copy_paste = 0.
  264. class RTDetrTrainer(object):
  265. def __init__(self,
  266. # Basic parameters
  267. args,
  268. cfg,
  269. device,
  270. # Model parameters
  271. model,
  272. model_ema,
  273. criterion,
  274. # Data parameters
  275. train_transform,
  276. val_transform,
  277. dataset,
  278. train_loader,
  279. evaluator,
  280. ):
  281. # ------------------- basic parameters -------------------
  282. self.args = args
  283. self.cfg = cfg
  284. self.epoch = 0
  285. self.best_map = -1.
  286. self.device = device
  287. self.criterion = criterion
  288. self.heavy_eval = False
  289. self.model_ema = model_ema
  290. # path to save model
  291. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  292. os.makedirs(self.path_to_save, exist_ok=True)
  293. # ---------------------------- Transform ----------------------------
  294. self.train_transform = train_transform
  295. self.val_transform = val_transform
  296. # ---------------------------- Dataset & Dataloader ----------------------------
  297. self.dataset = dataset
  298. self.train_loader = train_loader
  299. # ---------------------------- Evaluator ----------------------------
  300. self.evaluator = evaluator
  301. # ---------------------------- Build Grad. Scaler ----------------------------
  302. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  303. # ---------------------------- Build Optimizer ----------------------------
  304. cfg.grad_accumulate = max(16 // args.batch_size, 1, args.grad_accumulate)
  305. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  306. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  307. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  308. # ---------------------------- Build LR Scheduler ----------------------------
  309. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.base_lr, wp_iter=cfg.warmup_iters)
  310. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  311. # ---------------------------- Build Model-EMA ----------------------------
  312. if self.model_ema is not None:
  313. update_init = self.start_epoch * len(self.train_loader) // cfg.grad_accumulate
  314. print("Initialize ModelEMA's updates: {}".format(update_init))
  315. self.model_ema.updates = update_init
  316. def train(self, model):
  317. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  318. if self.args.distributed:
  319. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  320. # train one epoch
  321. self.epoch = epoch
  322. self.train_one_epoch(model)
  323. # LR Scheduler
  324. self.lr_scheduler.step()
  325. # eval one epoch
  326. if self.heavy_eval:
  327. model_eval = model.module if self.args.distributed else model
  328. self.eval(model_eval)
  329. else:
  330. model_eval = model.module if self.args.distributed else model
  331. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  332. self.eval(model_eval)
  333. if self.args.debug:
  334. print("For debug mode, we only train 1 epoch")
  335. break
  336. def eval(self, model):
  337. # set eval mode
  338. model.eval()
  339. model_eval = model if self.model_ema is None else self.model_ema.ema
  340. if distributed_utils.is_main_process():
  341. # check evaluator
  342. if self.evaluator is None:
  343. print('No evaluator ... save model and go on training.')
  344. print('Saving state, epoch: {}'.format(self.epoch))
  345. weight_name = '{}_no_eval.pth'.format(self.args.model)
  346. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  347. torch.save({'model': model_eval.state_dict(),
  348. 'mAP': -1.,
  349. 'optimizer': self.optimizer.state_dict(),
  350. 'lr_scheduler': self.lr_scheduler.state_dict(),
  351. 'epoch': self.epoch,
  352. 'args': self.args},
  353. checkpoint_path)
  354. else:
  355. print('eval ...')
  356. # evaluate
  357. with torch.no_grad():
  358. self.evaluator.evaluate(model_eval)
  359. # save model
  360. cur_map = self.evaluator.map
  361. if cur_map > self.best_map:
  362. # update best-map
  363. self.best_map = cur_map
  364. # save model
  365. print('Saving state, epoch:', self.epoch)
  366. weight_name = '{}_best.pth'.format(self.args.model)
  367. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  368. torch.save({'model': model_eval.state_dict(),
  369. 'mAP': round(self.best_map*100, 1),
  370. 'optimizer': self.optimizer.state_dict(),
  371. 'lr_scheduler': self.lr_scheduler.state_dict(),
  372. 'epoch': self.epoch,
  373. 'args': self.args},
  374. checkpoint_path)
  375. if self.args.distributed:
  376. # wait for all processes to synchronize
  377. dist.barrier()
  378. # set train mode.
  379. model.train()
  380. def train_one_epoch(self, model):
  381. metric_logger = MetricLogger(delimiter=" ")
  382. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  383. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  384. metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  385. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  386. epoch_size = len(self.train_loader)
  387. print_freq = 10
  388. # basic parameters
  389. epoch_size = len(self.train_loader)
  390. img_size = self.cfg.train_img_size
  391. nw = self.cfg.warmup_iters
  392. lr_warmup_stage = True
  393. # Train one epoch
  394. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  395. ni = iter_i + self.epoch * epoch_size
  396. # WarmUp
  397. if ni < nw and lr_warmup_stage:
  398. self.wp_lr_scheduler(ni, self.optimizer)
  399. elif ni == nw and lr_warmup_stage:
  400. print('Warmup stage is over.')
  401. lr_warmup_stage = False
  402. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  403. # To device
  404. images = images.to(self.device, non_blocking=True).float()
  405. for tgt in targets:
  406. tgt['boxes'] = tgt['boxes'].to(self.device)
  407. tgt['labels'] = tgt['labels'].to(self.device)
  408. # Multi scale
  409. images, targets, img_size = self.rescale_image_targets(
  410. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  411. # Visualize train targets
  412. if self.args.vis_tgt:
  413. vis_data(images,
  414. targets,
  415. self.cfg.num_classes,
  416. self.cfg.normalize_coords,
  417. self.train_transform.color_format,
  418. self.cfg.pixel_mean,
  419. self.cfg.pixel_std,
  420. self.cfg.box_format)
  421. # Inference
  422. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  423. outputs = model(images, targets)
  424. loss_dict = self.criterion(outputs, targets)
  425. losses = sum(loss_dict.values())
  426. losses /= self.cfg.grad_accumulate
  427. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  428. # Backward
  429. self.scaler.scale(losses).backward()
  430. # Gradient clip
  431. grad_norm = None
  432. if self.cfg.clip_max_norm > 0:
  433. self.scaler.unscale_(self.optimizer)
  434. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  435. # Optimize
  436. if (iter_i + 1) % self.cfg.grad_accumulate == 0:
  437. self.scaler.step(self.optimizer)
  438. self.scaler.update()
  439. self.optimizer.zero_grad()
  440. # ModelEMA
  441. if self.model_ema is not None:
  442. self.model_ema.update(model)
  443. # Update log
  444. metric_logger.update(loss=losses.item(), **loss_dict_reduced)
  445. metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
  446. metric_logger.update(grad_norm=grad_norm)
  447. metric_logger.update(size=img_size)
  448. if self.args.debug:
  449. print("For debug mode, we only train 1 iteration")
  450. break
  451. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  452. """
  453. Deployed for Multi scale trick.
  454. """
  455. # During training phase, the shape of input image is square.
  456. old_img_size = images.shape[-1]
  457. min_img_size = old_img_size * multi_scale_range[0]
  458. max_img_size = old_img_size * multi_scale_range[1]
  459. # Choose a new image size
  460. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  461. # Resize
  462. if new_img_size != old_img_size:
  463. # interpolate
  464. images = torch.nn.functional.interpolate(
  465. input=images,
  466. size=new_img_size,
  467. mode='bilinear',
  468. align_corners=False)
  469. return images, targets, new_img_size
  470. # Build Trainer
  471. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  472. # ----------------------- Det trainers -----------------------
  473. if cfg.trainer == 'yolo':
  474. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  475. elif cfg.trainer == 'rtdetr':
  476. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  477. else:
  478. raise NotImplementedError(cfg.trainer)