engine.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # YOLOv8 Trainer
  19. class Yolov8Trainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.device = device
  26. self.criterion = criterion
  27. self.world_size = world_size
  28. self.heavy_eval = False
  29. self.last_opt_step = 0
  30. self.clip_grad = 10
  31. # weak augmentatino stage
  32. self.second_stage = False
  33. self.third_stage = False
  34. self.second_stage_epoch = args.no_aug_epoch
  35. self.third_stage_epoch = args.no_aug_epoch // 2
  36. # path to save model
  37. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  38. os.makedirs(self.path_to_save, exist_ok=True)
  39. # ---------------------------- Hyperparameters refer to YOLOv8 ----------------------------
  40. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  41. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  42. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  43. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  44. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  45. self.data_cfg = data_cfg
  46. self.model_cfg = model_cfg
  47. self.trans_cfg = trans_cfg
  48. # ---------------------------- Build Transform ----------------------------
  49. self.train_transform, self.trans_cfg = build_transform(
  50. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  51. self.val_transform, _ = build_transform(
  52. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  53. # ---------------------------- Build Dataset & Dataloader ----------------------------
  54. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  55. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  56. # ---------------------------- Build Evaluator ----------------------------
  57. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  58. # ---------------------------- Build Grad. Scaler ----------------------------
  59. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  60. # ---------------------------- Build Optimizer ----------------------------
  61. accumulate = max(1, round(64 / self.args.batch_size))
  62. print('Grad Accumulate: {}'.format(accumulate))
  63. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  64. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  65. # ---------------------------- Build LR Scheduler ----------------------------
  66. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  67. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  68. if self.args.resume and self.args.resume != 'None':
  69. self.lr_scheduler.step()
  70. # ---------------------------- Build Model-EMA ----------------------------
  71. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  72. print('Build ModelEMA ...')
  73. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  74. else:
  75. self.model_ema = None
  76. def train(self, model):
  77. for epoch in range(self.start_epoch, self.args.max_epoch):
  78. if self.args.distributed:
  79. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  80. # check second stage
  81. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  82. self.check_second_stage()
  83. # save model of the last mosaic epoch
  84. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  85. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  86. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  87. torch.save({'model': model.state_dict(),
  88. 'mAP': round(self.evaluator.map*100, 1),
  89. 'optimizer': self.optimizer.state_dict(),
  90. 'epoch': self.epoch,
  91. 'args': self.args},
  92. checkpoint_path)
  93. # check third stage
  94. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  95. self.check_third_stage()
  96. # save model of the last mosaic epoch
  97. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  98. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  99. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  100. torch.save({'model': model.state_dict(),
  101. 'mAP': round(self.evaluator.map*100, 1),
  102. 'optimizer': self.optimizer.state_dict(),
  103. 'epoch': self.epoch,
  104. 'args': self.args},
  105. checkpoint_path)
  106. # train one epoch
  107. self.epoch = epoch
  108. self.train_one_epoch(model)
  109. # eval one epoch
  110. if self.heavy_eval:
  111. model_eval = model.module if self.args.distributed else model
  112. self.eval(model_eval)
  113. else:
  114. model_eval = model.module if self.args.distributed else model
  115. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  116. self.eval(model_eval)
  117. if self.args.debug:
  118. print("For debug mode, we only train 1 epoch")
  119. break
  120. def eval(self, model):
  121. # chech model
  122. model_eval = model if self.model_ema is None else self.model_ema.ema
  123. if distributed_utils.is_main_process():
  124. # check evaluator
  125. if self.evaluator is None:
  126. print('No evaluator ... save model and go on training.')
  127. print('Saving state, epoch: {}'.format(self.epoch + 1))
  128. weight_name = '{}_no_eval.pth'.format(self.args.model)
  129. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  130. torch.save({'model': model_eval.state_dict(),
  131. 'mAP': -1.,
  132. 'optimizer': self.optimizer.state_dict(),
  133. 'epoch': self.epoch,
  134. 'args': self.args},
  135. checkpoint_path)
  136. else:
  137. print('eval ...')
  138. # set eval mode
  139. model_eval.trainable = False
  140. model_eval.eval()
  141. # evaluate
  142. with torch.no_grad():
  143. self.evaluator.evaluate(model_eval)
  144. # save model
  145. cur_map = self.evaluator.map
  146. if cur_map > self.best_map:
  147. # update best-map
  148. self.best_map = cur_map
  149. # save model
  150. print('Saving state, epoch:', self.epoch + 1)
  151. weight_name = '{}_best.pth'.format(self.args.model)
  152. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  153. torch.save({'model': model_eval.state_dict(),
  154. 'mAP': round(self.best_map*100, 1),
  155. 'optimizer': self.optimizer.state_dict(),
  156. 'epoch': self.epoch,
  157. 'args': self.args},
  158. checkpoint_path)
  159. # set train mode.
  160. model_eval.trainable = True
  161. model_eval.train()
  162. if self.args.distributed:
  163. # wait for all processes to synchronize
  164. dist.barrier()
  165. def train_one_epoch(self, model):
  166. # basic parameters
  167. epoch_size = len(self.train_loader)
  168. img_size = self.args.img_size
  169. t0 = time.time()
  170. nw = epoch_size * self.args.wp_epoch
  171. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  172. # train one epoch
  173. for iter_i, (images, targets) in enumerate(self.train_loader):
  174. ni = iter_i + self.epoch * epoch_size
  175. # Warmup
  176. if ni <= nw:
  177. xi = [0, nw] # x interp
  178. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  179. for j, x in enumerate(self.optimizer.param_groups):
  180. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  181. x['lr'] = np.interp(
  182. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  183. if 'momentum' in x:
  184. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  185. # to device
  186. images = images.to(self.device, non_blocking=True).float() / 255.
  187. # Multi scale
  188. if self.args.multi_scale:
  189. images, targets, img_size = self.rescale_image_targets(
  190. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  191. else:
  192. targets = self.refine_targets(targets, self.args.min_box_size)
  193. # visualize train targets
  194. if self.args.vis_tgt:
  195. vis_data(images*255, targets)
  196. # inference
  197. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  198. outputs = model(images)
  199. # loss
  200. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  201. losses = loss_dict['losses']
  202. losses *= images.shape[0] # loss * bs
  203. # reduce
  204. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  205. # gradient averaged between devices in DDP mode
  206. losses *= distributed_utils.get_world_size()
  207. # backward
  208. self.scaler.scale(losses).backward()
  209. # Optimize
  210. if ni - self.last_opt_step >= accumulate:
  211. if self.clip_grad > 0:
  212. # unscale gradients
  213. self.scaler.unscale_(self.optimizer)
  214. # clip gradients
  215. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  216. # optimizer.step
  217. self.scaler.step(self.optimizer)
  218. self.scaler.update()
  219. self.optimizer.zero_grad()
  220. # ema
  221. if self.model_ema is not None:
  222. self.model_ema.update(model)
  223. self.last_opt_step = ni
  224. # display
  225. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  226. t1 = time.time()
  227. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  228. # basic infor
  229. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  230. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  231. log += '[lr: {:.6f}]'.format(cur_lr[2])
  232. # loss infor
  233. for k in loss_dict_reduced.keys():
  234. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  235. # other infor
  236. log += '[time: {:.2f}]'.format(t1 - t0)
  237. log += '[size: {}]'.format(img_size)
  238. # print log infor
  239. print(log, flush=True)
  240. t0 = time.time()
  241. if self.args.debug:
  242. print("For debug mode, we only train 1 iteration")
  243. break
  244. self.lr_scheduler.step()
  245. def check_second_stage(self):
  246. # set second stage
  247. print('============== Second stage of Training ==============')
  248. self.second_stage = True
  249. # close mosaic augmentation
  250. if self.train_loader.dataset.mosaic_prob > 0.:
  251. print(' - Close < Mosaic Augmentation > ...')
  252. self.train_loader.dataset.mosaic_prob = 0.
  253. self.heavy_eval = True
  254. # close mixup augmentation
  255. if self.train_loader.dataset.mixup_prob > 0.:
  256. print(' - Close < Mixup Augmentation > ...')
  257. self.train_loader.dataset.mixup_prob = 0.
  258. self.heavy_eval = True
  259. # close rotation augmentation
  260. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  261. print(' - Close < degress of rotation > ...')
  262. self.trans_cfg['degrees'] = 0.0
  263. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  264. print(' - Close < shear of rotation >...')
  265. self.trans_cfg['shear'] = 0.0
  266. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  267. print(' - Close < perspective of rotation > ...')
  268. self.trans_cfg['perspective'] = 0.0
  269. # build a new transform for second stage
  270. print(' - Rebuild transforms ...')
  271. self.train_transform, self.trans_cfg = build_transform(
  272. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  273. self.train_loader.dataset.transform = self.train_transform
  274. def check_third_stage(self):
  275. # set third stage
  276. print('============== Third stage of Training ==============')
  277. self.third_stage = True
  278. # close random affine
  279. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  280. print(' - Close < translate of affine > ...')
  281. self.trans_cfg['translate'] = 0.0
  282. if 'scale' in self.trans_cfg.keys():
  283. print(' - Close < scale of affine >...')
  284. self.trans_cfg['scale'] = [1.0, 1.0]
  285. # build a new transform for second stage
  286. print(' - Rebuild transforms ...')
  287. self.train_transform, self.trans_cfg = build_transform(
  288. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  289. self.train_loader.dataset.transform = self.train_transform
  290. def refine_targets(self, targets, min_box_size):
  291. # rescale targets
  292. for tgt in targets:
  293. boxes = tgt["boxes"].clone()
  294. labels = tgt["labels"].clone()
  295. # refine tgt
  296. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  297. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  298. keep = (min_tgt_size >= min_box_size)
  299. tgt["boxes"] = boxes[keep]
  300. tgt["labels"] = labels[keep]
  301. return targets
  302. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  303. """
  304. Deployed for Multi scale trick.
  305. """
  306. if isinstance(stride, int):
  307. max_stride = stride
  308. elif isinstance(stride, list):
  309. max_stride = max(stride)
  310. # During training phase, the shape of input image is square.
  311. old_img_size = images.shape[-1]
  312. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  313. new_img_size = new_img_size // max_stride * max_stride # size
  314. if new_img_size / old_img_size != 1:
  315. # interpolate
  316. images = torch.nn.functional.interpolate(
  317. input=images,
  318. size=new_img_size,
  319. mode='bilinear',
  320. align_corners=False)
  321. # rescale targets
  322. for tgt in targets:
  323. boxes = tgt["boxes"].clone()
  324. labels = tgt["labels"].clone()
  325. boxes = torch.clamp(boxes, 0, old_img_size)
  326. # rescale box
  327. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  328. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  329. # refine tgt
  330. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  331. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  332. keep = (min_tgt_size >= min_box_size)
  333. tgt["boxes"] = boxes[keep]
  334. tgt["labels"] = labels[keep]
  335. return images, targets, new_img_size
  336. # YOLOX Trainer
  337. class YoloxTrainer(object):
  338. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  339. # ------------------- basic parameters -------------------
  340. self.args = args
  341. self.epoch = 0
  342. self.best_map = -1.
  343. self.device = device
  344. self.criterion = criterion
  345. self.world_size = world_size
  346. self.grad_accumulate = args.grad_accumulate
  347. self.no_aug_epoch = args.no_aug_epoch
  348. self.heavy_eval = False
  349. # weak augmentatino stage
  350. self.second_stage = False
  351. self.third_stage = False
  352. self.second_stage_epoch = args.no_aug_epoch
  353. self.third_stage_epoch = args.no_aug_epoch // 2
  354. # path to save model
  355. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  356. os.makedirs(self.path_to_save, exist_ok=True)
  357. # ---------------------------- Hyperparameters refer to YOLOX ----------------------------
  358. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.9, 'weight_decay': 5e-4, 'lr0': 0.01}
  359. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  360. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  361. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  362. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  363. self.data_cfg = data_cfg
  364. self.model_cfg = model_cfg
  365. self.trans_cfg = trans_cfg
  366. # ---------------------------- Build Transform ----------------------------
  367. self.train_transform, self.trans_cfg = build_transform(
  368. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  369. self.val_transform, _ = build_transform(
  370. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  371. # ---------------------------- Build Dataset & Dataloader ----------------------------
  372. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  373. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  374. # ---------------------------- Build Evaluator ----------------------------
  375. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  376. # ---------------------------- Build Grad. Scaler ----------------------------
  377. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  378. # ---------------------------- Build Optimizer ----------------------------
  379. self.optimizer_dict['lr0'] *= self.args.batch_size * self.grad_accumulate / 64
  380. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  381. # ---------------------------- Build LR Scheduler ----------------------------
  382. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch - self.no_aug_epoch)
  383. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  384. if self.args.resume and self.args.resume != 'None':
  385. self.lr_scheduler.step()
  386. # ---------------------------- Build Model-EMA ----------------------------
  387. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  388. print('Build ModelEMA ...')
  389. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  390. else:
  391. self.model_ema = None
  392. def train(self, model):
  393. for epoch in range(self.start_epoch, self.args.max_epoch):
  394. if self.args.distributed:
  395. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  396. # check second stage
  397. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  398. self.check_second_stage()
  399. # save model of the last mosaic epoch
  400. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  401. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  402. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  403. torch.save({'model': model.state_dict(),
  404. 'mAP': round(self.evaluator.map*100, 1),
  405. 'optimizer': self.optimizer.state_dict(),
  406. 'epoch': self.epoch,
  407. 'args': self.args},
  408. checkpoint_path)
  409. # check third stage
  410. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  411. self.check_third_stage()
  412. # save model of the last mosaic epoch
  413. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  414. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  415. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  416. torch.save({'model': model.state_dict(),
  417. 'mAP': round(self.evaluator.map*100, 1),
  418. 'optimizer': self.optimizer.state_dict(),
  419. 'epoch': self.epoch,
  420. 'args': self.args},
  421. checkpoint_path)
  422. # train one epoch
  423. self.epoch = epoch
  424. self.train_one_epoch(model)
  425. # eval one epoch
  426. if self.heavy_eval:
  427. model_eval = model.module if self.args.distributed else model
  428. self.eval(model_eval)
  429. else:
  430. model_eval = model.module if self.args.distributed else model
  431. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  432. self.eval(model_eval)
  433. if self.args.debug:
  434. print("For debug mode, we only train 1 epoch")
  435. break
  436. def eval(self, model):
  437. # chech model
  438. model_eval = model if self.model_ema is None else self.model_ema.ema
  439. if distributed_utils.is_main_process():
  440. # check evaluator
  441. if self.evaluator is None:
  442. print('No evaluator ... save model and go on training.')
  443. print('Saving state, epoch: {}'.format(self.epoch + 1))
  444. weight_name = '{}_no_eval.pth'.format(self.args.model)
  445. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  446. torch.save({'model': model_eval.state_dict(),
  447. 'mAP': -1.,
  448. 'optimizer': self.optimizer.state_dict(),
  449. 'epoch': self.epoch,
  450. 'args': self.args},
  451. checkpoint_path)
  452. else:
  453. print('eval ...')
  454. # set eval mode
  455. model_eval.trainable = False
  456. model_eval.eval()
  457. # evaluate
  458. with torch.no_grad():
  459. self.evaluator.evaluate(model_eval)
  460. # save model
  461. cur_map = self.evaluator.map
  462. if cur_map > self.best_map:
  463. # update best-map
  464. self.best_map = cur_map
  465. # save model
  466. print('Saving state, epoch:', self.epoch + 1)
  467. weight_name = '{}_best.pth'.format(self.args.model)
  468. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  469. torch.save({'model': model_eval.state_dict(),
  470. 'mAP': round(self.best_map*100, 1),
  471. 'optimizer': self.optimizer.state_dict(),
  472. 'epoch': self.epoch,
  473. 'args': self.args},
  474. checkpoint_path)
  475. # set train mode.
  476. model_eval.trainable = True
  477. model_eval.train()
  478. if self.args.distributed:
  479. # wait for all processes to synchronize
  480. dist.barrier()
  481. def train_one_epoch(self, model):
  482. # basic parameters
  483. epoch_size = len(self.train_loader)
  484. img_size = self.args.img_size
  485. t0 = time.time()
  486. nw = epoch_size * self.args.wp_epoch
  487. # Train one epoch
  488. for iter_i, (images, targets) in enumerate(self.train_loader):
  489. ni = iter_i + self.epoch * epoch_size
  490. # Warmup
  491. if ni <= nw:
  492. xi = [0, nw] # x interp
  493. for j, x in enumerate(self.optimizer.param_groups):
  494. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  495. x['lr'] = np.interp(
  496. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  497. if 'momentum' in x:
  498. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  499. # To device
  500. images = images.to(self.device, non_blocking=True).float() / 255.
  501. # Multi scale
  502. if self.args.multi_scale and ni % 10 == 0:
  503. images, targets, img_size = self.rescale_image_targets(
  504. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  505. else:
  506. targets = self.refine_targets(targets, self.args.min_box_size)
  507. # Visualize train targets
  508. if self.args.vis_tgt:
  509. vis_data(images*255, targets)
  510. # Inference
  511. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  512. outputs = model(images)
  513. # Compute loss
  514. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  515. losses = loss_dict['losses']
  516. # Grad Accu
  517. if self.grad_accumulate > 1:
  518. losses /= self.grad_accumulate
  519. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  520. # Backward
  521. self.scaler.scale(losses).backward()
  522. # Optimize
  523. if ni % self.grad_accumulate == 0:
  524. self.scaler.step(self.optimizer)
  525. self.scaler.update()
  526. self.optimizer.zero_grad()
  527. # ema
  528. if self.model_ema is not None:
  529. self.model_ema.update(model)
  530. # Logs
  531. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  532. t1 = time.time()
  533. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  534. # basic infor
  535. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  536. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  537. log += '[lr: {:.6f}]'.format(cur_lr[2])
  538. # loss infor
  539. for k in loss_dict_reduced.keys():
  540. loss_val = loss_dict_reduced[k]
  541. if k == 'losses':
  542. loss_val *= self.grad_accumulate
  543. log += '[{}: {:.2f}]'.format(k, loss_val)
  544. # other infor
  545. log += '[time: {:.2f}]'.format(t1 - t0)
  546. log += '[size: {}]'.format(img_size)
  547. # print log infor
  548. print(log, flush=True)
  549. t0 = time.time()
  550. if self.args.debug:
  551. print("For debug mode, we only train 1 iteration")
  552. break
  553. # LR Schedule
  554. if not self.second_stage:
  555. self.lr_scheduler.step()
  556. def check_second_stage(self):
  557. # set second stage
  558. print('============== Second stage of Training ==============')
  559. self.second_stage = True
  560. # close mosaic augmentation
  561. if self.train_loader.dataset.mosaic_prob > 0.:
  562. print(' - Close < Mosaic Augmentation > ...')
  563. self.train_loader.dataset.mosaic_prob = 0.
  564. self.heavy_eval = True
  565. # close mixup augmentation
  566. if self.train_loader.dataset.mixup_prob > 0.:
  567. print(' - Close < Mixup Augmentation > ...')
  568. self.train_loader.dataset.mixup_prob = 0.
  569. self.heavy_eval = True
  570. # close rotation augmentation
  571. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  572. print(' - Close < degress of rotation > ...')
  573. self.trans_cfg['degrees'] = 0.0
  574. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  575. print(' - Close < shear of rotation >...')
  576. self.trans_cfg['shear'] = 0.0
  577. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  578. print(' - Close < perspective of rotation > ...')
  579. self.trans_cfg['perspective'] = 0.0
  580. # build a new transform for second stage
  581. print(' - Rebuild transforms ...')
  582. self.train_transform, self.trans_cfg = build_transform(
  583. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  584. self.train_loader.dataset.transform = self.train_transform
  585. def check_third_stage(self):
  586. # set third stage
  587. print('============== Third stage of Training ==============')
  588. self.third_stage = True
  589. # close random affine
  590. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  591. print(' - Close < translate of affine > ...')
  592. self.trans_cfg['translate'] = 0.0
  593. if 'scale' in self.trans_cfg.keys():
  594. print(' - Close < scale of affine >...')
  595. self.trans_cfg['scale'] = [1.0, 1.0]
  596. # build a new transform for second stage
  597. print(' - Rebuild transforms ...')
  598. self.train_transform, self.trans_cfg = build_transform(
  599. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  600. self.train_loader.dataset.transform = self.train_transform
  601. def refine_targets(self, targets, min_box_size):
  602. # rescale targets
  603. for tgt in targets:
  604. boxes = tgt["boxes"].clone()
  605. labels = tgt["labels"].clone()
  606. # refine tgt
  607. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  608. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  609. keep = (min_tgt_size >= min_box_size)
  610. tgt["boxes"] = boxes[keep]
  611. tgt["labels"] = labels[keep]
  612. return targets
  613. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  614. """
  615. Deployed for Multi scale trick.
  616. """
  617. if isinstance(stride, int):
  618. max_stride = stride
  619. elif isinstance(stride, list):
  620. max_stride = max(stride)
  621. # During training phase, the shape of input image is square.
  622. old_img_size = images.shape[-1]
  623. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  624. new_img_size = new_img_size // max_stride * max_stride # size
  625. if new_img_size / old_img_size != 1:
  626. # interpolate
  627. images = torch.nn.functional.interpolate(
  628. input=images,
  629. size=new_img_size,
  630. mode='bilinear',
  631. align_corners=False)
  632. # rescale targets
  633. for tgt in targets:
  634. boxes = tgt["boxes"].clone()
  635. labels = tgt["labels"].clone()
  636. boxes = torch.clamp(boxes, 0, old_img_size)
  637. # rescale box
  638. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  639. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  640. # refine tgt
  641. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  642. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  643. keep = (min_tgt_size >= min_box_size)
  644. tgt["boxes"] = boxes[keep]
  645. tgt["labels"] = labels[keep]
  646. return images, targets, new_img_size
  647. # RTCDet Trainer
  648. class RTCTrainer(object):
  649. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  650. # ------------------- basic parameters -------------------
  651. self.args = args
  652. self.epoch = 0
  653. self.best_map = -1.
  654. self.device = device
  655. self.criterion = criterion
  656. self.world_size = world_size
  657. self.grad_accumulate = args.grad_accumulate
  658. self.clip_grad = 35
  659. self.heavy_eval = False
  660. # weak augmentatino stage
  661. self.second_stage = False
  662. self.third_stage = False
  663. self.second_stage_epoch = args.no_aug_epoch
  664. self.third_stage_epoch = args.no_aug_epoch // 2
  665. # path to save model
  666. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  667. os.makedirs(self.path_to_save, exist_ok=True)
  668. # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
  669. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  670. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  671. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  672. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  673. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  674. self.data_cfg = data_cfg
  675. self.model_cfg = model_cfg
  676. self.trans_cfg = trans_cfg
  677. # ---------------------------- Build Transform ----------------------------
  678. self.train_transform, self.trans_cfg = build_transform(
  679. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  680. self.val_transform, _ = build_transform(
  681. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  682. # ---------------------------- Build Dataset & Dataloader ----------------------------
  683. self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  684. self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  685. # ---------------------------- Build Evaluator ----------------------------
  686. self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
  687. # ---------------------------- Build Grad. Scaler ----------------------------
  688. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  689. # ---------------------------- Build Optimizer ----------------------------
  690. self.optimizer_dict['lr0'] *= args.batch_size * self.grad_accumulate / 64
  691. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, args.resume)
  692. # ---------------------------- Build LR Scheduler ----------------------------
  693. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
  694. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  695. if self.args.resume and self.args.resume != 'None':
  696. self.lr_scheduler.step()
  697. # ---------------------------- Build Model-EMA ----------------------------
  698. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  699. print('Build ModelEMA ...')
  700. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  701. else:
  702. self.model_ema = None
  703. def train(self, model):
  704. for epoch in range(self.start_epoch, self.args.max_epoch):
  705. if self.args.distributed:
  706. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  707. # check second stage
  708. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  709. self.check_second_stage()
  710. # save model of the last mosaic epoch
  711. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  712. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  713. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  714. torch.save({'model': model.state_dict(),
  715. 'mAP': round(self.evaluator.map*100, 1),
  716. 'optimizer': self.optimizer.state_dict(),
  717. 'epoch': self.epoch,
  718. 'args': self.args},
  719. checkpoint_path)
  720. # check third stage
  721. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  722. self.check_third_stage()
  723. # save model of the last mosaic epoch
  724. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  725. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  726. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  727. torch.save({'model': model.state_dict(),
  728. 'mAP': round(self.evaluator.map*100, 1),
  729. 'optimizer': self.optimizer.state_dict(),
  730. 'epoch': self.epoch,
  731. 'args': self.args},
  732. checkpoint_path)
  733. # train one epoch
  734. self.epoch = epoch
  735. self.train_one_epoch(model)
  736. # eval one epoch
  737. if self.heavy_eval:
  738. model_eval = model.module if self.args.distributed else model
  739. self.eval(model_eval)
  740. else:
  741. model_eval = model.module if self.args.distributed else model
  742. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  743. self.eval(model_eval)
  744. if self.args.debug:
  745. print("For debug mode, we only train 1 epoch")
  746. break
  747. def eval(self, model):
  748. # chech model
  749. model_eval = model if self.model_ema is None else self.model_ema.ema
  750. if distributed_utils.is_main_process():
  751. # check evaluator
  752. if self.evaluator is None:
  753. print('No evaluator ... save model and go on training.')
  754. print('Saving state, epoch: {}'.format(self.epoch + 1))
  755. weight_name = '{}_no_eval.pth'.format(self.args.model)
  756. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  757. torch.save({'model': model_eval.state_dict(),
  758. 'mAP': -1.,
  759. 'optimizer': self.optimizer.state_dict(),
  760. 'epoch': self.epoch,
  761. 'args': self.args},
  762. checkpoint_path)
  763. else:
  764. print('eval ...')
  765. # set eval mode
  766. model_eval.trainable = False
  767. model_eval.eval()
  768. # evaluate
  769. with torch.no_grad():
  770. self.evaluator.evaluate(model_eval)
  771. # save model
  772. cur_map = self.evaluator.map
  773. if cur_map > self.best_map:
  774. # update best-map
  775. self.best_map = cur_map
  776. # save model
  777. print('Saving state, epoch:', self.epoch + 1)
  778. weight_name = '{}_best.pth'.format(self.args.model)
  779. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  780. torch.save({'model': model_eval.state_dict(),
  781. 'mAP': round(self.best_map*100, 1),
  782. 'optimizer': self.optimizer.state_dict(),
  783. 'epoch': self.epoch,
  784. 'args': self.args},
  785. checkpoint_path)
  786. # set train mode.
  787. model_eval.trainable = True
  788. model_eval.train()
  789. if self.args.distributed:
  790. # wait for all processes to synchronize
  791. dist.barrier()
  792. def train_one_epoch(self, model):
  793. # basic parameters
  794. epoch_size = len(self.train_loader)
  795. img_size = self.args.img_size
  796. t0 = time.time()
  797. nw = epoch_size * self.args.wp_epoch
  798. # Train one epoch
  799. for iter_i, (images, targets) in enumerate(self.train_loader):
  800. ni = iter_i + self.epoch * epoch_size
  801. # Warmup
  802. if ni <= nw:
  803. xi = [0, nw] # x interp
  804. for j, x in enumerate(self.optimizer.param_groups):
  805. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  806. x['lr'] = np.interp(
  807. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  808. if 'momentum' in x:
  809. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  810. # To device
  811. images = images.to(self.device, non_blocking=True).float() / 255.
  812. # Multi scale
  813. if self.args.multi_scale:
  814. images, targets, img_size = self.rescale_image_targets(
  815. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  816. else:
  817. targets = self.refine_targets(targets, self.args.min_box_size)
  818. # Visualize train targets
  819. if self.args.vis_tgt:
  820. vis_data(images*255, targets)
  821. # Inference
  822. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  823. outputs = model(images)
  824. # Compute loss
  825. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  826. losses = loss_dict['losses']
  827. # Grad Accumulate
  828. if self.grad_accumulate > 1:
  829. losses /= self.grad_accumulate
  830. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  831. # Backward
  832. self.scaler.scale(losses).backward()
  833. # Optimize
  834. if ni % self.grad_accumulate == 0:
  835. grad_norm = None
  836. if self.clip_grad > 0:
  837. # unscale gradients
  838. self.scaler.unscale_(self.optimizer)
  839. # clip gradients
  840. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  841. # optimizer.step
  842. self.scaler.step(self.optimizer)
  843. self.scaler.update()
  844. self.optimizer.zero_grad()
  845. # ema
  846. if self.model_ema is not None:
  847. self.model_ema.update(model)
  848. # Logs
  849. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  850. t1 = time.time()
  851. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  852. # basic infor
  853. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  854. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  855. log += '[lr: {:.6f}]'.format(cur_lr[2])
  856. # loss infor
  857. for k in loss_dict_reduced.keys():
  858. loss_val = loss_dict_reduced[k]
  859. if k == 'losses':
  860. loss_val *= self.grad_accumulate
  861. log += '[{}: {:.2f}]'.format(k, loss_val)
  862. # other infor
  863. log += '[grad_norm: {:.2f}]'.format(grad_norm)
  864. log += '[time: {:.2f}]'.format(t1 - t0)
  865. log += '[size: {}]'.format(img_size)
  866. # print log infor
  867. print(log, flush=True)
  868. t0 = time.time()
  869. if self.args.debug:
  870. print("For debug mode, we only train 1 iteration")
  871. break
  872. # LR Schedule
  873. if not self.second_stage:
  874. self.lr_scheduler.step()
  875. def refine_targets(self, targets, min_box_size):
  876. # rescale targets
  877. for tgt in targets:
  878. boxes = tgt["boxes"].clone()
  879. labels = tgt["labels"].clone()
  880. # refine tgt
  881. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  882. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  883. keep = (min_tgt_size >= min_box_size)
  884. tgt["boxes"] = boxes[keep]
  885. tgt["labels"] = labels[keep]
  886. return targets
  887. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  888. """
  889. Deployed for Multi scale trick.
  890. """
  891. if isinstance(stride, int):
  892. max_stride = stride
  893. elif isinstance(stride, list):
  894. max_stride = max(stride)
  895. # During training phase, the shape of input image is square.
  896. old_img_size = images.shape[-1]
  897. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  898. new_img_size = new_img_size // max_stride * max_stride # size
  899. if new_img_size / old_img_size != 1:
  900. # interpolate
  901. images = torch.nn.functional.interpolate(
  902. input=images,
  903. size=new_img_size,
  904. mode='bilinear',
  905. align_corners=False)
  906. # rescale targets
  907. for tgt in targets:
  908. boxes = tgt["boxes"].clone()
  909. labels = tgt["labels"].clone()
  910. boxes = torch.clamp(boxes, 0, old_img_size)
  911. # rescale box
  912. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  913. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  914. # refine tgt
  915. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  916. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  917. keep = (min_tgt_size >= min_box_size)
  918. tgt["boxes"] = boxes[keep]
  919. tgt["labels"] = labels[keep]
  920. return images, targets, new_img_size
  921. def check_second_stage(self):
  922. # set second stage
  923. print('============== Second stage of Training ==============')
  924. self.second_stage = True
  925. # close mosaic augmentation
  926. if self.train_loader.dataset.mosaic_prob > 0.:
  927. print(' - Close < Mosaic Augmentation > ...')
  928. self.train_loader.dataset.mosaic_prob = 0.
  929. self.heavy_eval = True
  930. # close mixup augmentation
  931. if self.train_loader.dataset.mixup_prob > 0.:
  932. print(' - Close < Mixup Augmentation > ...')
  933. self.train_loader.dataset.mixup_prob = 0.
  934. self.heavy_eval = True
  935. # close rotation augmentation
  936. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  937. print(' - Close < degress of rotation > ...')
  938. self.trans_cfg['degrees'] = 0.0
  939. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  940. print(' - Close < shear of rotation >...')
  941. self.trans_cfg['shear'] = 0.0
  942. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  943. print(' - Close < perspective of rotation > ...')
  944. self.trans_cfg['perspective'] = 0.0
  945. # build a new transform for second stage
  946. print(' - Rebuild transforms ...')
  947. self.train_transform, self.trans_cfg = build_transform(
  948. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  949. self.train_loader.dataset.transform = self.train_transform
  950. def check_third_stage(self):
  951. # set third stage
  952. print('============== Third stage of Training ==============')
  953. self.third_stage = True
  954. # close random affine
  955. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  956. print(' - Close < translate of affine > ...')
  957. self.trans_cfg['translate'] = 0.0
  958. if 'scale' in self.trans_cfg.keys():
  959. print(' - Close < scale of affine >...')
  960. self.trans_cfg['scale'] = [1.0, 1.0]
  961. # build a new transform for second stage
  962. print(' - Rebuild transforms ...')
  963. self.train_transform, self.trans_cfg = build_transform(
  964. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  965. self.train_loader.dataset.transform = self.train_transform
  966. # RTRDet Trainer
  967. class RTRTrainer(object):
  968. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  969. # ------------------- Basic parameters -------------------
  970. self.args = args
  971. self.epoch = 0
  972. self.best_map = -1.
  973. self.device = device
  974. self.criterion = criterion
  975. self.world_size = world_size
  976. self.grad_accumulate = args.grad_accumulate
  977. self.clip_grad = 35
  978. self.heavy_eval = False
  979. # weak augmentatino stage
  980. self.second_stage = False
  981. self.third_stage = False
  982. self.second_stage_epoch = args.no_aug_epoch
  983. self.third_stage_epoch = args.no_aug_epoch // 2
  984. # path to save model
  985. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  986. os.makedirs(self.path_to_save, exist_ok=True)
  987. # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
  988. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
  989. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  990. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  991. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  992. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  993. self.data_cfg = data_cfg
  994. self.model_cfg = model_cfg
  995. self.trans_cfg = trans_cfg
  996. # ---------------------------- Build Transform ----------------------------
  997. self.train_transform, self.trans_cfg = build_transform(
  998. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  999. self.val_transform, _ = build_transform(
  1000. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  1001. # ---------------------------- Build Dataset & Dataloader ----------------------------
  1002. self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  1003. self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  1004. # ---------------------------- Build Evaluator ----------------------------
  1005. self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
  1006. # ---------------------------- Build Grad. Scaler ----------------------------
  1007. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  1008. # ---------------------------- Build Optimizer ----------------------------
  1009. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  1010. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  1011. # ---------------------------- Build LR Scheduler ----------------------------
  1012. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
  1013. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  1014. if self.args.resume and self.args.resume != 'None':
  1015. self.lr_scheduler.step()
  1016. # ---------------------------- Build Model-EMA ----------------------------
  1017. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  1018. print('Build ModelEMA ...')
  1019. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  1020. else:
  1021. self.model_ema = None
  1022. def train(self, model):
  1023. for epoch in range(self.start_epoch, self.args.max_epoch):
  1024. if self.args.distributed:
  1025. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  1026. # check second stage
  1027. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  1028. self.check_second_stage()
  1029. # save model of the last mosaic epoch
  1030. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  1031. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1032. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  1033. torch.save({'model': model.state_dict(),
  1034. 'mAP': round(self.evaluator.map*100, 1),
  1035. 'optimizer': self.optimizer.state_dict(),
  1036. 'epoch': self.epoch,
  1037. 'args': self.args},
  1038. checkpoint_path)
  1039. # check third stage
  1040. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  1041. self.check_third_stage()
  1042. # save model of the last mosaic epoch
  1043. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  1044. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1045. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  1046. torch.save({'model': model.state_dict(),
  1047. 'mAP': round(self.evaluator.map*100, 1),
  1048. 'optimizer': self.optimizer.state_dict(),
  1049. 'epoch': self.epoch,
  1050. 'args': self.args},
  1051. checkpoint_path)
  1052. # train one epoch
  1053. self.epoch = epoch
  1054. self.train_one_epoch(model)
  1055. # eval one epoch
  1056. if self.heavy_eval:
  1057. model_eval = model.module if self.args.distributed else model
  1058. self.eval(model_eval)
  1059. else:
  1060. model_eval = model.module if self.args.distributed else model
  1061. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  1062. self.eval(model_eval)
  1063. def eval(self, model):
  1064. # chech model
  1065. model_eval = model if self.model_ema is None else self.model_ema.ema
  1066. if distributed_utils.is_main_process():
  1067. # check evaluator
  1068. if self.evaluator is None:
  1069. print('No evaluator ... save model and go on training.')
  1070. print('Saving state, epoch: {}'.format(self.epoch + 1))
  1071. weight_name = '{}_no_eval.pth'.format(self.args.model)
  1072. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1073. torch.save({'model': model_eval.state_dict(),
  1074. 'mAP': -1.,
  1075. 'optimizer': self.optimizer.state_dict(),
  1076. 'epoch': self.epoch,
  1077. 'args': self.args},
  1078. checkpoint_path)
  1079. else:
  1080. print('eval ...')
  1081. # set eval mode
  1082. model_eval.trainable = False
  1083. model_eval.eval()
  1084. # evaluate
  1085. with torch.no_grad():
  1086. self.evaluator.evaluate(model_eval)
  1087. # save model
  1088. cur_map = self.evaluator.map
  1089. if cur_map > self.best_map:
  1090. # update best-map
  1091. self.best_map = cur_map
  1092. # save model
  1093. print('Saving state, epoch:', self.epoch + 1)
  1094. weight_name = '{}_best.pth'.format(self.args.model)
  1095. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1096. torch.save({'model': model_eval.state_dict(),
  1097. 'mAP': round(self.best_map*100, 1),
  1098. 'optimizer': self.optimizer.state_dict(),
  1099. 'epoch': self.epoch,
  1100. 'args': self.args},
  1101. checkpoint_path)
  1102. # set train mode.
  1103. model_eval.trainable = True
  1104. model_eval.train()
  1105. if self.args.distributed:
  1106. # wait for all processes to synchronize
  1107. dist.barrier()
  1108. def train_one_epoch(self, model):
  1109. # basic parameters
  1110. epoch_size = len(self.train_loader)
  1111. img_size = self.args.img_size
  1112. t0 = time.time()
  1113. nw = epoch_size * self.args.wp_epoch
  1114. # Train one epoch
  1115. for iter_i, (images, targets) in enumerate(self.train_loader):
  1116. ni = iter_i + self.epoch * epoch_size
  1117. # Warmup
  1118. if ni <= nw:
  1119. xi = [0, nw] # x interp
  1120. for j, x in enumerate(self.optimizer.param_groups):
  1121. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  1122. x['lr'] = np.interp( ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  1123. if 'momentum' in x:
  1124. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  1125. # To device
  1126. images = images.to(self.device, non_blocking=True).float() / 255.
  1127. # Multi scale
  1128. if self.args.multi_scale:
  1129. images, targets, img_size = self.rescale_image_targets(
  1130. images, targets, self.model_cfg['max_stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  1131. else:
  1132. targets = self.refine_targets(targets, self.args.min_box_size)
  1133. # Normalize bbox
  1134. targets = self.normalize_bbox(targets, img_size)
  1135. # Visualize train targets
  1136. if self.args.vis_tgt:
  1137. targets = self.denormalize_bbox(targets, img_size)
  1138. vis_data(images*255, targets)
  1139. # Inference
  1140. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  1141. outputs = model(images)
  1142. # Compute loss
  1143. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  1144. losses = loss_dict['losses']
  1145. # Grad Accumulate
  1146. if self.grad_accumulate > 1:
  1147. losses /= self.grad_accumulate
  1148. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  1149. # Backward
  1150. self.scaler.scale(losses).backward()
  1151. # Optimize
  1152. if ni % self.grad_accumulate == 0:
  1153. grad_norm = None
  1154. if self.clip_grad > 0:
  1155. # unscale gradients
  1156. self.scaler.unscale_(self.optimizer)
  1157. # clip gradients
  1158. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  1159. # optimizer.step
  1160. self.scaler.step(self.optimizer)
  1161. self.scaler.update()
  1162. self.optimizer.zero_grad()
  1163. # ema
  1164. if self.model_ema is not None:
  1165. self.model_ema.update(model)
  1166. # Logs
  1167. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  1168. t1 = time.time()
  1169. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  1170. # basic infor
  1171. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  1172. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  1173. log += '[lr: {:.6f}]'.format(cur_lr[0])
  1174. # loss infor
  1175. for k in loss_dict_reduced.keys():
  1176. loss_val = loss_dict_reduced[k]
  1177. if k == 'losses':
  1178. loss_val *= self.grad_accumulate
  1179. log += '[{}: {:.2f}]'.format(k, loss_val)
  1180. # other infor
  1181. log += '[grad_norm: {:.2f}]'.format(grad_norm)
  1182. log += '[time: {:.2f}]'.format(t1 - t0)
  1183. log += '[size: {}]'.format(img_size)
  1184. # print log infor
  1185. print(log, flush=True)
  1186. t0 = time.time()
  1187. # LR Schedule
  1188. if not self.second_stage:
  1189. self.lr_scheduler.step()
  1190. def refine_targets(self, targets, min_box_size):
  1191. # rescale targets
  1192. for tgt in targets:
  1193. boxes = tgt["boxes"].clone()
  1194. labels = tgt["labels"].clone()
  1195. # refine tgt
  1196. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1197. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1198. keep = (min_tgt_size >= min_box_size)
  1199. tgt["boxes"] = boxes[keep]
  1200. tgt["labels"] = labels[keep]
  1201. return targets
  1202. def normalize_bbox(self, targets, img_size):
  1203. # normalize targets
  1204. for tgt in targets:
  1205. tgt["boxes"] /= img_size
  1206. return targets
  1207. def denormalize_bbox(self, targets, img_size):
  1208. # normalize targets
  1209. for tgt in targets:
  1210. tgt["boxes"] *= img_size
  1211. return targets
  1212. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  1213. """
  1214. Deployed for Multi scale trick.
  1215. """
  1216. if isinstance(stride, int):
  1217. max_stride = stride
  1218. elif isinstance(stride, list):
  1219. max_stride = max(stride)
  1220. # During training phase, the shape of input image is square.
  1221. old_img_size = images.shape[-1]
  1222. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  1223. new_img_size = new_img_size // max_stride * max_stride # size
  1224. if new_img_size / old_img_size != 1:
  1225. # interpolate
  1226. images = torch.nn.functional.interpolate(
  1227. input=images,
  1228. size=new_img_size,
  1229. mode='bilinear',
  1230. align_corners=False)
  1231. # rescale targets
  1232. for tgt in targets:
  1233. boxes = tgt["boxes"].clone()
  1234. labels = tgt["labels"].clone()
  1235. boxes = torch.clamp(boxes, 0, old_img_size)
  1236. # rescale box
  1237. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  1238. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  1239. # refine tgt
  1240. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1241. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1242. keep = (min_tgt_size >= min_box_size)
  1243. tgt["boxes"] = boxes[keep]
  1244. tgt["labels"] = labels[keep]
  1245. return images, targets, new_img_size
  1246. def check_second_stage(self):
  1247. # set second stage
  1248. print('============== Second stage of Training ==============')
  1249. self.second_stage = True
  1250. # close mosaic augmentation
  1251. if self.train_loader.dataset.mosaic_prob > 0.:
  1252. print(' - Close < Mosaic Augmentation > ...')
  1253. self.train_loader.dataset.mosaic_prob = 0.
  1254. self.heavy_eval = True
  1255. # close mixup augmentation
  1256. if self.train_loader.dataset.mixup_prob > 0.:
  1257. print(' - Close < Mixup Augmentation > ...')
  1258. self.train_loader.dataset.mixup_prob = 0.
  1259. self.heavy_eval = True
  1260. # close rotation augmentation
  1261. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  1262. print(' - Close < degress of rotation > ...')
  1263. self.trans_cfg['degrees'] = 0.0
  1264. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  1265. print(' - Close < shear of rotation >...')
  1266. self.trans_cfg['shear'] = 0.0
  1267. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  1268. print(' - Close < perspective of rotation > ...')
  1269. self.trans_cfg['perspective'] = 0.0
  1270. # build a new transform for second stage
  1271. print(' - Rebuild transforms ...')
  1272. self.train_transform, self.trans_cfg = build_transform(
  1273. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  1274. self.train_loader.dataset.transform = self.train_transform
  1275. def check_third_stage(self):
  1276. # set third stage
  1277. print('============== Third stage of Training ==============')
  1278. self.third_stage = True
  1279. # close random affine
  1280. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  1281. print(' - Close < translate of affine > ...')
  1282. self.trans_cfg['translate'] = 0.0
  1283. if 'scale' in self.trans_cfg.keys():
  1284. print(' - Close < scale of affine >...')
  1285. self.trans_cfg['scale'] = [1.0, 1.0]
  1286. # build a new transform for second stage
  1287. print(' - Rebuild transforms ...')
  1288. self.train_transform, self.trans_cfg = build_transform(
  1289. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  1290. self.train_loader.dataset.transform = self.train_transform
  1291. # Build Trainer
  1292. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  1293. if model_cfg['trainer_type'] == 'yolov8':
  1294. return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1295. elif model_cfg['trainer_type'] == 'yolox':
  1296. return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1297. elif model_cfg['trainer_type'] == 'rtcdet':
  1298. return RTCTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1299. elif model_cfg['trainer_type'] == 'rtrdet':
  1300. return RTRTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1301. else:
  1302. raise NotImplementedError