engine.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import numpy as np
  5. import random
  6. # ----------------- Extra Components -----------------
  7. from utils import distributed_utils
  8. from utils.misc import MetricLogger, SmoothedValue
  9. from utils.vis_tools import vis_data
  10. # ----------------- Optimizer & LrScheduler Components -----------------
  11. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  12. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler, build_lambda_lr_scheduler
  13. class YoloTrainer(object):
  14. def __init__(self,
  15. # Basic parameters
  16. args,
  17. cfg,
  18. device,
  19. # Model parameters
  20. model,
  21. model_ema,
  22. criterion,
  23. # Data parameters
  24. train_transform,
  25. val_transform,
  26. dataset,
  27. train_loader,
  28. evaluator,
  29. ):
  30. # ------------------- basic parameters -------------------
  31. self.args = args
  32. self.cfg = cfg
  33. self.epoch = 0
  34. self.best_map = -1.
  35. self.device = device
  36. self.criterion = criterion
  37. self.heavy_eval = False
  38. self.model_ema = model_ema
  39. # weak augmentatino stage
  40. self.second_stage = False
  41. self.second_stage_epoch = cfg.no_aug_epoch
  42. # path to save model
  43. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  44. os.makedirs(self.path_to_save, exist_ok=True)
  45. # ---------------------------- Transform ----------------------------
  46. self.train_transform = train_transform
  47. self.val_transform = val_transform
  48. # ---------------------------- Dataset & Dataloader ----------------------------
  49. self.dataset = dataset
  50. self.train_loader = train_loader
  51. # ---------------------------- Evaluator ----------------------------
  52. self.evaluator = evaluator
  53. # ---------------------------- Build Grad. Scaler ----------------------------
  54. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  55. # ---------------------------- Build Optimizer ----------------------------
  56. cfg.base_lr = cfg.per_image_lr * args.batch_size
  57. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  58. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  59. # ---------------------------- Build LR Scheduler ----------------------------
  60. self.lr_scheduler, self.lf = build_lambda_lr_scheduler(cfg, self.optimizer, cfg.max_epoch)
  61. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  62. if self.args.resume and self.args.resume != 'None':
  63. self.lr_scheduler.step()
  64. def train(self, model):
  65. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  66. if self.args.distributed:
  67. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  68. # check second stage
  69. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  70. self.check_second_stage()
  71. # save model of the last mosaic epoch
  72. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  73. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  74. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  75. torch.save({'model': model.state_dict(),
  76. 'mAP': round(self.evaluator.map*100, 1),
  77. 'optimizer': self.optimizer.state_dict(),
  78. 'epoch': self.epoch,
  79. 'args': self.args},
  80. checkpoint_path)
  81. # train one epoch
  82. self.epoch = epoch
  83. self.train_one_epoch(model)
  84. # LR Schedule
  85. self.lr_scheduler.step()
  86. # eval one epoch
  87. if self.heavy_eval:
  88. model_eval = model.module if self.args.distributed else model
  89. self.eval(model_eval)
  90. else:
  91. model_eval = model.module if self.args.distributed else model
  92. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  93. self.eval(model_eval)
  94. if self.args.debug:
  95. print("For debug mode, we only train 1 epoch")
  96. break
  97. def eval(self, model):
  98. # set eval mode
  99. model.eval()
  100. model_eval = model if self.model_ema is None else self.model_ema.ema
  101. cur_map = -1.
  102. to_save = False
  103. if distributed_utils.is_main_process():
  104. if self.evaluator is None:
  105. print('No evaluator ... save model and go on training.')
  106. to_save = True
  107. weight_name = '{}_no_eval.pth'.format(self.args.model)
  108. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  109. else:
  110. print('Eval ...')
  111. # Evaluate
  112. with torch.no_grad():
  113. self.evaluator.evaluate(model_eval)
  114. cur_map = self.evaluator.map
  115. if cur_map > self.best_map:
  116. # update best-map
  117. self.best_map = cur_map
  118. to_save = True
  119. # Save model
  120. if to_save:
  121. print('Saving state, epoch:', self.epoch)
  122. weight_name = '{}_best.pth'.format(self.args.model)
  123. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  124. state_dicts = {
  125. 'model': model_eval.state_dict(),
  126. 'mAP': round(cur_map*100, 1),
  127. 'optimizer': self.optimizer.state_dict(),
  128. 'epoch': self.epoch,
  129. 'args': self.args,
  130. }
  131. if self.model_ema is not None:
  132. state_dicts["ema_updates"] = self.model_ema.updates
  133. torch.save(state_dicts, checkpoint_path)
  134. if self.args.distributed:
  135. # wait for all processes to synchronize
  136. dist.barrier()
  137. # set train mode.
  138. model.train()
  139. def train_one_epoch(self, model):
  140. metric_logger = MetricLogger(delimiter=" ")
  141. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  142. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  143. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  144. epoch_size = len(self.train_loader)
  145. print_freq = 10
  146. # basic parameters
  147. epoch_size = len(self.train_loader)
  148. img_size = self.cfg.train_img_size
  149. nw = epoch_size * self.cfg.warmup_epoch
  150. # Train one epoch
  151. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  152. ni = iter_i + self.epoch * epoch_size
  153. # Warmup
  154. if ni <= nw:
  155. xi = [0, nw] # x interp
  156. for j, x in enumerate(self.optimizer.param_groups):
  157. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  158. x['lr'] = np.interp(
  159. ni, xi, [self.cfg.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  160. if 'momentum' in x:
  161. x['momentum'] = np.interp(ni, xi, [self.cfg.warmup_momentum, self.cfg.momentum])
  162. # To device
  163. images = images.to(self.device, non_blocking=True).float()
  164. # Multi scale
  165. images, targets, img_size = self.rescale_image_targets(
  166. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  167. # Visualize train targets
  168. if self.args.vis_tgt:
  169. vis_data(images,
  170. targets,
  171. self.cfg.num_classes,
  172. self.cfg.normalize_coords,
  173. self.train_transform.color_format,
  174. self.cfg.pixel_mean,
  175. self.cfg.pixel_std,
  176. self.cfg.box_format)
  177. # Inference
  178. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  179. outputs = model(images)
  180. # Compute loss
  181. loss_dict = self.criterion(outputs=outputs, targets=targets)
  182. losses = loss_dict['losses']
  183. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  184. # Backward
  185. self.scaler.scale(losses).backward()
  186. # Optimize
  187. if self.cfg.clip_max_norm > 0:
  188. self.scaler.unscale_(self.optimizer)
  189. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  190. self.scaler.step(self.optimizer)
  191. self.scaler.update()
  192. self.optimizer.zero_grad()
  193. # ModelEMA
  194. if self.model_ema is not None:
  195. self.model_ema.update(model)
  196. # Update log
  197. metric_logger.update(**loss_dict_reduced)
  198. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  199. metric_logger.update(size=img_size)
  200. if self.args.debug:
  201. print("For debug mode, we only train 1 iteration")
  202. break
  203. # Gather the stats from all processes
  204. metric_logger.synchronize_between_processes()
  205. print("Averaged stats:", metric_logger)
  206. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  207. """
  208. Deployed for Multi scale trick.
  209. """
  210. # During training phase, the shape of input image is square.
  211. old_img_size = images.shape[-1]
  212. min_img_size = old_img_size * multi_scale_range[0]
  213. max_img_size = old_img_size * multi_scale_range[1]
  214. # Choose a new image size
  215. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  216. # Resize
  217. if new_img_size != old_img_size:
  218. # interpolate
  219. images = torch.nn.functional.interpolate(
  220. input=images,
  221. size=new_img_size,
  222. mode='bilinear',
  223. align_corners=False)
  224. # rescale targets
  225. if not self.cfg.normalize_coords:
  226. for tgt in targets:
  227. boxes = tgt["boxes"].clone()
  228. labels = tgt["labels"].clone()
  229. boxes = torch.clamp(boxes, 0, old_img_size)
  230. # rescale box
  231. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  232. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  233. # refine tgt
  234. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  235. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  236. keep = (min_tgt_size >= 8)
  237. tgt["boxes"] = boxes[keep]
  238. tgt["labels"] = labels[keep]
  239. return images, targets, new_img_size
  240. def check_second_stage(self):
  241. # set second stage
  242. print('============== Second stage of Training ==============')
  243. self.second_stage = True
  244. self.heavy_eval = True
  245. # close mosaic augmentation
  246. if self.train_loader.dataset.mosaic_prob > 0.:
  247. print(' - Close < Mosaic Augmentation > ...')
  248. self.train_loader.dataset.mosaic_prob = 0.
  249. # close mixup augmentation
  250. if self.train_loader.dataset.mixup_prob > 0.:
  251. print(' - Close < Mixup Augmentation > ...')
  252. self.train_loader.dataset.mixup_prob = 0.
  253. # close copy-paste augmentation
  254. if self.train_loader.dataset.copy_paste > 0.:
  255. print(' - Close < Copy-paste Augmentation > ...')
  256. self.train_loader.dataset.copy_paste = 0.
  257. class RTDetrTrainer(object):
  258. def __init__(self,
  259. # Basic parameters
  260. args,
  261. cfg,
  262. device,
  263. # Model parameters
  264. model,
  265. model_ema,
  266. criterion,
  267. # Data parameters
  268. train_transform,
  269. val_transform,
  270. dataset,
  271. train_loader,
  272. evaluator,
  273. ):
  274. # ------------------- basic parameters -------------------
  275. self.args = args
  276. self.cfg = cfg
  277. self.epoch = 0
  278. self.best_map = -1.
  279. self.device = device
  280. self.criterion = criterion
  281. self.heavy_eval = False
  282. self.model_ema = model_ema
  283. # path to save model
  284. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  285. os.makedirs(self.path_to_save, exist_ok=True)
  286. # ---------------------------- Transform ----------------------------
  287. self.train_transform = train_transform
  288. self.val_transform = val_transform
  289. # ---------------------------- Dataset & Dataloader ----------------------------
  290. self.dataset = dataset
  291. self.train_loader = train_loader
  292. # ---------------------------- Evaluator ----------------------------
  293. self.evaluator = evaluator
  294. # ---------------------------- Build Grad. Scaler ----------------------------
  295. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  296. # ---------------------------- Build Optimizer ----------------------------
  297. cfg.base_lr = cfg.per_image_lr * args.batch_size
  298. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  299. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  300. # ---------------------------- Build LR Scheduler ----------------------------
  301. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.warmup_iters, cfg.base_lr)
  302. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  303. def train(self, model):
  304. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  305. if self.args.distributed:
  306. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  307. # train one epoch
  308. self.epoch = epoch
  309. self.train_one_epoch(model)
  310. # LR Scheduler
  311. self.lr_scheduler.step()
  312. # eval one epoch
  313. if self.heavy_eval:
  314. model_eval = model.module if self.args.distributed else model
  315. self.eval(model_eval)
  316. else:
  317. model_eval = model.module if self.args.distributed else model
  318. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  319. self.eval(model_eval)
  320. if self.args.debug:
  321. print("For debug mode, we only train 1 epoch")
  322. break
  323. def eval(self, model):
  324. # set eval mode
  325. model.eval()
  326. model_eval = model if self.model_ema is None else self.model_ema.ema
  327. cur_map = -1.
  328. to_save = False
  329. if distributed_utils.is_main_process():
  330. if self.evaluator is None:
  331. print('No evaluator ... save model and go on training.')
  332. to_save = True
  333. weight_name = '{}_no_eval.pth'.format(self.args.model)
  334. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  335. else:
  336. print('Eval ...')
  337. # Evaluate
  338. with torch.no_grad():
  339. self.evaluator.evaluate(model_eval)
  340. cur_map = self.evaluator.map
  341. if cur_map > self.best_map:
  342. # update best-map
  343. self.best_map = cur_map
  344. to_save = True
  345. # Save model
  346. if to_save:
  347. print('Saving state, epoch:', self.epoch)
  348. weight_name = '{}_best.pth'.format(self.args.model)
  349. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  350. state_dicts = {
  351. 'model': model_eval.state_dict(),
  352. 'mAP': round(cur_map*100, 1),
  353. 'optimizer': self.optimizer.state_dict(),
  354. 'lr_scheduler': self.lr_scheduler.state_dict(),
  355. 'epoch': self.epoch,
  356. 'args': self.args,
  357. }
  358. if self.model_ema is not None:
  359. state_dicts["ema_updates"] = self.model_ema.updates
  360. torch.save(state_dicts, checkpoint_path)
  361. if self.args.distributed:
  362. # wait for all processes to synchronize
  363. dist.barrier()
  364. # set train mode.
  365. model.train()
  366. def train_one_epoch(self, model):
  367. metric_logger = MetricLogger(delimiter=" ")
  368. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  369. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  370. metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  371. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  372. epoch_size = len(self.train_loader)
  373. print_freq = 10
  374. # basic parameters
  375. epoch_size = len(self.train_loader)
  376. img_size = self.cfg.train_img_size
  377. nw = self.cfg.warmup_iters
  378. lr_warmup_stage = True
  379. # Train one epoch
  380. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  381. ni = iter_i + self.epoch * epoch_size
  382. # WarmUp
  383. if ni < nw and lr_warmup_stage:
  384. self.wp_lr_scheduler(ni, self.optimizer)
  385. elif ni == nw and lr_warmup_stage:
  386. print('Warmup stage is over.')
  387. lr_warmup_stage = False
  388. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  389. # To device
  390. images = images.to(self.device, non_blocking=True).float()
  391. for tgt in targets:
  392. tgt['boxes'] = tgt['boxes'].to(self.device)
  393. tgt['labels'] = tgt['labels'].to(self.device)
  394. # Multi scale
  395. images, targets, img_size = self.rescale_image_targets(
  396. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  397. # Visualize train targets
  398. if self.args.vis_tgt:
  399. vis_data(images,
  400. targets,
  401. self.cfg.num_classes,
  402. self.cfg.normalize_coords,
  403. self.train_transform.color_format,
  404. self.cfg.pixel_mean,
  405. self.cfg.pixel_std,
  406. self.cfg.box_format)
  407. # Inference
  408. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  409. outputs = model(images, targets)
  410. loss_dict = self.criterion(outputs, targets)
  411. losses = sum(loss_dict.values())
  412. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  413. # Backward
  414. self.scaler.scale(losses).backward()
  415. # Optimize
  416. if self.cfg.clip_max_norm > 0:
  417. self.scaler.unscale_(self.optimizer)
  418. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  419. self.scaler.step(self.optimizer)
  420. self.scaler.update()
  421. self.optimizer.zero_grad()
  422. # ModelEMA
  423. if self.model_ema is not None:
  424. self.model_ema.update(model)
  425. # Update log
  426. metric_logger.update(**loss_dict_reduced)
  427. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  428. metric_logger.update(size=img_size)
  429. if self.args.debug:
  430. print("For debug mode, we only train 1 iteration")
  431. break
  432. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  433. """
  434. Deployed for Multi scale trick.
  435. """
  436. # During training phase, the shape of input image is square.
  437. old_img_size = images.shape[-1]
  438. min_img_size = old_img_size * multi_scale_range[0]
  439. max_img_size = old_img_size * multi_scale_range[1]
  440. # Choose a new image size
  441. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  442. # Resize
  443. if new_img_size != old_img_size:
  444. # interpolate
  445. images = torch.nn.functional.interpolate(
  446. input=images,
  447. size=new_img_size,
  448. mode='bilinear',
  449. align_corners=False)
  450. return images, targets, new_img_size
  451. # Build Trainer
  452. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  453. # ----------------------- Det trainers -----------------------
  454. if cfg.trainer == 'yolo':
  455. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  456. elif cfg.trainer == 'rtdetr':
  457. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  458. else:
  459. raise NotImplementedError(cfg.trainer)