engine.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # YOLOv8-style Trainer
  19. class Yolov8Trainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.no_aug_epoch = args.no_aug_epoch
  27. self.clip_grad = 10
  28. self.device = device
  29. self.criterion = criterion
  30. self.world_size = world_size
  31. self.heavy_eval = False
  32. self.second_stage = False
  33. # ---------------------------- Hyperparameters refer to YOLOv8 ----------------------------
  34. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  35. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  36. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  37. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  38. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  39. self.data_cfg = data_cfg
  40. self.model_cfg = model_cfg
  41. self.trans_cfg = trans_cfg
  42. # ---------------------------- Build Transform ----------------------------
  43. self.train_transform, self.trans_cfg = build_transform(
  44. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  45. self.val_transform, _ = build_transform(
  46. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  47. # ---------------------------- Build Dataset & Dataloader ----------------------------
  48. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  49. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  50. # ---------------------------- Build Evaluator ----------------------------
  51. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  52. # ---------------------------- Build Grad. Scaler ----------------------------
  53. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  54. # ---------------------------- Build Optimizer ----------------------------
  55. accumulate = max(1, round(64 / self.args.batch_size))
  56. print('Grad Accumulate: {}'.format(accumulate))
  57. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  58. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  59. # ---------------------------- Build LR Scheduler ----------------------------
  60. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  61. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  62. if self.args.resume:
  63. self.lr_scheduler.step()
  64. # ---------------------------- Build Model-EMA ----------------------------
  65. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  66. print('Build ModelEMA ...')
  67. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  68. else:
  69. self.model_ema = None
  70. def check_second_stage(self):
  71. # set second stage
  72. print('============== Second stage of Training ==============')
  73. self.second_stage = True
  74. # close mosaic augmentation
  75. if self.train_loader.dataset.mosaic_prob > 0.:
  76. print(' - Close < Mosaic Augmentation > ...')
  77. self.train_loader.dataset.mosaic_prob = 0.
  78. self.heavy_eval = True
  79. # close mixup augmentation
  80. if self.train_loader.dataset.mixup_prob > 0.:
  81. print(' - Close < Mixup Augmentation > ...')
  82. self.train_loader.dataset.mixup_prob = 0.
  83. self.heavy_eval = True
  84. # close rotation augmentation
  85. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  86. print(' - Close < degress of rotation > ...')
  87. self.trans_cfg['degrees'] = 0.0
  88. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  89. print(' - Close < shear of rotation >...')
  90. self.trans_cfg['shear'] = 0.0
  91. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  92. print(' - Close < perspective of rotation > ...')
  93. self.trans_cfg['perspective'] = 0.0
  94. # close random affine
  95. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  96. print(' - Close < translate of affine > ...')
  97. self.trans_cfg['translate'] = 0.0
  98. if 'scale' in self.trans_cfg.keys():
  99. print(' - Close < scale of affine >...')
  100. self.trans_cfg['scale'] = [1.0, 1.0]
  101. # build a new transform for second stage
  102. print(' - Rebuild transforms ...')
  103. self.train_transform, self.trans_cfg = build_transform(
  104. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  105. self.train_loader.dataset.transform = self.train_transform
  106. def train(self, model):
  107. for epoch in range(self.start_epoch, self.args.max_epoch):
  108. if self.args.distributed:
  109. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  110. # check second stage
  111. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  112. self.check_second_stage()
  113. # train one epoch
  114. self.epoch = epoch
  115. self.train_one_epoch(model)
  116. # eval one epoch
  117. if self.heavy_eval:
  118. model_eval = model.module if self.args.distributed else model
  119. self.eval(model_eval)
  120. else:
  121. model_eval = model.module if self.args.distributed else model
  122. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  123. self.eval(model_eval)
  124. def eval(self, model):
  125. # chech model
  126. model_eval = model if self.model_ema is None else self.model_ema.ema
  127. # path to save model
  128. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  129. os.makedirs(path_to_save, exist_ok=True)
  130. if distributed_utils.is_main_process():
  131. # check evaluator
  132. if self.evaluator is None:
  133. print('No evaluator ... save model and go on training.')
  134. print('Saving state, epoch: {}'.format(self.epoch + 1))
  135. weight_name = '{}_no_eval.pth'.format(self.args.model)
  136. checkpoint_path = os.path.join(path_to_save, weight_name)
  137. torch.save({'model': model_eval.state_dict(),
  138. 'mAP': -1.,
  139. 'optimizer': self.optimizer.state_dict(),
  140. 'epoch': self.epoch,
  141. 'args': self.args},
  142. checkpoint_path)
  143. else:
  144. print('eval ...')
  145. # set eval mode
  146. model_eval.trainable = False
  147. model_eval.eval()
  148. # evaluate
  149. with torch.no_grad():
  150. self.evaluator.evaluate(model_eval)
  151. # save model
  152. cur_map = self.evaluator.map
  153. if cur_map > self.best_map:
  154. # update best-map
  155. self.best_map = cur_map
  156. # save model
  157. print('Saving state, epoch:', self.epoch + 1)
  158. weight_name = '{}_best.pth'.format(self.args.model)
  159. checkpoint_path = os.path.join(path_to_save, weight_name)
  160. torch.save({'model': model_eval.state_dict(),
  161. 'mAP': round(self.best_map*100, 1),
  162. 'optimizer': self.optimizer.state_dict(),
  163. 'epoch': self.epoch,
  164. 'args': self.args},
  165. checkpoint_path)
  166. # set train mode.
  167. model_eval.trainable = True
  168. model_eval.train()
  169. if self.args.distributed:
  170. # wait for all processes to synchronize
  171. dist.barrier()
  172. def train_one_epoch(self, model):
  173. # basic parameters
  174. epoch_size = len(self.train_loader)
  175. img_size = self.args.img_size
  176. t0 = time.time()
  177. nw = epoch_size * self.args.wp_epoch
  178. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  179. # train one epoch
  180. for iter_i, (images, targets) in enumerate(self.train_loader):
  181. ni = iter_i + self.epoch * epoch_size
  182. # Warmup
  183. if ni <= nw:
  184. xi = [0, nw] # x interp
  185. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  186. for j, x in enumerate(self.optimizer.param_groups):
  187. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  188. x['lr'] = np.interp(
  189. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  190. if 'momentum' in x:
  191. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  192. # to device
  193. images = images.to(self.device, non_blocking=True).float() / 255.
  194. # Multi scale
  195. if self.args.multi_scale:
  196. images, targets, img_size = self.rescale_image_targets(
  197. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  198. else:
  199. targets = self.refine_targets(targets, self.args.min_box_size)
  200. # visualize train targets
  201. if self.args.vis_tgt:
  202. vis_data(images*255, targets)
  203. # inference
  204. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  205. outputs = model(images)
  206. # loss
  207. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  208. losses = loss_dict['losses']
  209. losses *= images.shape[0] # loss * bs
  210. # reduce
  211. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  212. # gradient averaged between devices in DDP mode
  213. losses *= distributed_utils.get_world_size()
  214. # backward
  215. self.scaler.scale(losses).backward()
  216. # Optimize
  217. if ni - self.last_opt_step >= accumulate:
  218. if self.clip_grad > 0:
  219. # unscale gradients
  220. self.scaler.unscale_(self.optimizer)
  221. # clip gradients
  222. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  223. # optimizer.step
  224. self.scaler.step(self.optimizer)
  225. self.scaler.update()
  226. self.optimizer.zero_grad()
  227. # ema
  228. if self.model_ema is not None:
  229. self.model_ema.update(model)
  230. self.last_opt_step = ni
  231. # display
  232. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  233. t1 = time.time()
  234. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  235. # basic infor
  236. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  237. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  238. log += '[lr: {:.6f}]'.format(cur_lr[2])
  239. # loss infor
  240. for k in loss_dict_reduced.keys():
  241. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  242. # other infor
  243. log += '[time: {:.2f}]'.format(t1 - t0)
  244. log += '[size: {}]'.format(img_size)
  245. # print log infor
  246. print(log, flush=True)
  247. t0 = time.time()
  248. self.lr_scheduler.step()
  249. def refine_targets(self, targets, min_box_size):
  250. # rescale targets
  251. for tgt in targets:
  252. boxes = tgt["boxes"].clone()
  253. labels = tgt["labels"].clone()
  254. # refine tgt
  255. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  256. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  257. keep = (min_tgt_size >= min_box_size)
  258. tgt["boxes"] = boxes[keep]
  259. tgt["labels"] = labels[keep]
  260. return targets
  261. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  262. """
  263. Deployed for Multi scale trick.
  264. """
  265. if isinstance(stride, int):
  266. max_stride = stride
  267. elif isinstance(stride, list):
  268. max_stride = max(stride)
  269. # During training phase, the shape of input image is square.
  270. old_img_size = images.shape[-1]
  271. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  272. new_img_size = new_img_size // max_stride * max_stride # size
  273. if new_img_size / old_img_size != 1:
  274. # interpolate
  275. images = torch.nn.functional.interpolate(
  276. input=images,
  277. size=new_img_size,
  278. mode='bilinear',
  279. align_corners=False)
  280. # rescale targets
  281. for tgt in targets:
  282. boxes = tgt["boxes"].clone()
  283. labels = tgt["labels"].clone()
  284. boxes = torch.clamp(boxes, 0, old_img_size)
  285. # rescale box
  286. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  287. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  288. # refine tgt
  289. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  290. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  291. keep = (min_tgt_size >= min_box_size)
  292. tgt["boxes"] = boxes[keep]
  293. tgt["labels"] = labels[keep]
  294. return images, targets, new_img_size
  295. # YOLOX-syle Trainer
  296. class YoloxTrainer(object):
  297. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  298. # ------------------- basic parameters -------------------
  299. self.args = args
  300. self.epoch = 0
  301. self.best_map = -1.
  302. self.device = device
  303. self.criterion = criterion
  304. self.world_size = world_size
  305. self.no_aug_epoch = args.no_aug_epoch
  306. self.heavy_eval = False
  307. self.second_stage = False
  308. # ---------------------------- Hyperparameters refer to YOLOX ----------------------------
  309. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.9, 'weight_decay': 5e-4, 'lr0': 0.01}
  310. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  311. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  312. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  313. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  314. self.data_cfg = data_cfg
  315. self.model_cfg = model_cfg
  316. self.trans_cfg = trans_cfg
  317. # ---------------------------- Build Transform ----------------------------
  318. self.train_transform, self.trans_cfg = build_transform(
  319. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  320. self.val_transform, _ = build_transform(
  321. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  322. # ---------------------------- Build Dataset & Dataloader ----------------------------
  323. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  324. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  325. # ---------------------------- Build Evaluator ----------------------------
  326. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  327. # ---------------------------- Build Grad. Scaler ----------------------------
  328. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  329. # ---------------------------- Build Optimizer ----------------------------
  330. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  331. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  332. # ---------------------------- Build LR Scheduler ----------------------------
  333. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch - self.no_aug_epoch)
  334. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  335. if self.args.resume:
  336. self.lr_scheduler.step()
  337. # ---------------------------- Build Model-EMA ----------------------------
  338. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  339. print('Build ModelEMA ...')
  340. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  341. else:
  342. self.model_ema = None
  343. def train(self, model):
  344. for epoch in range(self.start_epoch, self.args.max_epoch):
  345. if self.args.distributed:
  346. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  347. # check second stage
  348. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  349. self.check_second_stage()
  350. # train one epoch
  351. self.epoch = epoch
  352. self.train_one_epoch(model)
  353. # eval one epoch
  354. if self.heavy_eval:
  355. model_eval = model.module if self.args.distributed else model
  356. self.eval(model_eval)
  357. else:
  358. model_eval = model.module if self.args.distributed else model
  359. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  360. self.eval(model_eval)
  361. def eval(self, model):
  362. # chech model
  363. model_eval = model if self.model_ema is None else self.model_ema.ema
  364. # path to save model
  365. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  366. os.makedirs(path_to_save, exist_ok=True)
  367. if distributed_utils.is_main_process():
  368. # check evaluator
  369. if self.evaluator is None:
  370. print('No evaluator ... save model and go on training.')
  371. print('Saving state, epoch: {}'.format(self.epoch + 1))
  372. weight_name = '{}_no_eval.pth'.format(self.args.model)
  373. checkpoint_path = os.path.join(path_to_save, weight_name)
  374. torch.save({'model': model_eval.state_dict(),
  375. 'mAP': -1.,
  376. 'optimizer': self.optimizer.state_dict(),
  377. 'epoch': self.epoch,
  378. 'args': self.args},
  379. checkpoint_path)
  380. else:
  381. print('eval ...')
  382. # set eval mode
  383. model_eval.trainable = False
  384. model_eval.eval()
  385. # evaluate
  386. with torch.no_grad():
  387. self.evaluator.evaluate(model_eval)
  388. # save model
  389. cur_map = self.evaluator.map
  390. if cur_map > self.best_map:
  391. # update best-map
  392. self.best_map = cur_map
  393. # save model
  394. print('Saving state, epoch:', self.epoch + 1)
  395. weight_name = '{}_best.pth'.format(self.args.model)
  396. checkpoint_path = os.path.join(path_to_save, weight_name)
  397. torch.save({'model': model_eval.state_dict(),
  398. 'mAP': round(self.best_map*100, 1),
  399. 'optimizer': self.optimizer.state_dict(),
  400. 'epoch': self.epoch,
  401. 'args': self.args},
  402. checkpoint_path)
  403. # set train mode.
  404. model_eval.trainable = True
  405. model_eval.train()
  406. if self.args.distributed:
  407. # wait for all processes to synchronize
  408. dist.barrier()
  409. def train_one_epoch(self, model):
  410. # basic parameters
  411. epoch_size = len(self.train_loader)
  412. img_size = self.args.img_size
  413. t0 = time.time()
  414. nw = epoch_size * self.args.wp_epoch
  415. # Train one epoch
  416. for iter_i, (images, targets) in enumerate(self.train_loader):
  417. ni = iter_i + self.epoch * epoch_size
  418. # Warmup
  419. if ni <= nw:
  420. xi = [0, nw] # x interp
  421. for j, x in enumerate(self.optimizer.param_groups):
  422. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  423. x['lr'] = np.interp(
  424. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  425. if 'momentum' in x:
  426. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  427. # To device
  428. images = images.to(self.device, non_blocking=True).float() / 255.
  429. # Multi scale
  430. if self.args.multi_scale and ni % 10 == 0:
  431. images, targets, img_size = self.rescale_image_targets(
  432. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  433. else:
  434. targets = self.refine_targets(targets, self.args.min_box_size)
  435. # Visualize train targets
  436. if self.args.vis_tgt:
  437. vis_data(images*255, targets)
  438. # Inference
  439. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  440. outputs = model(images)
  441. # Compute loss
  442. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  443. losses = loss_dict['losses']
  444. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  445. # Backward
  446. self.scaler.scale(losses).backward()
  447. # Optimize
  448. self.scaler.step(self.optimizer)
  449. self.scaler.update()
  450. self.optimizer.zero_grad()
  451. # ema
  452. if self.model_ema is not None:
  453. self.model_ema.update(model)
  454. # Logs
  455. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  456. t1 = time.time()
  457. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  458. # basic infor
  459. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  460. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  461. log += '[lr: {:.6f}]'.format(cur_lr[2])
  462. # loss infor
  463. for k in loss_dict_reduced.keys():
  464. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  465. # other infor
  466. log += '[time: {:.2f}]'.format(t1 - t0)
  467. log += '[size: {}]'.format(img_size)
  468. # print log infor
  469. print(log, flush=True)
  470. t0 = time.time()
  471. # LR Schedule
  472. if not self.second_stage:
  473. self.lr_scheduler.step()
  474. def check_second_stage(self):
  475. # set second stage
  476. print('============== Second stage of Training ==============')
  477. self.second_stage = True
  478. # close mosaic augmentation
  479. if self.train_loader.dataset.mosaic_prob > 0.:
  480. print(' - Close < Mosaic Augmentation > ...')
  481. self.train_loader.dataset.mosaic_prob = 0.
  482. self.heavy_eval = True
  483. # close mixup augmentation
  484. if self.train_loader.dataset.mixup_prob > 0.:
  485. print(' - Close < Mixup Augmentation > ...')
  486. self.train_loader.dataset.mixup_prob = 0.
  487. self.heavy_eval = True
  488. # close rotation augmentation
  489. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  490. print(' - Close < degress of rotation > ...')
  491. self.trans_cfg['degrees'] = 0.0
  492. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  493. print(' - Close < shear of rotation >...')
  494. self.trans_cfg['shear'] = 0.0
  495. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  496. print(' - Close < perspective of rotation > ...')
  497. self.trans_cfg['perspective'] = 0.0
  498. # close random affine
  499. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  500. print(' - Close < translate of affine > ...')
  501. self.trans_cfg['translate'] = 0.0
  502. if 'scale' in self.trans_cfg.keys():
  503. print(' - Close < scale of affine >...')
  504. self.trans_cfg['scale'] = [1.0, 1.0]
  505. # build a new transform for second stage
  506. print(' - Rebuild transforms ...')
  507. self.train_transform, self.trans_cfg = build_transform(
  508. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  509. self.train_loader.dataset.transform = self.train_transform
  510. def refine_targets(self, targets, min_box_size):
  511. # rescale targets
  512. for tgt in targets:
  513. boxes = tgt["boxes"].clone()
  514. labels = tgt["labels"].clone()
  515. # refine tgt
  516. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  517. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  518. keep = (min_tgt_size >= min_box_size)
  519. tgt["boxes"] = boxes[keep]
  520. tgt["labels"] = labels[keep]
  521. return targets
  522. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  523. """
  524. Deployed for Multi scale trick.
  525. """
  526. if isinstance(stride, int):
  527. max_stride = stride
  528. elif isinstance(stride, list):
  529. max_stride = max(stride)
  530. # During training phase, the shape of input image is square.
  531. old_img_size = images.shape[-1]
  532. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  533. new_img_size = new_img_size // max_stride * max_stride # size
  534. if new_img_size / old_img_size != 1:
  535. # interpolate
  536. images = torch.nn.functional.interpolate(
  537. input=images,
  538. size=new_img_size,
  539. mode='bilinear',
  540. align_corners=False)
  541. # rescale targets
  542. for tgt in targets:
  543. boxes = tgt["boxes"].clone()
  544. labels = tgt["labels"].clone()
  545. boxes = torch.clamp(boxes, 0, old_img_size)
  546. # rescale box
  547. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  548. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  549. # refine tgt
  550. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  551. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  552. keep = (min_tgt_size >= min_box_size)
  553. tgt["boxes"] = boxes[keep]
  554. tgt["labels"] = labels[keep]
  555. return images, targets, new_img_size
  556. # RTMDet-syle Trainer
  557. class RTMTrainer(object):
  558. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  559. # ------------------- basic parameters -------------------
  560. self.args = args
  561. self.epoch = 0
  562. self.best_map = -1.
  563. self.device = device
  564. self.criterion = criterion
  565. self.world_size = world_size
  566. self.no_aug_epoch = args.no_aug_epoch
  567. self.clip_grad = 35
  568. self.heavy_eval = False
  569. self.second_stage = False
  570. # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
  571. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  572. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  573. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  574. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  575. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  576. self.data_cfg = data_cfg
  577. self.model_cfg = model_cfg
  578. self.trans_cfg = trans_cfg
  579. # ---------------------------- Build Transform ----------------------------
  580. self.train_transform, self.trans_cfg = build_transform(
  581. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  582. self.val_transform, _ = build_transform(
  583. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  584. # ---------------------------- Build Dataset & Dataloader ----------------------------
  585. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  586. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  587. # ---------------------------- Build Evaluator ----------------------------
  588. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  589. # ---------------------------- Build Grad. Scaler ----------------------------
  590. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  591. # ---------------------------- Build Optimizer ----------------------------
  592. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  593. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  594. # ---------------------------- Build LR Scheduler ----------------------------
  595. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch - self.no_aug_epoch)
  596. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  597. if self.args.resume:
  598. self.lr_scheduler.step()
  599. # ---------------------------- Build Model-EMA ----------------------------
  600. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  601. print('Build ModelEMA ...')
  602. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  603. else:
  604. self.model_ema = None
  605. def train(self, model):
  606. for epoch in range(self.start_epoch, self.args.max_epoch):
  607. if self.args.distributed:
  608. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  609. # check second stage
  610. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  611. self.check_second_stage()
  612. # train one epoch
  613. self.epoch = epoch
  614. self.train_one_epoch(model)
  615. # eval one epoch
  616. if self.heavy_eval:
  617. model_eval = model.module if self.args.distributed else model
  618. self.eval(model_eval)
  619. else:
  620. model_eval = model.module if self.args.distributed else model
  621. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  622. self.eval(model_eval)
  623. def eval(self, model):
  624. # chech model
  625. model_eval = model if self.model_ema is None else self.model_ema.ema
  626. # path to save model
  627. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  628. os.makedirs(path_to_save, exist_ok=True)
  629. if distributed_utils.is_main_process():
  630. # check evaluator
  631. if self.evaluator is None:
  632. print('No evaluator ... save model and go on training.')
  633. print('Saving state, epoch: {}'.format(self.epoch + 1))
  634. weight_name = '{}_no_eval.pth'.format(self.args.model)
  635. checkpoint_path = os.path.join(path_to_save, weight_name)
  636. torch.save({'model': model_eval.state_dict(),
  637. 'mAP': -1.,
  638. 'optimizer': self.optimizer.state_dict(),
  639. 'epoch': self.epoch,
  640. 'args': self.args},
  641. checkpoint_path)
  642. else:
  643. print('eval ...')
  644. # set eval mode
  645. model_eval.trainable = False
  646. model_eval.eval()
  647. # evaluate
  648. with torch.no_grad():
  649. self.evaluator.evaluate(model_eval)
  650. # save model
  651. cur_map = self.evaluator.map
  652. if cur_map > self.best_map:
  653. # update best-map
  654. self.best_map = cur_map
  655. # save model
  656. print('Saving state, epoch:', self.epoch + 1)
  657. weight_name = '{}_best.pth'.format(self.args.model)
  658. checkpoint_path = os.path.join(path_to_save, weight_name)
  659. torch.save({'model': model_eval.state_dict(),
  660. 'mAP': round(self.best_map*100, 1),
  661. 'optimizer': self.optimizer.state_dict(),
  662. 'epoch': self.epoch,
  663. 'args': self.args},
  664. checkpoint_path)
  665. # set train mode.
  666. model_eval.trainable = True
  667. model_eval.train()
  668. if self.args.distributed:
  669. # wait for all processes to synchronize
  670. dist.barrier()
  671. def train_one_epoch(self, model):
  672. # basic parameters
  673. epoch_size = len(self.train_loader)
  674. img_size = self.args.img_size
  675. t0 = time.time()
  676. nw = epoch_size * self.args.wp_epoch
  677. # Train one epoch
  678. for iter_i, (images, targets) in enumerate(self.train_loader):
  679. ni = iter_i + self.epoch * epoch_size
  680. # Warmup
  681. if ni <= nw:
  682. xi = [0, nw] # x interp
  683. for j, x in enumerate(self.optimizer.param_groups):
  684. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  685. x['lr'] = np.interp(
  686. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  687. if 'momentum' in x:
  688. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  689. # To device
  690. images = images.to(self.device, non_blocking=True).float() / 255.
  691. # Multi scale
  692. if self.args.multi_scale:
  693. images, targets, img_size = self.rescale_image_targets(
  694. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  695. else:
  696. targets = self.refine_targets(targets, self.args.min_box_size)
  697. # Visualize train targets
  698. if self.args.vis_tgt:
  699. vis_data(images*255, targets)
  700. # Inference
  701. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  702. outputs = model(images)
  703. # Compute loss
  704. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  705. losses = loss_dict['losses']
  706. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  707. # Backward
  708. self.scaler.scale(losses).backward()
  709. # Optimize
  710. if self.clip_grad > 0:
  711. # unscale gradients
  712. self.scaler.unscale_(self.optimizer)
  713. # clip gradients
  714. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  715. # optimizer.step
  716. self.scaler.step(self.optimizer)
  717. self.scaler.update()
  718. self.optimizer.zero_grad()
  719. # ema
  720. if self.model_ema is not None:
  721. self.model_ema.update(model)
  722. # Logs
  723. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  724. t1 = time.time()
  725. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  726. # basic infor
  727. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  728. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  729. log += '[lr: {:.6f}]'.format(cur_lr[2])
  730. # loss infor
  731. for k in loss_dict_reduced.keys():
  732. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  733. # other infor
  734. log += '[time: {:.2f}]'.format(t1 - t0)
  735. log += '[size: {}]'.format(img_size)
  736. # print log infor
  737. print(log, flush=True)
  738. t0 = time.time()
  739. # LR Schedule
  740. if not self.second_stage:
  741. self.lr_scheduler.step()
  742. def check_second_stage(self):
  743. # set second stage
  744. print('============== Second stage of Training ==============')
  745. self.second_stage = True
  746. # close mosaic augmentation
  747. if self.train_loader.dataset.mosaic_prob > 0.:
  748. print(' - Close < Mosaic Augmentation > ...')
  749. self.train_loader.dataset.mosaic_prob = 0.
  750. self.heavy_eval = True
  751. # close mixup augmentation
  752. if self.train_loader.dataset.mixup_prob > 0.:
  753. print(' - Close < Mixup Augmentation > ...')
  754. self.train_loader.dataset.mixup_prob = 0.
  755. self.heavy_eval = True
  756. # close rotation augmentation
  757. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  758. print(' - Close < degress of rotation > ...')
  759. self.trans_cfg['degrees'] = 0.0
  760. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  761. print(' - Close < shear of rotation >...')
  762. self.trans_cfg['shear'] = 0.0
  763. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  764. print(' - Close < perspective of rotation > ...')
  765. self.trans_cfg['perspective'] = 0.0
  766. # build a new transform for second stage
  767. print(' - Rebuild transforms ...')
  768. self.train_transform, self.trans_cfg = build_transform(
  769. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  770. self.train_loader.dataset.transform = self.train_transform
  771. def refine_targets(self, targets, min_box_size):
  772. # rescale targets
  773. for tgt in targets:
  774. boxes = tgt["boxes"].clone()
  775. labels = tgt["labels"].clone()
  776. # refine tgt
  777. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  778. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  779. keep = (min_tgt_size >= min_box_size)
  780. tgt["boxes"] = boxes[keep]
  781. tgt["labels"] = labels[keep]
  782. return targets
  783. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  784. """
  785. Deployed for Multi scale trick.
  786. """
  787. if isinstance(stride, int):
  788. max_stride = stride
  789. elif isinstance(stride, list):
  790. max_stride = max(stride)
  791. # During training phase, the shape of input image is square.
  792. old_img_size = images.shape[-1]
  793. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  794. new_img_size = new_img_size // max_stride * max_stride # size
  795. if new_img_size / old_img_size != 1:
  796. # interpolate
  797. images = torch.nn.functional.interpolate(
  798. input=images,
  799. size=new_img_size,
  800. mode='bilinear',
  801. align_corners=False)
  802. # rescale targets
  803. for tgt in targets:
  804. boxes = tgt["boxes"].clone()
  805. labels = tgt["labels"].clone()
  806. boxes = torch.clamp(boxes, 0, old_img_size)
  807. # rescale box
  808. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  809. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  810. # refine tgt
  811. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  812. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  813. keep = (min_tgt_size >= min_box_size)
  814. tgt["boxes"] = boxes[keep]
  815. tgt["labels"] = labels[keep]
  816. return images, targets, new_img_size
  817. # Trainer for DETR
  818. class DetrTrainer(object):
  819. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  820. # ------------------- basic parameters -------------------
  821. self.args = args
  822. self.epoch = 0
  823. self.best_map = -1.
  824. self.last_opt_step = 0
  825. self.no_aug_epoch = args.no_aug_epoch
  826. self.clip_grad = -1
  827. self.device = device
  828. self.criterion = criterion
  829. self.world_size = world_size
  830. self.second_stage = False
  831. self.heavy_eval = False
  832. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.001, 'backbone_lr_raio': 0.1}
  833. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  834. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  835. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  836. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  837. self.data_cfg = data_cfg
  838. self.model_cfg = model_cfg
  839. self.trans_cfg = trans_cfg
  840. # ---------------------------- Build Transform ----------------------------
  841. self.train_transform, self.trans_cfg = build_transform(
  842. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  843. self.val_transform, _ = build_transform(
  844. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  845. # ---------------------------- Build Dataset & Dataloader ----------------------------
  846. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  847. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  848. # ---------------------------- Build Evaluator ----------------------------
  849. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  850. # ---------------------------- Build Grad. Scaler ----------------------------
  851. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  852. # ---------------------------- Build Optimizer ----------------------------
  853. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  854. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  855. # ---------------------------- Build LR Scheduler ----------------------------
  856. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  857. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  858. if self.args.resume:
  859. self.lr_scheduler.step()
  860. # ---------------------------- Build Model-EMA ----------------------------
  861. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  862. print('Build ModelEMA ...')
  863. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  864. else:
  865. self.model_ema = None
  866. def check_second_stage(self):
  867. # set second stage
  868. print('============== Second stage of Training ==============')
  869. self.second_stage = True
  870. # close mosaic augmentation
  871. if self.train_loader.dataset.mosaic_prob > 0.:
  872. print(' - Close < Mosaic Augmentation > ...')
  873. self.train_loader.dataset.mosaic_prob = 0.
  874. self.heavy_eval = True
  875. # close mixup augmentation
  876. if self.train_loader.dataset.mixup_prob > 0.:
  877. print(' - Close < Mixup Augmentation > ...')
  878. self.train_loader.dataset.mixup_prob = 0.
  879. self.heavy_eval = True
  880. # close rotation augmentation
  881. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  882. print(' - Close < degress of rotation > ...')
  883. self.trans_cfg['degrees'] = 0.0
  884. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  885. print(' - Close < shear of rotation >...')
  886. self.trans_cfg['shear'] = 0.0
  887. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  888. print(' - Close < perspective of rotation > ...')
  889. self.trans_cfg['perspective'] = 0.0
  890. # build a new transform for second stage
  891. print(' - Rebuild transforms ...')
  892. self.train_transform, self.trans_cfg = build_transform(
  893. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  894. self.train_loader.dataset.transform = self.train_transform
  895. def train(self, model):
  896. for epoch in range(self.start_epoch, self.args.max_epoch):
  897. if self.args.distributed:
  898. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  899. # check second stage
  900. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  901. self.check_second_stage()
  902. # train one epoch
  903. self.epoch = epoch
  904. self.train_one_epoch(model)
  905. # eval one epoch
  906. if self.heavy_eval:
  907. model_eval = model.module if self.args.distributed else model
  908. self.eval(model_eval)
  909. else:
  910. model_eval = model.module if self.args.distributed else model
  911. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  912. self.eval(model_eval)
  913. def eval(self, model):
  914. # chech model
  915. model_eval = model if self.model_ema is None else self.model_ema.ema
  916. # path to save model
  917. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  918. os.makedirs(path_to_save, exist_ok=True)
  919. if distributed_utils.is_main_process():
  920. # check evaluator
  921. if self.evaluator is None:
  922. print('No evaluator ... save model and go on training.')
  923. print('Saving state, epoch: {}'.format(self.epoch + 1))
  924. weight_name = '{}_no_eval.pth'.format(self.args.model)
  925. checkpoint_path = os.path.join(path_to_save, weight_name)
  926. torch.save({'model': model_eval.state_dict(),
  927. 'mAP': -1.,
  928. 'optimizer': self.optimizer.state_dict(),
  929. 'epoch': self.epoch,
  930. 'args': self.args},
  931. checkpoint_path)
  932. else:
  933. print('eval ...')
  934. # set eval mode
  935. model_eval.trainable = False
  936. model_eval.eval()
  937. # evaluate
  938. with torch.no_grad():
  939. self.evaluator.evaluate(model_eval)
  940. # save model
  941. cur_map = self.evaluator.map
  942. if cur_map > self.best_map:
  943. # update best-map
  944. self.best_map = cur_map
  945. # save model
  946. print('Saving state, epoch:', self.epoch + 1)
  947. weight_name = '{}_best.pth'.format(self.args.model)
  948. checkpoint_path = os.path.join(path_to_save, weight_name)
  949. torch.save({'model': model_eval.state_dict(),
  950. 'mAP': round(self.best_map*100, 1),
  951. 'optimizer': self.optimizer.state_dict(),
  952. 'epoch': self.epoch,
  953. 'args': self.args},
  954. checkpoint_path)
  955. # set train mode.
  956. model_eval.trainable = True
  957. model_eval.train()
  958. if self.args.distributed:
  959. # wait for all processes to synchronize
  960. dist.barrier()
  961. def train_one_epoch(self, model):
  962. # basic parameters
  963. epoch_size = len(self.train_loader)
  964. img_size = self.args.img_size
  965. t0 = time.time()
  966. nw = epoch_size * self.args.wp_epoch
  967. # train one epoch
  968. for iter_i, (images, targets) in enumerate(self.train_loader):
  969. ni = iter_i + self.epoch * epoch_size
  970. # Warmup
  971. if ni <= nw:
  972. xi = [0, nw] # x interp
  973. for j, x in enumerate(self.optimizer.param_groups):
  974. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  975. x['lr'] = np.interp(
  976. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  977. if 'momentum' in x:
  978. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  979. # To device
  980. images = images.to(self.device, non_blocking=True).float() / 255.
  981. # Multi scale
  982. if self.args.multi_scale:
  983. images, targets, img_size = self.rescale_image_targets(
  984. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  985. else:
  986. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  987. # Visualize targets
  988. if self.args.vis_tgt:
  989. vis_data(images*255, targets)
  990. # Inference
  991. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  992. outputs = model(images)
  993. # Compute loss
  994. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  995. losses = loss_dict['losses']
  996. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  997. # Backward
  998. self.scaler.scale(losses).backward()
  999. # Optimize
  1000. if self.clip_grad > 0:
  1001. # unscale gradients
  1002. self.scaler.unscale_(self.optimizer)
  1003. # clip gradients
  1004. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  1005. self.scaler.step(self.optimizer)
  1006. self.scaler.update()
  1007. self.optimizer.zero_grad()
  1008. # Model EMA
  1009. if self.model_ema is not None:
  1010. self.model_ema.update(model)
  1011. self.last_opt_step = ni
  1012. # Log
  1013. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  1014. t1 = time.time()
  1015. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  1016. # basic infor
  1017. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  1018. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  1019. log += '[lr: {:.6f}]'.format(cur_lr[0])
  1020. # loss infor
  1021. for k in loss_dict_reduced.keys():
  1022. if self.args.vis_aux_loss:
  1023. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  1024. else:
  1025. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  1026. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  1027. # other infor
  1028. log += '[time: {:.2f}]'.format(t1 - t0)
  1029. log += '[size: {}]'.format(img_size)
  1030. # print log infor
  1031. print(log, flush=True)
  1032. t0 = time.time()
  1033. # LR Scheduler
  1034. self.lr_scheduler.step()
  1035. def refine_targets(self, targets, min_box_size, img_size):
  1036. # rescale targets
  1037. for tgt in targets:
  1038. boxes = tgt["boxes"]
  1039. labels = tgt["labels"]
  1040. # refine tgt
  1041. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1042. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1043. keep = (min_tgt_size >= min_box_size)
  1044. # xyxy -> cxcywh
  1045. new_boxes = torch.zeros_like(boxes)
  1046. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  1047. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  1048. # normalize
  1049. new_boxes /= img_size
  1050. del boxes
  1051. tgt["boxes"] = new_boxes[keep]
  1052. tgt["labels"] = labels[keep]
  1053. return targets
  1054. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  1055. """
  1056. Deployed for Multi scale trick.
  1057. """
  1058. if isinstance(stride, int):
  1059. max_stride = stride
  1060. elif isinstance(stride, list):
  1061. max_stride = max(stride)
  1062. # During training phase, the shape of input image is square.
  1063. old_img_size = images.shape[-1]
  1064. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  1065. new_img_size = new_img_size // max_stride * max_stride # size
  1066. if new_img_size / old_img_size != 1:
  1067. # interpolate
  1068. images = torch.nn.functional.interpolate(
  1069. input=images,
  1070. size=new_img_size,
  1071. mode='bilinear',
  1072. align_corners=False)
  1073. # rescale targets
  1074. for tgt in targets:
  1075. boxes = tgt["boxes"].clone()
  1076. labels = tgt["labels"].clone()
  1077. boxes = torch.clamp(boxes, 0, old_img_size)
  1078. # rescale box
  1079. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  1080. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  1081. # refine tgt
  1082. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1083. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1084. keep = (min_tgt_size >= min_box_size)
  1085. # xyxy -> cxcywh
  1086. new_boxes = torch.zeros_like(boxes)
  1087. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  1088. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  1089. # normalize
  1090. new_boxes /= new_img_size
  1091. del boxes
  1092. tgt["boxes"] = new_boxes[keep]
  1093. tgt["labels"] = labels[keep]
  1094. return images, targets, new_img_size
  1095. # Build Trainer
  1096. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  1097. if model_cfg['trainer_type'] == 'yolov8':
  1098. return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1099. elif model_cfg['trainer_type'] == 'yolox':
  1100. return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1101. elif model_cfg['trainer_type'] == 'rtmdet':
  1102. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1103. elif model_cfg['trainer_type'] == 'detr':
  1104. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1105. else:
  1106. raise NotImplementedError