engine.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. import torch
  2. import torch.distributed as dist
  3. import os
  4. import numpy as np
  5. import random
  6. # ----------------- Extra Components -----------------
  7. from utils import distributed_utils
  8. from utils.misc import MetricLogger, SmoothedValue
  9. from utils.vis_tools import vis_data
  10. # ----------------- Optimizer & LrScheduler Components -----------------
  11. from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
  12. from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler, build_lambda_lr_scheduler
  13. class YoloTrainer(object):
  14. def __init__(self,
  15. # Basic parameters
  16. args,
  17. cfg,
  18. device,
  19. # Model parameters
  20. model,
  21. model_ema,
  22. criterion,
  23. # Data parameters
  24. train_transform,
  25. val_transform,
  26. dataset,
  27. train_loader,
  28. evaluator,
  29. ):
  30. # ------------------- basic parameters -------------------
  31. self.args = args
  32. self.cfg = cfg
  33. self.epoch = 0
  34. self.best_map = -1.
  35. self.device = device
  36. self.criterion = criterion
  37. self.heavy_eval = False
  38. self.model_ema = model_ema
  39. # weak augmentatino stage
  40. self.second_stage = False
  41. self.second_stage_epoch = cfg.no_aug_epoch
  42. # path to save model
  43. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  44. os.makedirs(self.path_to_save, exist_ok=True)
  45. # ---------------------------- Transform ----------------------------
  46. self.train_transform = train_transform
  47. self.val_transform = val_transform
  48. # ---------------------------- Dataset & Dataloader ----------------------------
  49. self.dataset = dataset
  50. self.train_loader = train_loader
  51. # ---------------------------- Evaluator ----------------------------
  52. self.evaluator = evaluator
  53. # ---------------------------- Build Grad. Scaler ----------------------------
  54. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  55. # ---------------------------- Build Optimizer ----------------------------
  56. self.grad_accumulate = max(64 // args.batch_size, 1)
  57. cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
  58. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  59. self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
  60. # ---------------------------- Build LR Scheduler ----------------------------
  61. self.lr_scheduler, self.lf = build_lambda_lr_scheduler(cfg, self.optimizer, cfg.max_epoch)
  62. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  63. if self.args.resume and self.args.resume != 'None':
  64. self.lr_scheduler.step()
  65. def train(self, model):
  66. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  67. if self.args.distributed:
  68. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  69. # check second stage
  70. if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  71. self.check_second_stage()
  72. # save model of the last mosaic epoch
  73. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  74. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  75. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
  76. torch.save({'model': model.state_dict(),
  77. 'mAP': round(self.evaluator.map*100, 1),
  78. 'optimizer': self.optimizer.state_dict(),
  79. 'epoch': self.epoch,
  80. 'args': self.args},
  81. checkpoint_path)
  82. # train one epoch
  83. self.epoch = epoch
  84. self.train_one_epoch(model)
  85. # LR Schedule
  86. self.lr_scheduler.step()
  87. # eval one epoch
  88. if self.heavy_eval:
  89. model_eval = model.module if self.args.distributed else model
  90. self.eval(model_eval)
  91. else:
  92. model_eval = model.module if self.args.distributed else model
  93. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  94. self.eval(model_eval)
  95. if self.args.debug:
  96. print("For debug mode, we only train 1 epoch")
  97. break
  98. def eval(self, model):
  99. # set eval mode
  100. model.eval()
  101. model_eval = model if self.model_ema is None else self.model_ema.ema
  102. cur_map = -1.
  103. to_save = False
  104. if distributed_utils.is_main_process():
  105. if self.evaluator is None:
  106. print('No evaluator ... save model and go on training.')
  107. to_save = True
  108. weight_name = '{}_no_eval.pth'.format(self.args.model)
  109. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  110. else:
  111. print('Eval ...')
  112. # Evaluate
  113. with torch.no_grad():
  114. self.evaluator.evaluate(model_eval)
  115. cur_map = self.evaluator.map
  116. if cur_map > self.best_map:
  117. # update best-map
  118. self.best_map = cur_map
  119. to_save = True
  120. # Save model
  121. if to_save:
  122. print('Saving state, epoch:', self.epoch)
  123. weight_name = '{}_best.pth'.format(self.args.model)
  124. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  125. state_dicts = {
  126. 'model': model_eval.state_dict(),
  127. 'mAP': round(cur_map*100, 1),
  128. 'optimizer': self.optimizer.state_dict(),
  129. 'epoch': self.epoch,
  130. 'args': self.args,
  131. }
  132. if self.model_ema is not None:
  133. state_dicts["ema_updates"] = self.model_ema.updates
  134. torch.save(state_dicts, checkpoint_path)
  135. if self.args.distributed:
  136. # wait for all processes to synchronize
  137. dist.barrier()
  138. # set train mode.
  139. model.train()
  140. def train_one_epoch(self, model):
  141. metric_logger = MetricLogger(delimiter=" ")
  142. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  143. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  144. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  145. epoch_size = len(self.train_loader)
  146. print_freq = 10
  147. # basic parameters
  148. epoch_size = len(self.train_loader)
  149. img_size = self.cfg.train_img_size
  150. nw = epoch_size * self.cfg.warmup_epoch
  151. # Train one epoch
  152. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  153. ni = iter_i + self.epoch * epoch_size
  154. # Warmup
  155. if ni <= nw:
  156. xi = [0, nw] # x interp
  157. for j, x in enumerate(self.optimizer.param_groups):
  158. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  159. x['lr'] = np.interp(
  160. ni, xi, [self.cfg.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  161. if 'momentum' in x:
  162. x['momentum'] = np.interp(ni, xi, [self.cfg.warmup_momentum, self.cfg.momentum])
  163. # To device
  164. images = images.to(self.device, non_blocking=True).float()
  165. # Multi scale
  166. images, targets, img_size = self.rescale_image_targets(
  167. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  168. # Visualize train targets
  169. if self.args.vis_tgt:
  170. vis_data(images,
  171. targets,
  172. self.cfg.num_classes,
  173. self.cfg.normalize_coords,
  174. self.train_transform.color_format,
  175. self.cfg.pixel_mean,
  176. self.cfg.pixel_std,
  177. self.cfg.box_format)
  178. # Inference
  179. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  180. outputs = model(images)
  181. # Compute loss
  182. loss_dict = self.criterion(outputs=outputs, targets=targets)
  183. losses = loss_dict['losses']
  184. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  185. # Backward
  186. self.scaler.scale(losses).backward()
  187. # Optimize
  188. if (iter_i + 1) % self.grad_accumulate == 0:
  189. if self.cfg.clip_max_norm > 0:
  190. self.scaler.unscale_(self.optimizer)
  191. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  192. self.scaler.step(self.optimizer)
  193. self.scaler.update()
  194. self.optimizer.zero_grad()
  195. # ModelEMA
  196. if self.model_ema is not None:
  197. self.model_ema.update(model)
  198. # Update log
  199. metric_logger.update(**loss_dict_reduced)
  200. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  201. metric_logger.update(size=img_size)
  202. if self.args.debug:
  203. print("For debug mode, we only train 1 iteration")
  204. break
  205. # Gather the stats from all processes
  206. metric_logger.synchronize_between_processes()
  207. print("Averaged stats:", metric_logger)
  208. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  209. """
  210. Deployed for Multi scale trick.
  211. """
  212. # During training phase, the shape of input image is square.
  213. old_img_size = images.shape[-1]
  214. min_img_size = old_img_size * multi_scale_range[0]
  215. max_img_size = old_img_size * multi_scale_range[1]
  216. # Choose a new image size
  217. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  218. # Resize
  219. if new_img_size != old_img_size:
  220. # interpolate
  221. images = torch.nn.functional.interpolate(
  222. input=images,
  223. size=new_img_size,
  224. mode='bilinear',
  225. align_corners=False)
  226. # rescale targets
  227. if not self.cfg.normalize_coords:
  228. for tgt in targets:
  229. boxes = tgt["boxes"].clone()
  230. labels = tgt["labels"].clone()
  231. boxes = torch.clamp(boxes, 0, old_img_size)
  232. # rescale box
  233. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  234. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  235. # refine tgt
  236. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  237. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  238. keep = (min_tgt_size >= 8)
  239. tgt["boxes"] = boxes[keep]
  240. tgt["labels"] = labels[keep]
  241. return images, targets, new_img_size
  242. def check_second_stage(self):
  243. # set second stage
  244. print('============== Second stage of Training ==============')
  245. self.second_stage = True
  246. self.heavy_eval = True
  247. # close mosaic augmentation
  248. if self.train_loader.dataset.mosaic_prob > 0.:
  249. print(' - Close < Mosaic Augmentation > ...')
  250. self.train_loader.dataset.mosaic_prob = 0.
  251. # close mixup augmentation
  252. if self.train_loader.dataset.mixup_prob > 0.:
  253. print(' - Close < Mixup Augmentation > ...')
  254. self.train_loader.dataset.mixup_prob = 0.
  255. # close copy-paste augmentation
  256. if self.train_loader.dataset.copy_paste > 0.:
  257. print(' - Close < Copy-paste Augmentation > ...')
  258. self.train_loader.dataset.copy_paste = 0.
  259. class RTDetrTrainer(object):
  260. def __init__(self,
  261. # Basic parameters
  262. args,
  263. cfg,
  264. device,
  265. # Model parameters
  266. model,
  267. model_ema,
  268. criterion,
  269. # Data parameters
  270. train_transform,
  271. val_transform,
  272. dataset,
  273. train_loader,
  274. evaluator,
  275. ):
  276. # ------------------- basic parameters -------------------
  277. self.args = args
  278. self.cfg = cfg
  279. self.epoch = 0
  280. self.best_map = -1.
  281. self.device = device
  282. self.criterion = criterion
  283. self.heavy_eval = False
  284. self.model_ema = model_ema
  285. # path to save model
  286. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  287. os.makedirs(self.path_to_save, exist_ok=True)
  288. # ---------------------------- Transform ----------------------------
  289. self.train_transform = train_transform
  290. self.val_transform = val_transform
  291. # ---------------------------- Dataset & Dataloader ----------------------------
  292. self.dataset = dataset
  293. self.train_loader = train_loader
  294. # ---------------------------- Evaluator ----------------------------
  295. self.evaluator = evaluator
  296. # ---------------------------- Build Grad. Scaler ----------------------------
  297. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  298. # ---------------------------- Build Optimizer ----------------------------
  299. cfg.base_lr = cfg.per_image_lr * args.batch_size
  300. cfg.min_lr = cfg.base_lr * cfg.min_lr_ratio
  301. self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
  302. # ---------------------------- Build LR Scheduler ----------------------------
  303. self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.warmup_iters, cfg.base_lr)
  304. self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
  305. def train(self, model):
  306. for epoch in range(self.start_epoch, self.cfg.max_epoch):
  307. if self.args.distributed:
  308. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  309. # train one epoch
  310. self.epoch = epoch
  311. self.train_one_epoch(model)
  312. # LR Scheduler
  313. self.lr_scheduler.step()
  314. # eval one epoch
  315. if self.heavy_eval:
  316. model_eval = model.module if self.args.distributed else model
  317. self.eval(model_eval)
  318. else:
  319. model_eval = model.module if self.args.distributed else model
  320. if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
  321. self.eval(model_eval)
  322. if self.args.debug:
  323. print("For debug mode, we only train 1 epoch")
  324. break
  325. def eval(self, model):
  326. # set eval mode
  327. model.eval()
  328. model_eval = model if self.model_ema is None else self.model_ema.ema
  329. cur_map = -1.
  330. to_save = False
  331. if distributed_utils.is_main_process():
  332. if self.evaluator is None:
  333. print('No evaluator ... save model and go on training.')
  334. to_save = True
  335. weight_name = '{}_no_eval.pth'.format(self.args.model)
  336. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  337. else:
  338. print('Eval ...')
  339. # Evaluate
  340. with torch.no_grad():
  341. self.evaluator.evaluate(model_eval)
  342. cur_map = self.evaluator.map
  343. if cur_map > self.best_map:
  344. # update best-map
  345. self.best_map = cur_map
  346. to_save = True
  347. # Save model
  348. if to_save:
  349. print('Saving state, epoch:', self.epoch)
  350. weight_name = '{}_best.pth'.format(self.args.model)
  351. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  352. state_dicts = {
  353. 'model': model_eval.state_dict(),
  354. 'mAP': round(cur_map*100, 1),
  355. 'optimizer': self.optimizer.state_dict(),
  356. 'lr_scheduler': self.lr_scheduler.state_dict(),
  357. 'epoch': self.epoch,
  358. 'args': self.args,
  359. }
  360. if self.model_ema is not None:
  361. state_dicts["ema_updates"] = self.model_ema.updates
  362. torch.save(state_dicts, checkpoint_path)
  363. if self.args.distributed:
  364. # wait for all processes to synchronize
  365. dist.barrier()
  366. # set train mode.
  367. model.train()
  368. def train_one_epoch(self, model):
  369. metric_logger = MetricLogger(delimiter=" ")
  370. metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
  371. metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
  372. metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
  373. header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
  374. epoch_size = len(self.train_loader)
  375. print_freq = 10
  376. # basic parameters
  377. epoch_size = len(self.train_loader)
  378. img_size = self.cfg.train_img_size
  379. nw = self.cfg.warmup_iters
  380. lr_warmup_stage = True
  381. # Train one epoch
  382. for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
  383. ni = iter_i + self.epoch * epoch_size
  384. # WarmUp
  385. if ni < nw and lr_warmup_stage:
  386. self.wp_lr_scheduler(ni, self.optimizer)
  387. elif ni == nw and lr_warmup_stage:
  388. print('Warmup stage is over.')
  389. lr_warmup_stage = False
  390. self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
  391. # To device
  392. images = images.to(self.device, non_blocking=True).float()
  393. for tgt in targets:
  394. tgt['boxes'] = tgt['boxes'].to(self.device)
  395. tgt['labels'] = tgt['labels'].to(self.device)
  396. # Multi scale
  397. images, targets, img_size = self.rescale_image_targets(
  398. images, targets, self.cfg.max_stride, self.cfg.multi_scale)
  399. # Visualize train targets
  400. if self.args.vis_tgt:
  401. vis_data(images,
  402. targets,
  403. self.cfg.num_classes,
  404. self.cfg.normalize_coords,
  405. self.train_transform.color_format,
  406. self.cfg.pixel_mean,
  407. self.cfg.pixel_std,
  408. self.cfg.box_format)
  409. # Inference
  410. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  411. outputs = model(images, targets)
  412. loss_dict = self.criterion(outputs, targets)
  413. losses = sum(loss_dict.values())
  414. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  415. # Backward
  416. self.scaler.scale(losses).backward()
  417. # Optimize
  418. if self.cfg.clip_max_norm > 0:
  419. self.scaler.unscale_(self.optimizer)
  420. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
  421. self.scaler.step(self.optimizer)
  422. self.scaler.update()
  423. self.optimizer.zero_grad()
  424. # ModelEMA
  425. if self.model_ema is not None:
  426. self.model_ema.update(model)
  427. # Update log
  428. metric_logger.update(**loss_dict_reduced)
  429. metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
  430. metric_logger.update(size=img_size)
  431. if self.args.debug:
  432. print("For debug mode, we only train 1 iteration")
  433. break
  434. def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
  435. """
  436. Deployed for Multi scale trick.
  437. """
  438. # During training phase, the shape of input image is square.
  439. old_img_size = images.shape[-1]
  440. min_img_size = old_img_size * multi_scale_range[0]
  441. max_img_size = old_img_size * multi_scale_range[1]
  442. # Choose a new image size
  443. new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
  444. # Resize
  445. if new_img_size != old_img_size:
  446. # interpolate
  447. images = torch.nn.functional.interpolate(
  448. input=images,
  449. size=new_img_size,
  450. mode='bilinear',
  451. align_corners=False)
  452. return images, targets, new_img_size
  453. # Build Trainer
  454. def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
  455. # ----------------------- Det trainers -----------------------
  456. if cfg.trainer == 'yolo':
  457. return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  458. elif cfg.trainer == 'rtdetr':
  459. return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  460. else:
  461. raise NotImplementedError(cfg.trainer)