engine.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import random
  5. # ----------------- Extra Components -----------------
  6. from utils import distributed_utils
  7. from utils.misc import MetricLogger, SmoothedValue
  8. from utils.vis_tools import vis_data
  9. # ----------------- Optimizer & LrScheduler Components -----------------
  10. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  11. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
  12. class YoloTrainer(object):
  13. def __init__(self,
  14. # Basic parameters
  15. args,
  16. cfg,
  17. device,
  18. # Model parameters
  19. model,
  20. model_ema,
  21. criterion,
  22. # Data parameters
  23. train_transform,
  24. val_transform,
  25. dataset,
  26. train_loader,
  27. evaluator,
  28. ):
  29. # ------------------- basic parameters -------------------
  30. self.args = args
  31. self.cfg = cfg
  32. self.epoch = 0
  33. self.best_map = -1.
  34. self.device = device
  35. self.criterion = criterion
  36. self.heavy_eval = False
  37. self.model_ema = model_ema
  38. # weak augmentatino stage
  39. self.second_stage = False
  40. self.second_stage_epoch = cfg.no_aug_epoch
  41. # path to save model
  42. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  43. os.makedirs(self.path_to_save, exist_ok=True)
  44. # ---------------------------- Transform ----------------------------
  45. self.train_transform = train_transform
  46. self.val_transform = val_transform
  47. # ---------------------------- Dataset & Dataloader ----------------------------
  48. self.dataset = dataset
  49. self.train_loader = train_loader
  50. # ---------------------------- Evaluator ----------------------------
  51. self.evaluator = evaluator
  52. # ---------------------------- Build Grad. Scaler ----------------------------
  53. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  54. # ---------------------------- Build Optimizer ----------------------------
  55. cfg.base_lr = cfg.per_image_lr * args.batch_size
  56. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  57. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  58. # ---------------------------- Build LR Scheduler ----------------------------
  59. warmup_iters = cfg.warmup_epoch * len(self.train_loader)
  60. self.lr_scheduler_warmup = LinearWarmUpLrScheduler(warmup_iters, cfg.base_lr, cfg.warmup_bias_lr, cfg.warmup_momentum)
  61. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  62. def train(self, model):
  63. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  64. if self.args.distributed:
  65. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  66. # check second stage
  67. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  68. self.check_second_stage()
  69. # save model of the last mosaic epoch
  70. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  71. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  72. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  73. torch.save({'model': model.state_dict(),
  74. 'mAP': round(self.evaluator.map*100, 1),
  75. 'optimizer': self.optimizer.state_dict(),
  76. 'epoch': self.epoch,
  77. 'args': self.args},
  78. checkpoint_path)
  79. # train one epoch
  80. self.epoch = epoch
  81. self.train_one_epoch(model)
  82. # LR Schedule
  83. if (epoch + 1) > self.cfg.warmup_epoch:
  84. self.lr_scheduler.step()
  85. # eval one epoch
  86. if self.heavy_eval:
  87. model_eval = model.module if self.args.distributed else model
  88. self.eval(model_eval)
  89. else:
  90. model_eval = model.module if self.args.distributed else model
  91. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  92. self.eval(model_eval)
  93. if self.args.debug:
  94. print("For debug mode, we only train 1 epoch")
  95. break
  96. def eval(self, model):
  97. # set eval mode
  98. model.eval()
  99. model_eval = model if self.model_ema is None else self.model_ema.ema
  100. cur_map = -1.
  101. to_save = False
  102. if distributed_utils.is_main_process():
  103. if self.evaluator is None:
  104. print('No evaluator ... save model and go on training.')
  105. to_save = True
  106. weight_name = '{}_no_eval.pth'.format(self.args.model)
  107. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  108. else:
  109. print('Eval ...')
  110. # Evaluate
  111. with torch.no_grad():
  112. self.evaluator.evaluate(model_eval)
  113. cur_map = self.evaluator.map
  114. if cur_map > self.best_map:
  115. # update best-map
  116. self.best_map = cur_map
  117. to_save = True
  118. # Save model
  119. if to_save:
  120. print('Saving state, epoch:', self.epoch)
  121. weight_name = '{}_best.pth'.format(self.args.model)
  122. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  123. state_dicts = {
  124. 'model': model_eval.state_dict(),
  125. 'mAP': round(cur_map*100, 1),
  126. 'optimizer': self.optimizer.state_dict(),
  127. 'lr_scheduler': self.lr_scheduler.state_dict(),
  128. 'epoch': self.epoch,
  129. 'args': self.args,
  130. }
  131. if self.model_ema is not None:
  132. state_dicts["ema_updates"] = self.model_ema.updates
  133. torch.save(state_dicts, checkpoint_path)
  134. if self.args.distributed:
  135. # wait for all processes to synchronize
  136. dist.barrier()
  137. # set train mode.
  138. model.train()
  139. def train_one_epoch(self, model):
  140. metric_logger = MetricLogger(delimiter=" ")
  141. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  142. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  143. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  144. epoch_size = len(self.train_loader)
  145. print_freq = 10
  146. # basic parameters
  147. epoch_size = len(self.train_loader)
  148. img_size = self.cfg.train_img_size
  149. nw = epoch_size * self.cfg.warmup_epoch
  150. # Train one epoch
  151. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  152. ni = iter_i + self.epoch * epoch_size
  153. # Warmup
  154. if nw > 0 and ni < nw:
  155. self.lr_scheduler_warmup(ni, self.optimizer)
  156. elif ni == nw:
  157. print("Warmup stage is over.")
  158. self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
  159. # To device
  160. images = images.to(self.device, non_blocking=True).float()
  161. # Multi scale
  162. images, targets, img_size = self.rescale_image_targets(
  163. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  164. # Visualize train targets
  165. if self.args.vis_tgt:
  166. vis_data(images,
  167. targets,
  168. self.cfg.num_classes,
  169. self.cfg.normalize_coords,
  170. self.train_transform.color_format,
  171. self.cfg.pixel_mean,
  172. self.cfg.pixel_std,
  173. self.cfg.box_format)
  174. # Inference
  175. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  176. outputs = model(images)
  177. # Compute loss
  178. loss_dict = self.criterion(outputs=outputs, targets=targets)
  179. losses = loss_dict['losses']
  180. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  181. # Backward
  182. self.scaler.scale(losses).backward()
  183. # Optimize
  184. if self.cfg.clip_max_norm > 0:
  185. self.scaler.unscale_(self.optimizer)
  186. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  187. self.scaler.step(self.optimizer)
  188. self.scaler.update()
  189. self.optimizer.zero_grad()
  190. # ModelEMA
  191. if self.model_ema is not None:
  192. self.model_ema.update(model)
  193. # Update log
  194. metric_logger.update(**loss_dict_reduced)
  195. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  196. metric_logger.update(size=img_size)
  197. if self.args.debug:
  198. print("For debug mode, we only train 1 iteration")
  199. break
  200. # Gather the stats from all processes
  201. metric_logger.synchronize_between_processes()
  202. print("Averaged stats:", metric_logger)
  203. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  204. """
  205. Deployed for Multi scale trick.
  206. """
  207. # During training phase, the shape of input image is square.
  208. old_img_size = images.shape[-1]
  209. min_img_size = old_img_size * multi_scale_range[0]
  210. max_img_size = old_img_size * multi_scale_range[1]
  211. # Choose a new image size
  212. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  213. # Resize
  214. if new_img_size != old_img_size:
  215. # interpolate
  216. images = torch.nn.functional.interpolate(
  217. input=images,
  218. size=new_img_size,
  219. mode='bilinear',
  220. align_corners=False)
  221. # rescale targets
  222. if not self.cfg.normalize_coords:
  223. for tgt in targets:
  224. boxes = tgt["boxes"].clone()
  225. labels = tgt["labels"].clone()
  226. boxes = torch.clamp(boxes, 0, old_img_size)
  227. # rescale box
  228. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  229. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  230. # refine tgt
  231. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  232. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  233. keep = (min_tgt_size >= 1)
  234. tgt["boxes"] = boxes[keep]
  235. tgt["labels"] = labels[keep]
  236. return images, targets, new_img_size
  237. def check_second_stage(self):
  238. # set second stage
  239. print('============== Second stage of Training ==============')
  240. self.second_stage = True
  241. self.heavy_eval = True
  242. # close mosaic augmentation
  243. if self.train_loader.dataset.mosaic_prob > 0.:
  244. print(' - Close < Mosaic Augmentation > ...')
  245. self.train_loader.dataset.mosaic_prob = 0.
  246. # close mixup augmentation
  247. if self.train_loader.dataset.mixup_prob > 0.:
  248. print(' - Close < Mixup Augmentation > ...')
  249. self.train_loader.dataset.mixup_prob = 0.
  250. # close copy-paste augmentation
  251. if self.train_loader.dataset.copy_paste > 0.:
  252. print(' - Close < Copy-paste Augmentation > ...')
  253. self.train_loader.dataset.copy_paste = 0.
  254. class RTDetrTrainer(object):
  255. def __init__(self,
  256. # Basic parameters
  257. args,
  258. cfg,
  259. device,
  260. # Model parameters
  261. model,
  262. model_ema,
  263. criterion,
  264. # Data parameters
  265. train_transform,
  266. val_transform,
  267. dataset,
  268. train_loader,
  269. evaluator,
  270. ):
  271. # ------------------- basic parameters -------------------
  272. self.args = args
  273. self.cfg = cfg
  274. self.epoch = 0
  275. self.best_map = -1.
  276. self.device = device
  277. self.criterion = criterion
  278. self.heavy_eval = False
  279. self.model_ema = model_ema
  280. # path to save model
  281. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  282. os.makedirs(self.path_to_save, exist_ok=True)
  283. # ---------------------------- Transform ----------------------------
  284. self.train_transform = train_transform
  285. self.val_transform = val_transform
  286. # ---------------------------- Dataset & Dataloader ----------------------------
  287. self.dataset = dataset
  288. self.train_loader = train_loader
  289. # ---------------------------- Evaluator ----------------------------
  290. self.evaluator = evaluator
  291. # ---------------------------- Build Grad. Scaler ----------------------------
  292. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  293. # ---------------------------- Build Optimizer ----------------------------
  294. cfg.base_lr = cfg.per_image_lr * args.batch_size
  295. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  296. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  297. # ---------------------------- Build LR Scheduler ----------------------------
  298. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.warmup_iters, cfg.base_lr)
  299. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  300. def train(self, model):
  301. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  302. if self.args.distributed:
  303. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  304. # train one epoch
  305. self.epoch = epoch
  306. self.train_one_epoch(model)
  307. # LR Scheduler
  308. self.lr_scheduler.step()
  309. # eval one epoch
  310. if self.heavy_eval:
  311. model_eval = model.module if self.args.distributed else model
  312. self.eval(model_eval)
  313. else:
  314. model_eval = model.module if self.args.distributed else model
  315. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  316. self.eval(model_eval)
  317. if self.args.debug:
  318. print("For debug mode, we only train 1 epoch")
  319. break
  320. def eval(self, model):
  321. # set eval mode
  322. model.eval()
  323. model_eval = model if self.model_ema is None else self.model_ema.ema
  324. cur_map = -1.
  325. to_save = False
  326. if distributed_utils.is_main_process():
  327. if self.evaluator is None:
  328. print('No evaluator ... save model and go on training.')
  329. to_save = True
  330. weight_name = '{}_no_eval.pth'.format(self.args.model)
  331. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  332. else:
  333. print('Eval ...')
  334. # Evaluate
  335. with torch.no_grad():
  336. self.evaluator.evaluate(model_eval)
  337. cur_map = self.evaluator.map
  338. if cur_map > self.best_map:
  339. # update best-map
  340. self.best_map = cur_map
  341. to_save = True
  342. # Save model
  343. if to_save:
  344. print('Saving state, epoch:', self.epoch)
  345. weight_name = '{}_best.pth'.format(self.args.model)
  346. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  347. state_dicts = {
  348. 'model': model_eval.state_dict(),
  349. 'mAP': round(cur_map*100, 1),
  350. 'optimizer': self.optimizer.state_dict(),
  351. 'lr_scheduler': self.lr_scheduler.state_dict(),
  352. 'epoch': self.epoch,
  353. 'args': self.args,
  354. }
  355. if self.model_ema is not None:
  356. state_dicts["ema_updates"] = self.model_ema.updates
  357. torch.save(state_dicts, checkpoint_path)
  358. if self.args.distributed:
  359. # wait for all processes to synchronize
  360. dist.barrier()
  361. # set train mode.
  362. model.train()
  363. def train_one_epoch(self, model):
  364. metric_logger = MetricLogger(delimiter=" ")
  365. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  366. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  367. metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  368. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  369. epoch_size = len(self.train_loader)
  370. print_freq = 10
  371. # basic parameters
  372. epoch_size = len(self.train_loader)
  373. img_size = self.cfg.train_img_size
  374. nw = self.cfg.warmup_iters
  375. lr_warmup_stage = True
  376. # Train one epoch
  377. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  378. ni = iter_i + self.epoch * epoch_size
  379. # WarmUp
  380. if ni < nw and lr_warmup_stage:
  381. self.wp_lr_scheduler(ni, self.optimizer)
  382. elif ni == nw and lr_warmup_stage:
  383. print('Warmup stage is over.')
  384. lr_warmup_stage = False
  385. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  386. # To device
  387. images = images.to(self.device, non_blocking=True).float()
  388. for tgt in targets:
  389. tgt['boxes'] = tgt['boxes'].to(self.device)
  390. tgt['labels'] = tgt['labels'].to(self.device)
  391. # Multi scale
  392. images, targets, img_size = self.rescale_image_targets(
  393. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  394. # Visualize train targets
  395. if self.args.vis_tgt:
  396. vis_data(images,
  397. targets,
  398. self.cfg.num_classes,
  399. self.cfg.normalize_coords,
  400. self.train_transform.color_format,
  401. self.cfg.pixel_mean,
  402. self.cfg.pixel_std,
  403. self.cfg.box_format)
  404. # Inference
  405. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  406. outputs = model(images, targets)
  407. loss_dict = self.criterion(outputs, targets)
  408. losses = sum(loss_dict.values())
  409. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  410. # Backward
  411. self.scaler.scale(losses).backward()
  412. # Optimize
  413. if self.cfg.clip_max_norm > 0:
  414. self.scaler.unscale_(self.optimizer)
  415. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  416. self.scaler.step(self.optimizer)
  417. self.scaler.update()
  418. self.optimizer.zero_grad()
  419. # ModelEMA
  420. if self.model_ema is not None:
  421. self.model_ema.update(model)
  422. # Update log
  423. metric_logger.update(**loss_dict_reduced)
  424. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  425. metric_logger.update(size=img_size)
  426. if self.args.debug:
  427. print("For debug mode, we only train 1 iteration")
  428. break
  429. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  430. """
  431. Deployed for Multi scale trick.
  432. """
  433. # During training phase, the shape of input image is square.
  434. old_img_size = images.shape[-1]
  435. min_img_size = old_img_size * multi_scale_range[0]
  436. max_img_size = old_img_size * multi_scale_range[1]
  437. # Choose a new image size
  438. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  439. # Resize
  440. if new_img_size != old_img_size:
  441. # interpolate
  442. images = torch.nn.functional.interpolate(
  443. input=images,
  444. size=new_img_size,
  445. mode='bilinear',
  446. align_corners=False)
  447. return images, targets, new_img_size
  448. # Build Trainer
  449. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  450. # ----------------------- Det trainers -----------------------
  451. if cfg.trainer == 'yolo':
  452. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  453. elif cfg.trainer == 'rtdetr':
  454. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  455. else:
  456. raise NotImplementedError(cfg.trainer)