engine.py 67 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # YOLOv8 Trainer
  19. class Yolov8Trainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.device = device
  26. self.criterion = criterion
  27. self.world_size = world_size
  28. self.heavy_eval = False
  29. self.last_opt_step = 0
  30. self.clip_grad = 10
  31. # weak augmentatino stage
  32. self.second_stage = False
  33. self.third_stage = False
  34. self.second_stage_epoch = args.no_aug_epoch
  35. self.third_stage_epoch = args.no_aug_epoch // 2
  36. # path to save model
  37. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  38. os.makedirs(self.path_to_save, exist_ok=True)
  39. # ---------------------------- Hyperparameters refer to YOLOv8 ----------------------------
  40. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  41. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  42. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  43. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  44. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  45. self.data_cfg = data_cfg
  46. self.model_cfg = model_cfg
  47. self.trans_cfg = trans_cfg
  48. # ---------------------------- Build Transform ----------------------------
  49. self.train_transform, self.trans_cfg = build_transform(
  50. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  51. self.val_transform, _ = build_transform(
  52. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  53. # ---------------------------- Build Dataset & Dataloader ----------------------------
  54. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  55. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  56. # ---------------------------- Build Evaluator ----------------------------
  57. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  58. # ---------------------------- Build Grad. Scaler ----------------------------
  59. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  60. # ---------------------------- Build Optimizer ----------------------------
  61. accumulate = max(1, round(64 / self.args.batch_size))
  62. print('Grad Accumulate: {}'.format(accumulate))
  63. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  64. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  65. # ---------------------------- Build LR Scheduler ----------------------------
  66. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  67. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  68. if self.args.resume:
  69. self.lr_scheduler.step()
  70. # ---------------------------- Build Model-EMA ----------------------------
  71. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  72. print('Build ModelEMA ...')
  73. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  74. else:
  75. self.model_ema = None
  76. def train(self, model):
  77. for epoch in range(self.start_epoch, self.args.max_epoch):
  78. if self.args.distributed:
  79. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  80. # check second stage
  81. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  82. self.check_second_stage()
  83. # save model of the last mosaic epoch
  84. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  85. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  86. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  87. torch.save({'model': model.state_dict(),
  88. 'mAP': round(self.evaluator.map*100, 1),
  89. 'optimizer': self.optimizer.state_dict(),
  90. 'epoch': self.epoch,
  91. 'args': self.args},
  92. checkpoint_path)
  93. # check third stage
  94. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  95. self.check_third_stage()
  96. # save model of the last mosaic epoch
  97. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  98. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  99. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  100. torch.save({'model': model.state_dict(),
  101. 'mAP': round(self.evaluator.map*100, 1),
  102. 'optimizer': self.optimizer.state_dict(),
  103. 'epoch': self.epoch,
  104. 'args': self.args},
  105. checkpoint_path)
  106. # train one epoch
  107. self.epoch = epoch
  108. self.train_one_epoch(model)
  109. # eval one epoch
  110. if self.heavy_eval:
  111. model_eval = model.module if self.args.distributed else model
  112. self.eval(model_eval)
  113. else:
  114. model_eval = model.module if self.args.distributed else model
  115. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  116. self.eval(model_eval)
  117. def eval(self, model):
  118. # chech model
  119. model_eval = model if self.model_ema is None else self.model_ema.ema
  120. if distributed_utils.is_main_process():
  121. # check evaluator
  122. if self.evaluator is None:
  123. print('No evaluator ... save model and go on training.')
  124. print('Saving state, epoch: {}'.format(self.epoch + 1))
  125. weight_name = '{}_no_eval.pth'.format(self.args.model)
  126. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  127. torch.save({'model': model_eval.state_dict(),
  128. 'mAP': -1.,
  129. 'optimizer': self.optimizer.state_dict(),
  130. 'epoch': self.epoch,
  131. 'args': self.args},
  132. checkpoint_path)
  133. else:
  134. print('eval ...')
  135. # set eval mode
  136. model_eval.trainable = False
  137. model_eval.eval()
  138. # evaluate
  139. with torch.no_grad():
  140. self.evaluator.evaluate(model_eval)
  141. # save model
  142. cur_map = self.evaluator.map
  143. if cur_map > self.best_map:
  144. # update best-map
  145. self.best_map = cur_map
  146. # save model
  147. print('Saving state, epoch:', self.epoch + 1)
  148. weight_name = '{}_best.pth'.format(self.args.model)
  149. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  150. torch.save({'model': model_eval.state_dict(),
  151. 'mAP': round(self.best_map*100, 1),
  152. 'optimizer': self.optimizer.state_dict(),
  153. 'epoch': self.epoch,
  154. 'args': self.args},
  155. checkpoint_path)
  156. # set train mode.
  157. model_eval.trainable = True
  158. model_eval.train()
  159. if self.args.distributed:
  160. # wait for all processes to synchronize
  161. dist.barrier()
  162. def train_one_epoch(self, model):
  163. # basic parameters
  164. epoch_size = len(self.train_loader)
  165. img_size = self.args.img_size
  166. t0 = time.time()
  167. nw = epoch_size * self.args.wp_epoch
  168. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  169. # train one epoch
  170. for iter_i, (images, targets) in enumerate(self.train_loader):
  171. ni = iter_i + self.epoch * epoch_size
  172. # Warmup
  173. if ni <= nw:
  174. xi = [0, nw] # x interp
  175. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  176. for j, x in enumerate(self.optimizer.param_groups):
  177. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  178. x['lr'] = np.interp(
  179. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  180. if 'momentum' in x:
  181. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  182. # to device
  183. images = images.to(self.device, non_blocking=True).float() / 255.
  184. # Multi scale
  185. if self.args.multi_scale:
  186. images, targets, img_size = self.rescale_image_targets(
  187. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  188. else:
  189. targets = self.refine_targets(targets, self.args.min_box_size)
  190. # visualize train targets
  191. if self.args.vis_tgt:
  192. vis_data(images*255, targets)
  193. # inference
  194. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  195. outputs = model(images)
  196. # loss
  197. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  198. losses = loss_dict['losses']
  199. losses *= images.shape[0] # loss * bs
  200. # reduce
  201. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  202. # gradient averaged between devices in DDP mode
  203. losses *= distributed_utils.get_world_size()
  204. # backward
  205. self.scaler.scale(losses).backward()
  206. # Optimize
  207. if ni - self.last_opt_step >= accumulate:
  208. if self.clip_grad > 0:
  209. # unscale gradients
  210. self.scaler.unscale_(self.optimizer)
  211. # clip gradients
  212. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  213. # optimizer.step
  214. self.scaler.step(self.optimizer)
  215. self.scaler.update()
  216. self.optimizer.zero_grad()
  217. # ema
  218. if self.model_ema is not None:
  219. self.model_ema.update(model)
  220. self.last_opt_step = ni
  221. # display
  222. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  223. t1 = time.time()
  224. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  225. # basic infor
  226. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  227. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  228. log += '[lr: {:.6f}]'.format(cur_lr[2])
  229. # loss infor
  230. for k in loss_dict_reduced.keys():
  231. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  232. # other infor
  233. log += '[time: {:.2f}]'.format(t1 - t0)
  234. log += '[size: {}]'.format(img_size)
  235. # print log infor
  236. print(log, flush=True)
  237. t0 = time.time()
  238. self.lr_scheduler.step()
  239. def check_second_stage(self):
  240. # set second stage
  241. print('============== Second stage of Training ==============')
  242. self.second_stage = True
  243. # close mosaic augmentation
  244. if self.train_loader.dataset.mosaic_prob > 0.:
  245. print(' - Close < Mosaic Augmentation > ...')
  246. self.train_loader.dataset.mosaic_prob = 0.
  247. self.heavy_eval = True
  248. # close mixup augmentation
  249. if self.train_loader.dataset.mixup_prob > 0.:
  250. print(' - Close < Mixup Augmentation > ...')
  251. self.train_loader.dataset.mixup_prob = 0.
  252. self.heavy_eval = True
  253. # close rotation augmentation
  254. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  255. print(' - Close < degress of rotation > ...')
  256. self.trans_cfg['degrees'] = 0.0
  257. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  258. print(' - Close < shear of rotation >...')
  259. self.trans_cfg['shear'] = 0.0
  260. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  261. print(' - Close < perspective of rotation > ...')
  262. self.trans_cfg['perspective'] = 0.0
  263. # build a new transform for second stage
  264. print(' - Rebuild transforms ...')
  265. self.train_transform, self.trans_cfg = build_transform(
  266. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  267. self.train_loader.dataset.transform = self.train_transform
  268. def check_third_stage(self):
  269. # set third stage
  270. print('============== Third stage of Training ==============')
  271. self.third_stage = True
  272. # close random affine
  273. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  274. print(' - Close < translate of affine > ...')
  275. self.trans_cfg['translate'] = 0.0
  276. if 'scale' in self.trans_cfg.keys():
  277. print(' - Close < scale of affine >...')
  278. self.trans_cfg['scale'] = [1.0, 1.0]
  279. # build a new transform for second stage
  280. print(' - Rebuild transforms ...')
  281. self.train_transform, self.trans_cfg = build_transform(
  282. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  283. self.train_loader.dataset.transform = self.train_transform
  284. def refine_targets(self, targets, min_box_size):
  285. # rescale targets
  286. for tgt in targets:
  287. boxes = tgt["boxes"].clone()
  288. labels = tgt["labels"].clone()
  289. # refine tgt
  290. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  291. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  292. keep = (min_tgt_size >= min_box_size)
  293. tgt["boxes"] = boxes[keep]
  294. tgt["labels"] = labels[keep]
  295. return targets
  296. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  297. """
  298. Deployed for Multi scale trick.
  299. """
  300. if isinstance(stride, int):
  301. max_stride = stride
  302. elif isinstance(stride, list):
  303. max_stride = max(stride)
  304. # During training phase, the shape of input image is square.
  305. old_img_size = images.shape[-1]
  306. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  307. new_img_size = new_img_size // max_stride * max_stride # size
  308. if new_img_size / old_img_size != 1:
  309. # interpolate
  310. images = torch.nn.functional.interpolate(
  311. input=images,
  312. size=new_img_size,
  313. mode='bilinear',
  314. align_corners=False)
  315. # rescale targets
  316. for tgt in targets:
  317. boxes = tgt["boxes"].clone()
  318. labels = tgt["labels"].clone()
  319. boxes = torch.clamp(boxes, 0, old_img_size)
  320. # rescale box
  321. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  322. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  323. # refine tgt
  324. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  325. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  326. keep = (min_tgt_size >= min_box_size)
  327. tgt["boxes"] = boxes[keep]
  328. tgt["labels"] = labels[keep]
  329. return images, targets, new_img_size
  330. # YOLOX Trainer
  331. class YoloxTrainer(object):
  332. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  333. # ------------------- basic parameters -------------------
  334. self.args = args
  335. self.epoch = 0
  336. self.best_map = -1.
  337. self.device = device
  338. self.criterion = criterion
  339. self.world_size = world_size
  340. self.grad_accumulate = args.grad_accumulate
  341. self.no_aug_epoch = args.no_aug_epoch
  342. self.heavy_eval = False
  343. # weak augmentatino stage
  344. self.second_stage = False
  345. self.third_stage = False
  346. self.second_stage_epoch = args.no_aug_epoch
  347. self.third_stage_epoch = args.no_aug_epoch // 2
  348. # path to save model
  349. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  350. os.makedirs(self.path_to_save, exist_ok=True)
  351. # ---------------------------- Hyperparameters refer to YOLOX ----------------------------
  352. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.9, 'weight_decay': 5e-4, 'lr0': 0.01}
  353. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  354. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  355. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  356. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  357. self.data_cfg = data_cfg
  358. self.model_cfg = model_cfg
  359. self.trans_cfg = trans_cfg
  360. # ---------------------------- Build Transform ----------------------------
  361. self.train_transform, self.trans_cfg = build_transform(
  362. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  363. self.val_transform, _ = build_transform(
  364. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  365. # ---------------------------- Build Dataset & Dataloader ----------------------------
  366. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  367. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  368. # ---------------------------- Build Evaluator ----------------------------
  369. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  370. # ---------------------------- Build Grad. Scaler ----------------------------
  371. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  372. # ---------------------------- Build Optimizer ----------------------------
  373. self.optimizer_dict['lr0'] *= self.args.batch_size * self.grad_accumulate / 64
  374. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  375. # ---------------------------- Build LR Scheduler ----------------------------
  376. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch - self.no_aug_epoch)
  377. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  378. if self.args.resume:
  379. self.lr_scheduler.step()
  380. # ---------------------------- Build Model-EMA ----------------------------
  381. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  382. print('Build ModelEMA ...')
  383. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  384. else:
  385. self.model_ema = None
  386. def train(self, model):
  387. for epoch in range(self.start_epoch, self.args.max_epoch):
  388. if self.args.distributed:
  389. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  390. # check second stage
  391. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  392. self.check_second_stage()
  393. # save model of the last mosaic epoch
  394. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  395. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  396. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  397. torch.save({'model': model.state_dict(),
  398. 'mAP': round(self.evaluator.map*100, 1),
  399. 'optimizer': self.optimizer.state_dict(),
  400. 'epoch': self.epoch,
  401. 'args': self.args},
  402. checkpoint_path)
  403. # check third stage
  404. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  405. self.check_third_stage()
  406. # save model of the last mosaic epoch
  407. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  408. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  409. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  410. torch.save({'model': model.state_dict(),
  411. 'mAP': round(self.evaluator.map*100, 1),
  412. 'optimizer': self.optimizer.state_dict(),
  413. 'epoch': self.epoch,
  414. 'args': self.args},
  415. checkpoint_path)
  416. # train one epoch
  417. self.epoch = epoch
  418. self.train_one_epoch(model)
  419. # eval one epoch
  420. if self.heavy_eval:
  421. model_eval = model.module if self.args.distributed else model
  422. self.eval(model_eval)
  423. else:
  424. model_eval = model.module if self.args.distributed else model
  425. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  426. self.eval(model_eval)
  427. def eval(self, model):
  428. # chech model
  429. model_eval = model if self.model_ema is None else self.model_ema.ema
  430. if distributed_utils.is_main_process():
  431. # check evaluator
  432. if self.evaluator is None:
  433. print('No evaluator ... save model and go on training.')
  434. print('Saving state, epoch: {}'.format(self.epoch + 1))
  435. weight_name = '{}_no_eval.pth'.format(self.args.model)
  436. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  437. torch.save({'model': model_eval.state_dict(),
  438. 'mAP': -1.,
  439. 'optimizer': self.optimizer.state_dict(),
  440. 'epoch': self.epoch,
  441. 'args': self.args},
  442. checkpoint_path)
  443. else:
  444. print('eval ...')
  445. # set eval mode
  446. model_eval.trainable = False
  447. model_eval.eval()
  448. # evaluate
  449. with torch.no_grad():
  450. self.evaluator.evaluate(model_eval)
  451. # save model
  452. cur_map = self.evaluator.map
  453. if cur_map > self.best_map:
  454. # update best-map
  455. self.best_map = cur_map
  456. # save model
  457. print('Saving state, epoch:', self.epoch + 1)
  458. weight_name = '{}_best.pth'.format(self.args.model)
  459. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  460. torch.save({'model': model_eval.state_dict(),
  461. 'mAP': round(self.best_map*100, 1),
  462. 'optimizer': self.optimizer.state_dict(),
  463. 'epoch': self.epoch,
  464. 'args': self.args},
  465. checkpoint_path)
  466. # set train mode.
  467. model_eval.trainable = True
  468. model_eval.train()
  469. if self.args.distributed:
  470. # wait for all processes to synchronize
  471. dist.barrier()
  472. def train_one_epoch(self, model):
  473. # basic parameters
  474. epoch_size = len(self.train_loader)
  475. img_size = self.args.img_size
  476. t0 = time.time()
  477. nw = epoch_size * self.args.wp_epoch
  478. # Train one epoch
  479. for iter_i, (images, targets) in enumerate(self.train_loader):
  480. ni = iter_i + self.epoch * epoch_size
  481. # Warmup
  482. if ni <= nw:
  483. xi = [0, nw] # x interp
  484. for j, x in enumerate(self.optimizer.param_groups):
  485. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  486. x['lr'] = np.interp(
  487. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  488. if 'momentum' in x:
  489. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  490. # To device
  491. images = images.to(self.device, non_blocking=True).float() / 255.
  492. # Multi scale
  493. if self.args.multi_scale and ni % 10 == 0:
  494. images, targets, img_size = self.rescale_image_targets(
  495. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  496. else:
  497. targets = self.refine_targets(targets, self.args.min_box_size)
  498. # Visualize train targets
  499. if self.args.vis_tgt:
  500. vis_data(images*255, targets)
  501. # Inference
  502. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  503. outputs = model(images)
  504. # Compute loss
  505. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  506. losses = loss_dict['losses']
  507. # Grad Accu
  508. if self.grad_accumulate > 1:
  509. losses /= self.grad_accumulate
  510. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  511. # Backward
  512. self.scaler.scale(losses).backward()
  513. # Optimize
  514. if ni % self.grad_accumulate == 0:
  515. self.scaler.step(self.optimizer)
  516. self.scaler.update()
  517. self.optimizer.zero_grad()
  518. # ema
  519. if self.model_ema is not None:
  520. self.model_ema.update(model)
  521. # Logs
  522. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  523. t1 = time.time()
  524. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  525. # basic infor
  526. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  527. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  528. log += '[lr: {:.6f}]'.format(cur_lr[2])
  529. # loss infor
  530. for k in loss_dict_reduced.keys():
  531. loss_val = loss_dict_reduced[k]
  532. if k == 'losses':
  533. loss_val *= self.grad_accumulate
  534. log += '[{}: {:.2f}]'.format(k, loss_val)
  535. # other infor
  536. log += '[time: {:.2f}]'.format(t1 - t0)
  537. log += '[size: {}]'.format(img_size)
  538. # print log infor
  539. print(log, flush=True)
  540. t0 = time.time()
  541. # LR Schedule
  542. if not self.second_stage:
  543. self.lr_scheduler.step()
  544. def check_second_stage(self):
  545. # set second stage
  546. print('============== Second stage of Training ==============')
  547. self.second_stage = True
  548. # close mosaic augmentation
  549. if self.train_loader.dataset.mosaic_prob > 0.:
  550. print(' - Close < Mosaic Augmentation > ...')
  551. self.train_loader.dataset.mosaic_prob = 0.
  552. self.heavy_eval = True
  553. # close mixup augmentation
  554. if self.train_loader.dataset.mixup_prob > 0.:
  555. print(' - Close < Mixup Augmentation > ...')
  556. self.train_loader.dataset.mixup_prob = 0.
  557. self.heavy_eval = True
  558. # close rotation augmentation
  559. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  560. print(' - Close < degress of rotation > ...')
  561. self.trans_cfg['degrees'] = 0.0
  562. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  563. print(' - Close < shear of rotation >...')
  564. self.trans_cfg['shear'] = 0.0
  565. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  566. print(' - Close < perspective of rotation > ...')
  567. self.trans_cfg['perspective'] = 0.0
  568. # build a new transform for second stage
  569. print(' - Rebuild transforms ...')
  570. self.train_transform, self.trans_cfg = build_transform(
  571. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  572. self.train_loader.dataset.transform = self.train_transform
  573. def check_third_stage(self):
  574. # set third stage
  575. print('============== Third stage of Training ==============')
  576. self.third_stage = True
  577. # close random affine
  578. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  579. print(' - Close < translate of affine > ...')
  580. self.trans_cfg['translate'] = 0.0
  581. if 'scale' in self.trans_cfg.keys():
  582. print(' - Close < scale of affine >...')
  583. self.trans_cfg['scale'] = [1.0, 1.0]
  584. # build a new transform for second stage
  585. print(' - Rebuild transforms ...')
  586. self.train_transform, self.trans_cfg = build_transform(
  587. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  588. self.train_loader.dataset.transform = self.train_transform
  589. def refine_targets(self, targets, min_box_size):
  590. # rescale targets
  591. for tgt in targets:
  592. boxes = tgt["boxes"].clone()
  593. labels = tgt["labels"].clone()
  594. # refine tgt
  595. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  596. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  597. keep = (min_tgt_size >= min_box_size)
  598. tgt["boxes"] = boxes[keep]
  599. tgt["labels"] = labels[keep]
  600. return targets
  601. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  602. """
  603. Deployed for Multi scale trick.
  604. """
  605. if isinstance(stride, int):
  606. max_stride = stride
  607. elif isinstance(stride, list):
  608. max_stride = max(stride)
  609. # During training phase, the shape of input image is square.
  610. old_img_size = images.shape[-1]
  611. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  612. new_img_size = new_img_size // max_stride * max_stride # size
  613. if new_img_size / old_img_size != 1:
  614. # interpolate
  615. images = torch.nn.functional.interpolate(
  616. input=images,
  617. size=new_img_size,
  618. mode='bilinear',
  619. align_corners=False)
  620. # rescale targets
  621. for tgt in targets:
  622. boxes = tgt["boxes"].clone()
  623. labels = tgt["labels"].clone()
  624. boxes = torch.clamp(boxes, 0, old_img_size)
  625. # rescale box
  626. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  627. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  628. # refine tgt
  629. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  630. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  631. keep = (min_tgt_size >= min_box_size)
  632. tgt["boxes"] = boxes[keep]
  633. tgt["labels"] = labels[keep]
  634. return images, targets, new_img_size
  635. # RTCDet Trainer
  636. class RTCTrainer(object):
  637. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  638. # ------------------- basic parameters -------------------
  639. self.args = args
  640. self.epoch = 0
  641. self.best_map = -1.
  642. self.device = device
  643. self.criterion = criterion
  644. self.world_size = world_size
  645. self.grad_accumulate = args.grad_accumulate
  646. self.clip_grad = 35
  647. self.heavy_eval = False
  648. # weak augmentatino stage
  649. self.second_stage = False
  650. self.third_stage = False
  651. self.second_stage_epoch = args.no_aug_epoch
  652. self.third_stage_epoch = args.no_aug_epoch // 2
  653. # path to save model
  654. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  655. os.makedirs(self.path_to_save, exist_ok=True)
  656. # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
  657. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  658. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  659. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  660. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  661. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  662. self.data_cfg = data_cfg
  663. self.model_cfg = model_cfg
  664. self.trans_cfg = trans_cfg
  665. # ---------------------------- Build Transform ----------------------------
  666. self.train_transform, self.trans_cfg = build_transform(
  667. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  668. self.val_transform, _ = build_transform(
  669. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  670. # ---------------------------- Build Dataset & Dataloader ----------------------------
  671. self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  672. self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  673. # ---------------------------- Build Evaluator ----------------------------
  674. self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
  675. # ---------------------------- Build Grad. Scaler ----------------------------
  676. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  677. # ---------------------------- Build Optimizer ----------------------------
  678. self.optimizer_dict['lr0'] *= args.batch_size * self.grad_accumulate / 64
  679. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, args.resume)
  680. # ---------------------------- Build LR Scheduler ----------------------------
  681. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
  682. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  683. if self.args.resume:
  684. self.lr_scheduler.step()
  685. # ---------------------------- Build Model-EMA ----------------------------
  686. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  687. print('Build ModelEMA ...')
  688. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  689. else:
  690. self.model_ema = None
  691. def train(self, model):
  692. for epoch in range(self.start_epoch, self.args.max_epoch):
  693. if self.args.distributed:
  694. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  695. # check second stage
  696. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  697. self.check_second_stage()
  698. # save model of the last mosaic epoch
  699. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  700. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  701. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  702. torch.save({'model': model.state_dict(),
  703. 'mAP': round(self.evaluator.map*100, 1),
  704. 'optimizer': self.optimizer.state_dict(),
  705. 'epoch': self.epoch,
  706. 'args': self.args},
  707. checkpoint_path)
  708. # check third stage
  709. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  710. self.check_third_stage()
  711. # save model of the last mosaic epoch
  712. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  713. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  714. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  715. torch.save({'model': model.state_dict(),
  716. 'mAP': round(self.evaluator.map*100, 1),
  717. 'optimizer': self.optimizer.state_dict(),
  718. 'epoch': self.epoch,
  719. 'args': self.args},
  720. checkpoint_path)
  721. # train one epoch
  722. self.epoch = epoch
  723. self.train_one_epoch(model)
  724. # eval one epoch
  725. if self.heavy_eval:
  726. model_eval = model.module if self.args.distributed else model
  727. self.eval(model_eval)
  728. else:
  729. model_eval = model.module if self.args.distributed else model
  730. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  731. self.eval(model_eval)
  732. def eval(self, model):
  733. # chech model
  734. model_eval = model if self.model_ema is None else self.model_ema.ema
  735. if distributed_utils.is_main_process():
  736. # check evaluator
  737. if self.evaluator is None:
  738. print('No evaluator ... save model and go on training.')
  739. print('Saving state, epoch: {}'.format(self.epoch + 1))
  740. weight_name = '{}_no_eval.pth'.format(self.args.model)
  741. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  742. torch.save({'model': model_eval.state_dict(),
  743. 'mAP': -1.,
  744. 'optimizer': self.optimizer.state_dict(),
  745. 'epoch': self.epoch,
  746. 'args': self.args},
  747. checkpoint_path)
  748. else:
  749. print('eval ...')
  750. # set eval mode
  751. model_eval.trainable = False
  752. model_eval.eval()
  753. # evaluate
  754. with torch.no_grad():
  755. self.evaluator.evaluate(model_eval)
  756. # save model
  757. cur_map = self.evaluator.map
  758. if cur_map > self.best_map:
  759. # update best-map
  760. self.best_map = cur_map
  761. # save model
  762. print('Saving state, epoch:', self.epoch + 1)
  763. weight_name = '{}_best.pth'.format(self.args.model)
  764. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  765. torch.save({'model': model_eval.state_dict(),
  766. 'mAP': round(self.best_map*100, 1),
  767. 'optimizer': self.optimizer.state_dict(),
  768. 'epoch': self.epoch,
  769. 'args': self.args},
  770. checkpoint_path)
  771. # set train mode.
  772. model_eval.trainable = True
  773. model_eval.train()
  774. if self.args.distributed:
  775. # wait for all processes to synchronize
  776. dist.barrier()
  777. def train_one_epoch(self, model):
  778. # basic parameters
  779. epoch_size = len(self.train_loader)
  780. img_size = self.args.img_size
  781. t0 = time.time()
  782. nw = epoch_size * self.args.wp_epoch
  783. # Train one epoch
  784. for iter_i, (images, targets) in enumerate(self.train_loader):
  785. ni = iter_i + self.epoch * epoch_size
  786. # Warmup
  787. if ni <= nw:
  788. xi = [0, nw] # x interp
  789. for j, x in enumerate(self.optimizer.param_groups):
  790. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  791. x['lr'] = np.interp(
  792. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  793. if 'momentum' in x:
  794. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  795. # To device
  796. images = images.to(self.device, non_blocking=True).float() / 255.
  797. # Multi scale
  798. if self.args.multi_scale:
  799. images, targets, img_size = self.rescale_image_targets(
  800. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  801. else:
  802. targets = self.refine_targets(targets, self.args.min_box_size)
  803. # Visualize train targets
  804. if self.args.vis_tgt:
  805. vis_data(images*255, targets)
  806. # Inference
  807. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  808. outputs = model(images)
  809. # Compute loss
  810. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  811. losses = loss_dict['losses']
  812. # Grad Accumulate
  813. if self.grad_accumulate > 1:
  814. losses /= self.grad_accumulate
  815. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  816. # Backward
  817. self.scaler.scale(losses).backward()
  818. # Optimize
  819. if ni % self.grad_accumulate == 0:
  820. grad_norm = None
  821. if self.clip_grad > 0:
  822. # unscale gradients
  823. self.scaler.unscale_(self.optimizer)
  824. # clip gradients
  825. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  826. # optimizer.step
  827. self.scaler.step(self.optimizer)
  828. self.scaler.update()
  829. self.optimizer.zero_grad()
  830. # ema
  831. if self.model_ema is not None:
  832. self.model_ema.update(model)
  833. # Logs
  834. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  835. t1 = time.time()
  836. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  837. # basic infor
  838. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  839. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  840. log += '[lr: {:.6f}]'.format(cur_lr[2])
  841. # loss infor
  842. for k in loss_dict_reduced.keys():
  843. loss_val = loss_dict_reduced[k]
  844. if k == 'losses':
  845. loss_val *= self.grad_accumulate
  846. log += '[{}: {:.2f}]'.format(k, loss_val)
  847. # other infor
  848. log += '[grad_norm: {:.2f}]'.format(grad_norm)
  849. log += '[time: {:.2f}]'.format(t1 - t0)
  850. log += '[size: {}]'.format(img_size)
  851. # print log infor
  852. print(log, flush=True)
  853. t0 = time.time()
  854. # LR Schedule
  855. if not self.second_stage:
  856. self.lr_scheduler.step()
  857. def refine_targets(self, targets, min_box_size):
  858. # rescale targets
  859. for tgt in targets:
  860. boxes = tgt["boxes"].clone()
  861. labels = tgt["labels"].clone()
  862. # refine tgt
  863. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  864. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  865. keep = (min_tgt_size >= min_box_size)
  866. tgt["boxes"] = boxes[keep]
  867. tgt["labels"] = labels[keep]
  868. return targets
  869. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  870. """
  871. Deployed for Multi scale trick.
  872. """
  873. if isinstance(stride, int):
  874. max_stride = stride
  875. elif isinstance(stride, list):
  876. max_stride = max(stride)
  877. # During training phase, the shape of input image is square.
  878. old_img_size = images.shape[-1]
  879. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  880. new_img_size = new_img_size // max_stride * max_stride # size
  881. if new_img_size / old_img_size != 1:
  882. # interpolate
  883. images = torch.nn.functional.interpolate(
  884. input=images,
  885. size=new_img_size,
  886. mode='bilinear',
  887. align_corners=False)
  888. # rescale targets
  889. for tgt in targets:
  890. boxes = tgt["boxes"].clone()
  891. labels = tgt["labels"].clone()
  892. boxes = torch.clamp(boxes, 0, old_img_size)
  893. # rescale box
  894. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  895. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  896. # refine tgt
  897. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  898. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  899. keep = (min_tgt_size >= min_box_size)
  900. tgt["boxes"] = boxes[keep]
  901. tgt["labels"] = labels[keep]
  902. return images, targets, new_img_size
  903. def check_second_stage(self):
  904. # set second stage
  905. print('============== Second stage of Training ==============')
  906. self.second_stage = True
  907. # close mosaic augmentation
  908. if self.train_loader.dataset.mosaic_prob > 0.:
  909. print(' - Close < Mosaic Augmentation > ...')
  910. self.train_loader.dataset.mosaic_prob = 0.
  911. self.heavy_eval = True
  912. # close mixup augmentation
  913. if self.train_loader.dataset.mixup_prob > 0.:
  914. print(' - Close < Mixup Augmentation > ...')
  915. self.train_loader.dataset.mixup_prob = 0.
  916. self.heavy_eval = True
  917. # close rotation augmentation
  918. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  919. print(' - Close < degress of rotation > ...')
  920. self.trans_cfg['degrees'] = 0.0
  921. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  922. print(' - Close < shear of rotation >...')
  923. self.trans_cfg['shear'] = 0.0
  924. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  925. print(' - Close < perspective of rotation > ...')
  926. self.trans_cfg['perspective'] = 0.0
  927. # build a new transform for second stage
  928. print(' - Rebuild transforms ...')
  929. self.train_transform, self.trans_cfg = build_transform(
  930. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  931. self.train_loader.dataset.transform = self.train_transform
  932. def check_third_stage(self):
  933. # set third stage
  934. print('============== Third stage of Training ==============')
  935. self.third_stage = True
  936. # close random affine
  937. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  938. print(' - Close < translate of affine > ...')
  939. self.trans_cfg['translate'] = 0.0
  940. if 'scale' in self.trans_cfg.keys():
  941. print(' - Close < scale of affine >...')
  942. self.trans_cfg['scale'] = [1.0, 1.0]
  943. # build a new transform for second stage
  944. print(' - Rebuild transforms ...')
  945. self.train_transform, self.trans_cfg = build_transform(
  946. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  947. self.train_loader.dataset.transform = self.train_transform
  948. # RTRDet Trainer
  949. class RTRTrainer(object):
  950. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  951. # ------------------- Basic parameters -------------------
  952. self.args = args
  953. self.epoch = 0
  954. self.best_map = -1.
  955. self.device = device
  956. self.criterion = criterion
  957. self.world_size = world_size
  958. self.grad_accumulate = args.grad_accumulate
  959. self.clip_grad = 35
  960. self.heavy_eval = False
  961. # weak augmentatino stage
  962. self.second_stage = False
  963. self.third_stage = False
  964. self.second_stage_epoch = args.no_aug_epoch
  965. self.third_stage_epoch = args.no_aug_epoch // 2
  966. # path to save model
  967. self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  968. os.makedirs(self.path_to_save, exist_ok=True)
  969. # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
  970. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
  971. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  972. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
  973. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  974. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  975. self.data_cfg = data_cfg
  976. self.model_cfg = model_cfg
  977. self.trans_cfg = trans_cfg
  978. # ---------------------------- Build Transform ----------------------------
  979. self.train_transform, self.trans_cfg = build_transform(
  980. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  981. self.val_transform, _ = build_transform(
  982. args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  983. # ---------------------------- Build Dataset & Dataloader ----------------------------
  984. self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  985. self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  986. # ---------------------------- Build Evaluator ----------------------------
  987. self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
  988. # ---------------------------- Build Grad. Scaler ----------------------------
  989. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  990. # ---------------------------- Build Optimizer ----------------------------
  991. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  992. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  993. # ---------------------------- Build LR Scheduler ----------------------------
  994. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
  995. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  996. if self.args.resume:
  997. self.lr_scheduler.step()
  998. # ---------------------------- Build Model-EMA ----------------------------
  999. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  1000. print('Build ModelEMA ...')
  1001. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  1002. else:
  1003. self.model_ema = None
  1004. def train(self, model):
  1005. for epoch in range(self.start_epoch, self.args.max_epoch):
  1006. if self.args.distributed:
  1007. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  1008. # check second stage
  1009. if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
  1010. self.check_second_stage()
  1011. # save model of the last mosaic epoch
  1012. weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
  1013. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1014. print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
  1015. torch.save({'model': model.state_dict(),
  1016. 'mAP': round(self.evaluator.map*100, 1),
  1017. 'optimizer': self.optimizer.state_dict(),
  1018. 'epoch': self.epoch,
  1019. 'args': self.args},
  1020. checkpoint_path)
  1021. # check third stage
  1022. if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
  1023. self.check_third_stage()
  1024. # save model of the last mosaic epoch
  1025. weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
  1026. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1027. print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
  1028. torch.save({'model': model.state_dict(),
  1029. 'mAP': round(self.evaluator.map*100, 1),
  1030. 'optimizer': self.optimizer.state_dict(),
  1031. 'epoch': self.epoch,
  1032. 'args': self.args},
  1033. checkpoint_path)
  1034. # train one epoch
  1035. self.epoch = epoch
  1036. self.train_one_epoch(model)
  1037. # eval one epoch
  1038. if self.heavy_eval:
  1039. model_eval = model.module if self.args.distributed else model
  1040. self.eval(model_eval)
  1041. else:
  1042. model_eval = model.module if self.args.distributed else model
  1043. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  1044. self.eval(model_eval)
  1045. def eval(self, model):
  1046. # chech model
  1047. model_eval = model if self.model_ema is None else self.model_ema.ema
  1048. if distributed_utils.is_main_process():
  1049. # check evaluator
  1050. if self.evaluator is None:
  1051. print('No evaluator ... save model and go on training.')
  1052. print('Saving state, epoch: {}'.format(self.epoch + 1))
  1053. weight_name = '{}_no_eval.pth'.format(self.args.model)
  1054. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1055. torch.save({'model': model_eval.state_dict(),
  1056. 'mAP': -1.,
  1057. 'optimizer': self.optimizer.state_dict(),
  1058. 'epoch': self.epoch,
  1059. 'args': self.args},
  1060. checkpoint_path)
  1061. else:
  1062. print('eval ...')
  1063. # set eval mode
  1064. model_eval.trainable = False
  1065. model_eval.eval()
  1066. # evaluate
  1067. with torch.no_grad():
  1068. self.evaluator.evaluate(model_eval)
  1069. # save model
  1070. cur_map = self.evaluator.map
  1071. if cur_map > self.best_map:
  1072. # update best-map
  1073. self.best_map = cur_map
  1074. # save model
  1075. print('Saving state, epoch:', self.epoch + 1)
  1076. weight_name = '{}_best.pth'.format(self.args.model)
  1077. checkpoint_path = os.path.join(self.path_to_save, weight_name)
  1078. torch.save({'model': model_eval.state_dict(),
  1079. 'mAP': round(self.best_map*100, 1),
  1080. 'optimizer': self.optimizer.state_dict(),
  1081. 'epoch': self.epoch,
  1082. 'args': self.args},
  1083. checkpoint_path)
  1084. # set train mode.
  1085. model_eval.trainable = True
  1086. model_eval.train()
  1087. if self.args.distributed:
  1088. # wait for all processes to synchronize
  1089. dist.barrier()
  1090. def train_one_epoch(self, model):
  1091. # basic parameters
  1092. epoch_size = len(self.train_loader)
  1093. img_size = self.args.img_size
  1094. t0 = time.time()
  1095. nw = epoch_size * self.args.wp_epoch
  1096. # Train one epoch
  1097. for iter_i, (images, targets) in enumerate(self.train_loader):
  1098. ni = iter_i + self.epoch * epoch_size
  1099. # Warmup
  1100. if ni <= nw:
  1101. xi = [0, nw] # x interp
  1102. for j, x in enumerate(self.optimizer.param_groups):
  1103. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  1104. x['lr'] = np.interp( ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  1105. if 'momentum' in x:
  1106. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  1107. # To device
  1108. images = images.to(self.device, non_blocking=True).float() / 255.
  1109. # Multi scale
  1110. if self.args.multi_scale:
  1111. images, targets, img_size = self.rescale_image_targets(
  1112. images, targets, self.model_cfg['max_stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  1113. else:
  1114. targets = self.refine_targets(targets, self.args.min_box_size)
  1115. # Normalize bbox
  1116. targets = self.normalize_bbox(targets, img_size)
  1117. # Visualize train targets
  1118. if self.args.vis_tgt:
  1119. targets = self.denormalize_bbox(targets, img_size)
  1120. vis_data(images*255, targets)
  1121. # Inference
  1122. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  1123. outputs = model(images)
  1124. # Compute loss
  1125. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  1126. losses = loss_dict['losses']
  1127. # Grad Accumulate
  1128. if self.grad_accumulate > 1:
  1129. losses /= self.grad_accumulate
  1130. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  1131. # Backward
  1132. self.scaler.scale(losses).backward()
  1133. # Optimize
  1134. if ni % self.grad_accumulate == 0:
  1135. grad_norm = None
  1136. if self.clip_grad > 0:
  1137. # unscale gradients
  1138. self.scaler.unscale_(self.optimizer)
  1139. # clip gradients
  1140. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  1141. # optimizer.step
  1142. self.scaler.step(self.optimizer)
  1143. self.scaler.update()
  1144. self.optimizer.zero_grad()
  1145. # ema
  1146. if self.model_ema is not None:
  1147. self.model_ema.update(model)
  1148. # Logs
  1149. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  1150. t1 = time.time()
  1151. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  1152. # basic infor
  1153. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  1154. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  1155. log += '[lr: {:.6f}]'.format(cur_lr[0])
  1156. # loss infor
  1157. for k in loss_dict_reduced.keys():
  1158. loss_val = loss_dict_reduced[k]
  1159. if k == 'losses':
  1160. loss_val *= self.grad_accumulate
  1161. log += '[{}: {:.2f}]'.format(k, loss_val)
  1162. # other infor
  1163. log += '[grad_norm: {:.2f}]'.format(grad_norm)
  1164. log += '[time: {:.2f}]'.format(t1 - t0)
  1165. log += '[size: {}]'.format(img_size)
  1166. # print log infor
  1167. print(log, flush=True)
  1168. t0 = time.time()
  1169. # LR Schedule
  1170. if not self.second_stage:
  1171. self.lr_scheduler.step()
  1172. def refine_targets(self, targets, min_box_size):
  1173. # rescale targets
  1174. for tgt in targets:
  1175. boxes = tgt["boxes"].clone()
  1176. labels = tgt["labels"].clone()
  1177. # refine tgt
  1178. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1179. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1180. keep = (min_tgt_size >= min_box_size)
  1181. tgt["boxes"] = boxes[keep]
  1182. tgt["labels"] = labels[keep]
  1183. return targets
  1184. def normalize_bbox(self, targets, img_size):
  1185. # normalize targets
  1186. for tgt in targets:
  1187. tgt["boxes"] /= img_size
  1188. return targets
  1189. def denormalize_bbox(self, targets, img_size):
  1190. # normalize targets
  1191. for tgt in targets:
  1192. tgt["boxes"] *= img_size
  1193. return targets
  1194. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  1195. """
  1196. Deployed for Multi scale trick.
  1197. """
  1198. if isinstance(stride, int):
  1199. max_stride = stride
  1200. elif isinstance(stride, list):
  1201. max_stride = max(stride)
  1202. # During training phase, the shape of input image is square.
  1203. old_img_size = images.shape[-1]
  1204. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  1205. new_img_size = new_img_size // max_stride * max_stride # size
  1206. if new_img_size / old_img_size != 1:
  1207. # interpolate
  1208. images = torch.nn.functional.interpolate(
  1209. input=images,
  1210. size=new_img_size,
  1211. mode='bilinear',
  1212. align_corners=False)
  1213. # rescale targets
  1214. for tgt in targets:
  1215. boxes = tgt["boxes"].clone()
  1216. labels = tgt["labels"].clone()
  1217. boxes = torch.clamp(boxes, 0, old_img_size)
  1218. # rescale box
  1219. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  1220. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  1221. # refine tgt
  1222. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  1223. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  1224. keep = (min_tgt_size >= min_box_size)
  1225. tgt["boxes"] = boxes[keep]
  1226. tgt["labels"] = labels[keep]
  1227. return images, targets, new_img_size
  1228. def check_second_stage(self):
  1229. # set second stage
  1230. print('============== Second stage of Training ==============')
  1231. self.second_stage = True
  1232. # close mosaic augmentation
  1233. if self.train_loader.dataset.mosaic_prob > 0.:
  1234. print(' - Close < Mosaic Augmentation > ...')
  1235. self.train_loader.dataset.mosaic_prob = 0.
  1236. self.heavy_eval = True
  1237. # close mixup augmentation
  1238. if self.train_loader.dataset.mixup_prob > 0.:
  1239. print(' - Close < Mixup Augmentation > ...')
  1240. self.train_loader.dataset.mixup_prob = 0.
  1241. self.heavy_eval = True
  1242. # close rotation augmentation
  1243. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  1244. print(' - Close < degress of rotation > ...')
  1245. self.trans_cfg['degrees'] = 0.0
  1246. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  1247. print(' - Close < shear of rotation >...')
  1248. self.trans_cfg['shear'] = 0.0
  1249. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  1250. print(' - Close < perspective of rotation > ...')
  1251. self.trans_cfg['perspective'] = 0.0
  1252. # build a new transform for second stage
  1253. print(' - Rebuild transforms ...')
  1254. self.train_transform, self.trans_cfg = build_transform(
  1255. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  1256. self.train_loader.dataset.transform = self.train_transform
  1257. def check_third_stage(self):
  1258. # set third stage
  1259. print('============== Third stage of Training ==============')
  1260. self.third_stage = True
  1261. # close random affine
  1262. if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
  1263. print(' - Close < translate of affine > ...')
  1264. self.trans_cfg['translate'] = 0.0
  1265. if 'scale' in self.trans_cfg.keys():
  1266. print(' - Close < scale of affine >...')
  1267. self.trans_cfg['scale'] = [1.0, 1.0]
  1268. # build a new transform for second stage
  1269. print(' - Rebuild transforms ...')
  1270. self.train_transform, self.trans_cfg = build_transform(
  1271. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  1272. self.train_loader.dataset.transform = self.train_transform
  1273. # Build Trainer
  1274. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  1275. if model_cfg['trainer_type'] == 'yolov8':
  1276. return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1277. elif model_cfg['trainer_type'] == 'yolox':
  1278. return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1279. elif model_cfg['trainer_type'] == 'rtcdet':
  1280. return RTCTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1281. elif model_cfg['trainer_type'] == 'rtrdet':
  1282. return RTRTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  1283. else:
  1284. raise NotImplementedError