engine.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # YOLOv8-style Trainer
  19. class Yolov8Trainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.no_aug_epoch = args.no_aug_epoch
  27. self.clip_grad = 10
  28. self.device = device
  29. self.criterion = criterion
  30. self.world_size = world_size
  31. self.heavy_eval = False
  32. self.second_stage = False
  33. """
  34. The below hyperparameters refer to YOLOv8.
  35. """
  36. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  37. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  38. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  39. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  40. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  41. self.data_cfg = data_cfg
  42. self.model_cfg = model_cfg
  43. self.trans_cfg = trans_cfg
  44. # ---------------------------- Build Transform ----------------------------
  45. self.train_transform, self.trans_cfg = build_transform(
  46. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  47. self.val_transform, _ = build_transform(
  48. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  49. # ---------------------------- Build Dataset & Dataloader ----------------------------
  50. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  51. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  52. # ---------------------------- Build Evaluator ----------------------------
  53. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  54. # ---------------------------- Build Grad. Scaler ----------------------------
  55. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  56. # ---------------------------- Build Optimizer ----------------------------
  57. accumulate = max(1, round(64 / self.args.batch_size))
  58. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  59. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  60. # ---------------------------- Build LR Scheduler ----------------------------
  61. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  62. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  63. if self.args.resume:
  64. self.lr_scheduler.step()
  65. # ---------------------------- Build Model-EMA ----------------------------
  66. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  67. print('Build ModelEMA ...')
  68. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  69. else:
  70. self.model_ema = None
  71. def check_second_stage(self):
  72. # set second stage
  73. print('============== Second stage of Training ==============')
  74. self.second_stage = True
  75. # close mosaic augmentation
  76. if self.train_loader.dataset.mosaic_prob > 0.:
  77. print(' - Close < Mosaic Augmentation > ...')
  78. self.train_loader.dataset.mosaic_prob = 0.
  79. self.heavy_eval = True
  80. # close mixup augmentation
  81. if self.train_loader.dataset.mixup_prob > 0.:
  82. print(' - Close < Mixup Augmentation > ...')
  83. self.train_loader.dataset.mixup_prob = 0.
  84. self.heavy_eval = True
  85. # close rotation augmentation
  86. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  87. print(' - Close < degress of rotation > ...')
  88. self.trans_cfg['degrees'] = 0.0
  89. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  90. print(' - Close < shear of rotation >...')
  91. self.trans_cfg['shear'] = 0.0
  92. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  93. print(' - Close < perspective of rotation > ...')
  94. self.trans_cfg['perspective'] = 0.0
  95. # build a new transform for second stage
  96. print(' - Rebuild transforms ...')
  97. self.train_transform, self.trans_cfg = build_transform(
  98. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  99. self.train_loader.dataset.transform = self.train_transform
  100. def train(self, model):
  101. for epoch in range(self.start_epoch, self.args.max_epoch):
  102. if self.args.distributed:
  103. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  104. # check second stage
  105. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  106. self.check_second_stage()
  107. # train one epoch
  108. self.epoch = epoch
  109. self.train_one_epoch(model)
  110. # eval one epoch
  111. if self.heavy_eval:
  112. model_eval = model.module if self.args.distributed else model
  113. self.eval(model_eval)
  114. else:
  115. model_eval = model.module if self.args.distributed else model
  116. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  117. self.eval(model_eval)
  118. def eval(self, model):
  119. # chech model
  120. model_eval = model if self.model_ema is None else self.model_ema.ema
  121. # path to save model
  122. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  123. os.makedirs(path_to_save, exist_ok=True)
  124. if distributed_utils.is_main_process():
  125. # check evaluator
  126. if self.evaluator is None:
  127. print('No evaluator ... save model and go on training.')
  128. print('Saving state, epoch: {}'.format(self.epoch + 1))
  129. weight_name = '{}_no_eval.pth'.format(self.args.model)
  130. checkpoint_path = os.path.join(path_to_save, weight_name)
  131. torch.save({'model': model_eval.state_dict(),
  132. 'mAP': -1.,
  133. 'optimizer': self.optimizer.state_dict(),
  134. 'epoch': self.epoch,
  135. 'args': self.args},
  136. checkpoint_path)
  137. else:
  138. print('eval ...')
  139. # set eval mode
  140. model_eval.trainable = False
  141. model_eval.eval()
  142. # evaluate
  143. with torch.no_grad():
  144. self.evaluator.evaluate(model_eval)
  145. # save model
  146. cur_map = self.evaluator.map
  147. if cur_map > self.best_map:
  148. # update best-map
  149. self.best_map = cur_map
  150. # save model
  151. print('Saving state, epoch:', self.epoch + 1)
  152. weight_name = '{}_best.pth'.format(self.args.model)
  153. checkpoint_path = os.path.join(path_to_save, weight_name)
  154. torch.save({'model': model_eval.state_dict(),
  155. 'mAP': round(self.best_map*100, 1),
  156. 'optimizer': self.optimizer.state_dict(),
  157. 'epoch': self.epoch,
  158. 'args': self.args},
  159. checkpoint_path)
  160. # set train mode.
  161. model_eval.trainable = True
  162. model_eval.train()
  163. if self.args.distributed:
  164. # wait for all processes to synchronize
  165. dist.barrier()
  166. def train_one_epoch(self, model):
  167. # basic parameters
  168. epoch_size = len(self.train_loader)
  169. img_size = self.args.img_size
  170. t0 = time.time()
  171. nw = epoch_size * self.args.wp_epoch
  172. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  173. # train one epoch
  174. for iter_i, (images, targets) in enumerate(self.train_loader):
  175. ni = iter_i + self.epoch * epoch_size
  176. # Warmup
  177. if ni <= nw:
  178. xi = [0, nw] # x interp
  179. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  180. for j, x in enumerate(self.optimizer.param_groups):
  181. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  182. x['lr'] = np.interp(
  183. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  184. if 'momentum' in x:
  185. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  186. # to device
  187. images = images.to(self.device, non_blocking=True).float() / 255.
  188. # Multi scale
  189. if self.args.multi_scale:
  190. images, targets, img_size = self.rescale_image_targets(
  191. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  192. else:
  193. targets = self.refine_targets(targets, self.args.min_box_size)
  194. # visualize train targets
  195. if self.args.vis_tgt:
  196. vis_data(images*255, targets)
  197. # inference
  198. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  199. outputs = model(images)
  200. # loss
  201. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  202. losses = loss_dict['losses']
  203. losses *= images.shape[0] # loss * bs
  204. # reduce
  205. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  206. # gradient averaged between devices in DDP mode
  207. losses *= distributed_utils.get_world_size()
  208. # backward
  209. self.scaler.scale(losses).backward()
  210. # Optimize
  211. if ni - self.last_opt_step >= accumulate:
  212. if self.clip_grad > 0:
  213. # unscale gradients
  214. self.scaler.unscale_(self.optimizer)
  215. # clip gradients
  216. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  217. # optimizer.step
  218. self.scaler.step(self.optimizer)
  219. self.scaler.update()
  220. self.optimizer.zero_grad()
  221. # ema
  222. if self.model_ema is not None:
  223. self.model_ema.update(model)
  224. self.last_opt_step = ni
  225. # display
  226. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  227. t1 = time.time()
  228. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  229. # basic infor
  230. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  231. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  232. log += '[lr: {:.6f}]'.format(cur_lr[2])
  233. # loss infor
  234. for k in loss_dict_reduced.keys():
  235. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  236. # other infor
  237. log += '[time: {:.2f}]'.format(t1 - t0)
  238. log += '[size: {}]'.format(img_size)
  239. # print log infor
  240. print(log, flush=True)
  241. t0 = time.time()
  242. self.lr_scheduler.step()
  243. def refine_targets(self, targets, min_box_size):
  244. # rescale targets
  245. for tgt in targets:
  246. boxes = tgt["boxes"].clone()
  247. labels = tgt["labels"].clone()
  248. # refine tgt
  249. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  250. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  251. keep = (min_tgt_size >= min_box_size)
  252. tgt["boxes"] = boxes[keep]
  253. tgt["labels"] = labels[keep]
  254. return targets
  255. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  256. """
  257. Deployed for Multi scale trick.
  258. """
  259. if isinstance(stride, int):
  260. max_stride = stride
  261. elif isinstance(stride, list):
  262. max_stride = max(stride)
  263. # During training phase, the shape of input image is square.
  264. old_img_size = images.shape[-1]
  265. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  266. new_img_size = new_img_size // max_stride * max_stride # size
  267. if new_img_size / old_img_size != 1:
  268. # interpolate
  269. images = torch.nn.functional.interpolate(
  270. input=images,
  271. size=new_img_size,
  272. mode='bilinear',
  273. align_corners=False)
  274. # rescale targets
  275. for tgt in targets:
  276. boxes = tgt["boxes"].clone()
  277. labels = tgt["labels"].clone()
  278. boxes = torch.clamp(boxes, 0, old_img_size)
  279. # rescale box
  280. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  281. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  282. # refine tgt
  283. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  284. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  285. keep = (min_tgt_size >= min_box_size)
  286. tgt["boxes"] = boxes[keep]
  287. tgt["labels"] = labels[keep]
  288. return images, targets, new_img_size
  289. # YOLOX-syle Trainer
  290. class YoloxTrainer(object):
  291. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  292. # ------------------- basic parameters -------------------
  293. self.args = args
  294. self.epoch = 0
  295. self.best_map = -1.
  296. self.device = device
  297. self.criterion = criterion
  298. self.world_size = world_size
  299. self.no_aug_epoch = args.no_aug_epoch
  300. self.heavy_eval = False
  301. self.second_stage = False
  302. """
  303. The below hyperparameters refer to YOLOX: https://github.com/open-mmlab/mmyolo/tree/main/configs/rtmdet.
  304. """
  305. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.9, 'weight_decay': 5e-4, 'lr0': 0.01}
  306. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  307. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  308. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  309. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  310. self.data_cfg = data_cfg
  311. self.model_cfg = model_cfg
  312. self.trans_cfg = trans_cfg
  313. # ---------------------------- Build Transform ----------------------------
  314. self.train_transform, self.trans_cfg = build_transform(
  315. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  316. self.val_transform, _ = build_transform(
  317. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  318. # ---------------------------- Build Dataset & Dataloader ----------------------------
  319. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  320. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  321. # ---------------------------- Build Evaluator ----------------------------
  322. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  323. # ---------------------------- Build Grad. Scaler ----------------------------
  324. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  325. # ---------------------------- Build Optimizer ----------------------------
  326. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  327. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  328. # ---------------------------- Build LR Scheduler ----------------------------
  329. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  330. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  331. if self.args.resume:
  332. self.lr_scheduler.step()
  333. # ---------------------------- Build Model-EMA ----------------------------
  334. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  335. print('Build ModelEMA ...')
  336. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  337. else:
  338. self.model_ema = None
  339. def train(self, model):
  340. for epoch in range(self.start_epoch, self.args.max_epoch):
  341. if self.args.distributed:
  342. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  343. # check second stage
  344. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  345. self.check_second_stage()
  346. # train one epoch
  347. self.epoch = epoch
  348. self.train_one_epoch(model)
  349. # eval one epoch
  350. if self.heavy_eval:
  351. model_eval = model.module if self.args.distributed else model
  352. self.eval(model_eval)
  353. else:
  354. model_eval = model.module if self.args.distributed else model
  355. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  356. self.eval(model_eval)
  357. def eval(self, model):
  358. # chech model
  359. model_eval = model if self.model_ema is None else self.model_ema.ema
  360. # path to save model
  361. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  362. os.makedirs(path_to_save, exist_ok=True)
  363. if distributed_utils.is_main_process():
  364. # check evaluator
  365. if self.evaluator is None:
  366. print('No evaluator ... save model and go on training.')
  367. print('Saving state, epoch: {}'.format(self.epoch + 1))
  368. weight_name = '{}_no_eval.pth'.format(self.args.model)
  369. checkpoint_path = os.path.join(path_to_save, weight_name)
  370. torch.save({'model': model_eval.state_dict(),
  371. 'mAP': -1.,
  372. 'optimizer': self.optimizer.state_dict(),
  373. 'epoch': self.epoch,
  374. 'args': self.args},
  375. checkpoint_path)
  376. else:
  377. print('eval ...')
  378. # set eval mode
  379. model_eval.trainable = False
  380. model_eval.eval()
  381. # evaluate
  382. with torch.no_grad():
  383. self.evaluator.evaluate(model_eval)
  384. # save model
  385. cur_map = self.evaluator.map
  386. if cur_map > self.best_map:
  387. # update best-map
  388. self.best_map = cur_map
  389. # save model
  390. print('Saving state, epoch:', self.epoch + 1)
  391. weight_name = '{}_best.pth'.format(self.args.model)
  392. checkpoint_path = os.path.join(path_to_save, weight_name)
  393. torch.save({'model': model_eval.state_dict(),
  394. 'mAP': round(self.best_map*100, 1),
  395. 'optimizer': self.optimizer.state_dict(),
  396. 'epoch': self.epoch,
  397. 'args': self.args},
  398. checkpoint_path)
  399. # set train mode.
  400. model_eval.trainable = True
  401. model_eval.train()
  402. if self.args.distributed:
  403. # wait for all processes to synchronize
  404. dist.barrier()
  405. def train_one_epoch(self, model):
  406. # basic parameters
  407. epoch_size = len(self.train_loader)
  408. img_size = self.args.img_size
  409. t0 = time.time()
  410. nw = epoch_size * self.args.wp_epoch
  411. # Train one epoch
  412. for iter_i, (images, targets) in enumerate(self.train_loader):
  413. ni = iter_i + self.epoch * epoch_size
  414. # Warmup
  415. if ni <= nw:
  416. xi = [0, nw] # x interp
  417. for j, x in enumerate(self.optimizer.param_groups):
  418. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  419. x['lr'] = np.interp(
  420. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  421. if 'momentum' in x:
  422. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  423. # To device
  424. images = images.to(self.device, non_blocking=True).float() / 255.
  425. # Multi scale
  426. if self.args.multi_scale:
  427. images, targets, img_size = self.rescale_image_targets(
  428. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  429. else:
  430. targets = self.refine_targets(targets, self.args.min_box_size)
  431. # Visualize train targets
  432. if self.args.vis_tgt:
  433. vis_data(images*255, targets)
  434. # Inference
  435. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  436. outputs = model(images)
  437. # Compute loss
  438. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  439. losses = loss_dict['losses']
  440. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  441. # Backward
  442. self.scaler.scale(losses).backward()
  443. # Optimize
  444. self.scaler.step(self.optimizer)
  445. self.scaler.update()
  446. self.optimizer.zero_grad()
  447. # ema
  448. if self.model_ema is not None:
  449. self.model_ema.update(model)
  450. # Logs
  451. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  452. t1 = time.time()
  453. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  454. # basic infor
  455. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  456. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  457. log += '[lr: {:.6f}]'.format(cur_lr[2])
  458. # loss infor
  459. for k in loss_dict_reduced.keys():
  460. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  461. # other infor
  462. log += '[time: {:.2f}]'.format(t1 - t0)
  463. log += '[size: {}]'.format(img_size)
  464. # print log infor
  465. print(log, flush=True)
  466. t0 = time.time()
  467. # LR Schedule
  468. self.lr_scheduler.step()
  469. def check_second_stage(self):
  470. # set second stage
  471. print('============== Second stage of Training ==============')
  472. self.second_stage = True
  473. # close mosaic augmentation
  474. if self.train_loader.dataset.mosaic_prob > 0.:
  475. print(' - Close < Mosaic Augmentation > ...')
  476. self.train_loader.dataset.mosaic_prob = 0.
  477. self.heavy_eval = True
  478. # close mixup augmentation
  479. if self.train_loader.dataset.mixup_prob > 0.:
  480. print(' - Close < Mixup Augmentation > ...')
  481. self.train_loader.dataset.mixup_prob = 0.
  482. self.heavy_eval = True
  483. # close rotation augmentation
  484. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  485. print(' - Close < degress of rotation > ...')
  486. self.trans_cfg['degrees'] = 0.0
  487. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  488. print(' - Close < shear of rotation >...')
  489. self.trans_cfg['shear'] = 0.0
  490. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  491. print(' - Close < perspective of rotation > ...')
  492. self.trans_cfg['perspective'] = 0.0
  493. # build a new transform for second stage
  494. print(' - Rebuild transforms ...')
  495. self.train_transform, self.trans_cfg = build_transform(
  496. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  497. self.train_loader.dataset.transform = self.train_transform
  498. def refine_targets(self, targets, min_box_size):
  499. # rescale targets
  500. for tgt in targets:
  501. boxes = tgt["boxes"].clone()
  502. labels = tgt["labels"].clone()
  503. # refine tgt
  504. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  505. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  506. keep = (min_tgt_size >= min_box_size)
  507. tgt["boxes"] = boxes[keep]
  508. tgt["labels"] = labels[keep]
  509. return targets
  510. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  511. """
  512. Deployed for Multi scale trick.
  513. """
  514. if isinstance(stride, int):
  515. max_stride = stride
  516. elif isinstance(stride, list):
  517. max_stride = max(stride)
  518. # During training phase, the shape of input image is square.
  519. old_img_size = images.shape[-1]
  520. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  521. new_img_size = new_img_size // max_stride * max_stride # size
  522. if new_img_size / old_img_size != 1:
  523. # interpolate
  524. images = torch.nn.functional.interpolate(
  525. input=images,
  526. size=new_img_size,
  527. mode='bilinear',
  528. align_corners=False)
  529. # rescale targets
  530. for tgt in targets:
  531. boxes = tgt["boxes"].clone()
  532. labels = tgt["labels"].clone()
  533. boxes = torch.clamp(boxes, 0, old_img_size)
  534. # rescale box
  535. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  536. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  537. # refine tgt
  538. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  539. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  540. keep = (min_tgt_size >= min_box_size)
  541. tgt["boxes"] = boxes[keep]
  542. tgt["labels"] = labels[keep]
  543. return images, targets, new_img_size
  544. # RTMDet-syle Trainer
  545. class RTMTrainer(object):
  546. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  547. # ------------------- basic parameters -------------------
  548. self.args = args
  549. self.epoch = 0
  550. self.best_map = -1.
  551. self.device = device
  552. self.criterion = criterion
  553. self.world_size = world_size
  554. self.no_aug_epoch = args.no_aug_epoch
  555. self.clip_grad = 35
  556. self.heavy_eval = False
  557. self.second_stage = False
  558. """
  559. The below hyperparameters refer to RTMDet: https://github.com/open-mmlab/mmyolo/tree/main/configs/rtmdet.
  560. """
  561. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  562. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  563. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  564. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  565. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  566. self.data_cfg = data_cfg
  567. self.model_cfg = model_cfg
  568. self.trans_cfg = trans_cfg
  569. # ---------------------------- Build Transform ----------------------------
  570. self.train_transform, self.trans_cfg = build_transform(
  571. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  572. self.val_transform, _ = build_transform(
  573. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  574. # ---------------------------- Build Dataset & Dataloader ----------------------------
  575. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  576. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  577. # ---------------------------- Build Evaluator ----------------------------
  578. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  579. # ---------------------------- Build Grad. Scaler ----------------------------
  580. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  581. # ---------------------------- Build Optimizer ----------------------------
  582. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  583. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  584. # ---------------------------- Build LR Scheduler ----------------------------
  585. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  586. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  587. if self.args.resume:
  588. self.lr_scheduler.step()
  589. # ---------------------------- Build Model-EMA ----------------------------
  590. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  591. print('Build ModelEMA ...')
  592. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  593. else:
  594. self.model_ema = None
  595. def train(self, model):
  596. for epoch in range(self.start_epoch, self.args.max_epoch):
  597. if self.args.distributed:
  598. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  599. # check second stage
  600. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  601. self.check_second_stage()
  602. # train one epoch
  603. self.epoch = epoch
  604. self.train_one_epoch(model)
  605. # eval one epoch
  606. if self.heavy_eval:
  607. model_eval = model.module if self.args.distributed else model
  608. self.eval(model_eval)
  609. else:
  610. model_eval = model.module if self.args.distributed else model
  611. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  612. self.eval(model_eval)
  613. def eval(self, model):
  614. # chech model
  615. model_eval = model if self.model_ema is None else self.model_ema.ema
  616. # path to save model
  617. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  618. os.makedirs(path_to_save, exist_ok=True)
  619. if distributed_utils.is_main_process():
  620. # check evaluator
  621. if self.evaluator is None:
  622. print('No evaluator ... save model and go on training.')
  623. print('Saving state, epoch: {}'.format(self.epoch + 1))
  624. weight_name = '{}_no_eval.pth'.format(self.args.model)
  625. checkpoint_path = os.path.join(path_to_save, weight_name)
  626. torch.save({'model': model_eval.state_dict(),
  627. 'mAP': -1.,
  628. 'optimizer': self.optimizer.state_dict(),
  629. 'epoch': self.epoch,
  630. 'args': self.args},
  631. checkpoint_path)
  632. else:
  633. print('eval ...')
  634. # set eval mode
  635. model_eval.trainable = False
  636. model_eval.eval()
  637. # evaluate
  638. with torch.no_grad():
  639. self.evaluator.evaluate(model_eval)
  640. # save model
  641. cur_map = self.evaluator.map
  642. if cur_map > self.best_map:
  643. # update best-map
  644. self.best_map = cur_map
  645. # save model
  646. print('Saving state, epoch:', self.epoch + 1)
  647. weight_name = '{}_best.pth'.format(self.args.model)
  648. checkpoint_path = os.path.join(path_to_save, weight_name)
  649. torch.save({'model': model_eval.state_dict(),
  650. 'mAP': round(self.best_map*100, 1),
  651. 'optimizer': self.optimizer.state_dict(),
  652. 'epoch': self.epoch,
  653. 'args': self.args},
  654. checkpoint_path)
  655. # set train mode.
  656. model_eval.trainable = True
  657. model_eval.train()
  658. if self.args.distributed:
  659. # wait for all processes to synchronize
  660. dist.barrier()
  661. def train_one_epoch(self, model):
  662. # basic parameters
  663. epoch_size = len(self.train_loader)
  664. img_size = self.args.img_size
  665. t0 = time.time()
  666. nw = epoch_size * self.args.wp_epoch
  667. # Train one epoch
  668. for iter_i, (images, targets) in enumerate(self.train_loader):
  669. ni = iter_i + self.epoch * epoch_size
  670. # Warmup
  671. if ni <= nw:
  672. xi = [0, nw] # x interp
  673. for j, x in enumerate(self.optimizer.param_groups):
  674. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  675. x['lr'] = np.interp(
  676. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  677. if 'momentum' in x:
  678. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  679. # To device
  680. images = images.to(self.device, non_blocking=True).float() / 255.
  681. # Multi scale
  682. if self.args.multi_scale:
  683. images, targets, img_size = self.rescale_image_targets(
  684. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  685. else:
  686. targets = self.refine_targets(targets, self.args.min_box_size)
  687. # Visualize train targets
  688. if self.args.vis_tgt:
  689. vis_data(images*255, targets)
  690. # Inference
  691. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  692. outputs = model(images)
  693. # Compute loss
  694. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  695. losses = loss_dict['losses']
  696. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  697. # Backward
  698. self.scaler.scale(losses).backward()
  699. # Optimize
  700. if self.clip_grad > 0:
  701. # unscale gradients
  702. self.scaler.unscale_(self.optimizer)
  703. # clip gradients
  704. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  705. # optimizer.step
  706. self.scaler.step(self.optimizer)
  707. self.scaler.update()
  708. self.optimizer.zero_grad()
  709. # ema
  710. if self.model_ema is not None:
  711. self.model_ema.update(model)
  712. # Logs
  713. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  714. t1 = time.time()
  715. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  716. # basic infor
  717. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  718. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  719. log += '[lr: {:.6f}]'.format(cur_lr[2])
  720. # loss infor
  721. for k in loss_dict_reduced.keys():
  722. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  723. # other infor
  724. log += '[time: {:.2f}]'.format(t1 - t0)
  725. log += '[size: {}]'.format(img_size)
  726. # print log infor
  727. print(log, flush=True)
  728. t0 = time.time()
  729. # LR Schedule
  730. self.lr_scheduler.step()
  731. def check_second_stage(self):
  732. # set second stage
  733. print('============== Second stage of Training ==============')
  734. self.second_stage = True
  735. # close mosaic augmentation
  736. if self.train_loader.dataset.mosaic_prob > 0.:
  737. print(' - Close < Mosaic Augmentation > ...')
  738. self.train_loader.dataset.mosaic_prob = 0.
  739. self.heavy_eval = True
  740. # close mixup augmentation
  741. if self.train_loader.dataset.mixup_prob > 0.:
  742. print(' - Close < Mixup Augmentation > ...')
  743. self.train_loader.dataset.mixup_prob = 0.
  744. self.heavy_eval = True
  745. # close rotation augmentation
  746. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  747. print(' - Close < degress of rotation > ...')
  748. self.trans_cfg['degrees'] = 0.0
  749. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  750. print(' - Close < shear of rotation >...')
  751. self.trans_cfg['shear'] = 0.0
  752. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  753. print(' - Close < perspective of rotation > ...')
  754. self.trans_cfg['perspective'] = 0.0
  755. # build a new transform for second stage
  756. print(' - Rebuild transforms ...')
  757. self.train_transform, self.trans_cfg = build_transform(
  758. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  759. self.train_loader.dataset.transform = self.train_transform
  760. def refine_targets(self, targets, min_box_size):
  761. # rescale targets
  762. for tgt in targets:
  763. boxes = tgt["boxes"].clone()
  764. labels = tgt["labels"].clone()
  765. # refine tgt
  766. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  767. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  768. keep = (min_tgt_size >= min_box_size)
  769. tgt["boxes"] = boxes[keep]
  770. tgt["labels"] = labels[keep]
  771. return targets
  772. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  773. """
  774. Deployed for Multi scale trick.
  775. """
  776. if isinstance(stride, int):
  777. max_stride = stride
  778. elif isinstance(stride, list):
  779. max_stride = max(stride)
  780. # During training phase, the shape of input image is square.
  781. old_img_size = images.shape[-1]
  782. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  783. new_img_size = new_img_size // max_stride * max_stride # size
  784. if new_img_size / old_img_size != 1:
  785. # interpolate
  786. images = torch.nn.functional.interpolate(
  787. input=images,
  788. size=new_img_size,
  789. mode='bilinear',
  790. align_corners=False)
  791. # rescale targets
  792. for tgt in targets:
  793. boxes = tgt["boxes"].clone()
  794. labels = tgt["labels"].clone()
  795. boxes = torch.clamp(boxes, 0, old_img_size)
  796. # rescale box
  797. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  798. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  799. # refine tgt
  800. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  801. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  802. keep = (min_tgt_size >= min_box_size)
  803. tgt["boxes"] = boxes[keep]
  804. tgt["labels"] = labels[keep]
  805. return images, targets, new_img_size
  806. # Trainer for DETR
  807. class DetrTrainer(object):
  808. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  809. # ------------------- basic parameters -------------------
  810. self.args = args
  811. self.epoch = 0
  812. self.best_map = -1.
  813. self.last_opt_step = 0
  814. self.no_aug_epoch = args.no_aug_epoch
  815. self.clip_grad = -1
  816. self.device = device
  817. self.criterion = criterion
  818. self.world_size = world_size
  819. self.second_stage = False
  820. self.heavy_eval = False
  821. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.001, 'backbone_lr_raio': 0.1}
  822. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  823. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  824. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  825. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  826. self.data_cfg = data_cfg
  827. self.model_cfg = model_cfg
  828. self.trans_cfg = trans_cfg
  829. # ---------------------------- Build Transform ----------------------------
  830. self.train_transform, self.trans_cfg = build_transform(
  831. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  832. self.val_transform, _ = build_transform(
  833. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  834. # ---------------------------- Build Dataset & Dataloader ----------------------------
  835. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  836. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  837. # ---------------------------- Build Evaluator ----------------------------
  838. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  839. # ---------------------------- Build Grad. Scaler ----------------------------
  840. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  841. # ---------------------------- Build Optimizer ----------------------------
  842. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  843. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  844. # ---------------------------- Build LR Scheduler ----------------------------
  845. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  846. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  847. if self.args.resume:
  848. self.lr_scheduler.step()
  849. # ---------------------------- Build Model-EMA ----------------------------
  850. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  851. print('Build ModelEMA ...')
  852. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  853. else:
  854. self.model_ema = None
  855. def check_second_stage(self):
  856. # set second stage
  857. print('============== Second stage of Training ==============')
  858. self.second_stage = True
  859. # close mosaic augmentation
  860. if self.train_loader.dataset.mosaic_prob > 0.:
  861. print(' - Close < Mosaic Augmentation > ...')
  862. self.train_loader.dataset.mosaic_prob = 0.
  863. self.heavy_eval = True
  864. # close mixup augmentation
  865. if self.train_loader.dataset.mixup_prob > 0.:
  866. print(' - Close < Mixup Augmentation > ...')
  867. self.train_loader.dataset.mixup_prob = 0.
  868. self.heavy_eval = True
  869. # close rotation augmentation
  870. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  871. print(' - Close < degress of rotation > ...')
  872. self.trans_cfg['degrees'] = 0.0
  873. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  874. print(' - Close < shear of rotation >...')
  875. self.trans_cfg['shear'] = 0.0
  876. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  877. print(' - Close < perspective of rotation > ...')
  878. self.trans_cfg['perspective'] = 0.0
  879. # build a new transform for second stage
  880. print(' - Rebuild transforms ...')
  881. self.train_transform, self.trans_cfg = build_transform(
  882. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  883. self.train_loader.dataset.transform = self.train_transform
  884. def train(self, model):
  885. for epoch in range(self.start_epoch, self.args.max_epoch):
  886. if self.args.distributed:
  887. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  888. # check second stage
  889. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  890. self.check_second_stage()
  891. # train one epoch
  892. self.epoch = epoch
  893. self.train_one_epoch(model)
  894. # eval one epoch
  895. if self.heavy_eval:
  896. model_eval = model.module if self.args.distributed else model
  897. self.eval(model_eval)
  898. else:
  899. model_eval = model.module if self.args.distributed else model
  900. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  901. self.eval(model_eval)
  902. def eval(self, model):
  903. # chech model
  904. model_eval = model if self.model_ema is None else self.model_ema.ema
  905. # path to save model
  906. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  907. os.makedirs(path_to_save, exist_ok=True)
  908. if distributed_utils.is_main_process():
  909. # check evaluator
  910. if self.evaluator is None:
  911. print('No evaluator ... save model and go on training.')
  912. print('Saving state, epoch: {}'.format(self.epoch + 1))
  913. weight_name = '{}_no_eval.pth'.format(self.args.model)
  914. checkpoint_path = os.path.join(path_to_save, weight_name)
  915. torch.save({'model': model_eval.state_dict(),
  916. 'mAP': -1.,
  917. 'optimizer': self.optimizer.state_dict(),
  918. 'epoch': self.epoch,
  919. 'args': self.args},
  920. checkpoint_path)
  921. else:
  922. print('eval ...')
  923. # set eval mode
  924. model_eval.trainable = False
  925. model_eval.eval()
  926. # evaluate
  927. with torch.no_grad():
  928. self.evaluator.evaluate(model_eval)
  929. # save model
  930. cur_map = self.evaluator.map
  931. if cur_map > self.best_map:
  932. # update best-map
  933. self.best_map = cur_map
  934. # save model
  935. print('Saving state, epoch:', self.epoch + 1)
  936. weight_name = '{}_best.pth'.format(self.args.model)
  937. checkpoint_path = os.path.join(path_to_save, weight_name)
  938. torch.save({'model': model_eval.state_dict(),
  939. 'mAP': round(self.best_map*100, 1),
  940. 'optimizer': self.optimizer.state_dict(),
  941. 'epoch': self.epoch,
  942. 'args': self.args},
  943. checkpoint_path)
  944. # set train mode.
  945. model_eval.trainable = True
  946. model_eval.train()
  947. if self.args.distributed:
  948. # wait for all processes to synchronize
  949. dist.barrier()
  950. def train_one_epoch(self, model):
  951. # basic parameters
  952. epoch_size = len(self.train_loader)
  953. img_size = self.args.img_size
  954. t0 = time.time()
  955. nw = epoch_size * self.args.wp_epoch
  956. # train one epoch
  957. for iter_i, (images, targets) in enumerate(self.train_loader):
  958. ni = iter_i + self.epoch * epoch_size
  959. # Warmup
  960. if ni <= nw:
  961. xi = [0, nw] # x interp
  962. for j, x in enumerate(self.optimizer.param_groups):
  963. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  964. x['lr'] = np.interp(
  965. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  966. if 'momentum' in x:
  967. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  968. # To device
  969. images = images.to(self.device, non_blocking=True).float() / 255.
  970. # Multi scale
  971. if self.args.multi_scale:
  972. images, targets, img_size = self.rescale_image_targets(
  973. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  974. else:
  975. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  976. # Visualize targets
  977. if self.args.vis_tgt:
  978. vis_data(images*255, targets)
  979. # Inference
  980. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  981. outputs = model(images)
  982. # Compute loss
  983. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  984. losses = loss_dict['losses']
  985. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  986. # Backward
  987. self.scaler.scale(losses).backward()
  988. # Optimize
  989. if self.clip_grad > 0:
  990. # unscale gradients
  991. self.scaler.unscale_(self.optimizer)
  992. # clip gradients
  993. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  994. self.scaler.step(self.optimizer)
  995. self.scaler.update()
  996. self.optimizer.zero_grad()
  997. # Model EMA
  998. if self.model_ema is not None:
  999. self.model_ema.update(model)
  1000. self.last_opt_step = ni
  1001. # Log
  1002. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  1003. t1 = time.time()
  1004. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  1005. # basic infor
  1006. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  1007. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  1008. log += '[lr: {:.6f}]'.format(cur_lr[0])
  1009. # loss infor
  1010. for k in loss_dict_reduced.keys():
  1011. if self.args.vis_aux_loss:
  1012. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  1013. else:
  1014. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  1015. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  1016. # other infor
  1017. log += '[time: {:.2f}]'.format(t1 - t0)
  1018. log += '[size: {}]'.format(img_size)
  1019. # print log infor
  1020. print(log, flush=True)
  1021. t0 = time.time()
  1022. # LR Scheduler
  1023. self.lr_scheduler.step()
  1024. def refine_targets(self, targets, min_box_size, img_size):
  1025. # rescale targets
  1026. for tgt in targets:
  1027. boxes = tgt["boxes"]
  1028. labels = tgt["labels"]
  1029. # refine tgt
  1030. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1031. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1032. keep = (min_tgt_size >= min_box_size)
  1033. # xyxy -> cxcywh
  1034. new_boxes = torch.zeros_like(boxes)
  1035. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  1036. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  1037. # normalize
  1038. new_boxes /= img_size
  1039. del boxes
  1040. tgt["boxes"] = new_boxes[keep]
  1041. tgt["labels"] = labels[keep]
  1042. return targets
  1043. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  1044. """
  1045. Deployed for Multi scale trick.
  1046. """
  1047. if isinstance(stride, int):
  1048. max_stride = stride
  1049. elif isinstance(stride, list):
  1050. max_stride = max(stride)
  1051. # During training phase, the shape of input image is square.
  1052. old_img_size = images.shape[-1]
  1053. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  1054. new_img_size = new_img_size // max_stride * max_stride # size
  1055. if new_img_size / old_img_size != 1:
  1056. # interpolate
  1057. images = torch.nn.functional.interpolate(
  1058. input=images,
  1059. size=new_img_size,
  1060. mode='bilinear',
  1061. align_corners=False)
  1062. # rescale targets
  1063. for tgt in targets:
  1064. boxes = tgt["boxes"].clone()
  1065. labels = tgt["labels"].clone()
  1066. boxes = torch.clamp(boxes, 0, old_img_size)
  1067. # rescale box
  1068. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  1069. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  1070. # refine tgt
  1071. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1072. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1073. keep = (min_tgt_size >= min_box_size)
  1074. # xyxy -> cxcywh
  1075. new_boxes = torch.zeros_like(boxes)
  1076. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  1077. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  1078. # normalize
  1079. new_boxes /= new_img_size
  1080. del boxes
  1081. tgt["boxes"] = new_boxes[keep]
  1082. tgt["labels"] = labels[keep]
  1083. return images, targets, new_img_size
  1084. # Build Trainer
  1085. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  1086. if model_cfg['trainer_type'] == 'yolov8':
  1087. return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1088. elif model_cfg['trainer_type'] == 'yolox':
  1089. return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1090. elif model_cfg['trainer_type'] == 'rtmdet':
  1091. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1092. elif model_cfg['trainer_type'] == 'detr':
  1093. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1094. else:
  1095. raise NotImplementedError