engine.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # Trainer refered to YOLOv8
  19. class YoloTrainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.no_aug_epoch = args.no_aug_epoch
  27. self.clip_grad = 10
  28. self.device = device
  29. self.criterion = criterion
  30. self.world_size = world_size
  31. self.heavy_eval = False
  32. self.second_stage = False
  33. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  34. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  35. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  36. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  37. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  38. self.data_cfg = data_cfg
  39. self.model_cfg = model_cfg
  40. self.trans_cfg = trans_cfg
  41. # ---------------------------- Build Transform ----------------------------
  42. self.train_transform, self.trans_cfg = build_transform(
  43. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  44. self.val_transform, _ = build_transform(
  45. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  46. # ---------------------------- Build Dataset & Dataloader ----------------------------
  47. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  48. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  49. # ---------------------------- Build Evaluator ----------------------------
  50. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  51. # ---------------------------- Build Grad. Scaler ----------------------------
  52. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  53. # ---------------------------- Build Optimizer ----------------------------
  54. accumulate = max(1, round(64 / self.args.batch_size))
  55. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  56. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  57. # ---------------------------- Build LR Scheduler ----------------------------
  58. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  59. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  60. if self.args.resume:
  61. self.lr_scheduler.step()
  62. # ---------------------------- Build Model-EMA ----------------------------
  63. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  64. print('Build ModelEMA ...')
  65. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  66. else:
  67. self.model_ema = None
  68. def check_second_stage(self):
  69. # set second stage
  70. print('============== Second stage of Training ==============')
  71. self.second_stage = True
  72. # close mosaic augmentation
  73. if self.train_loader.dataset.mosaic_prob > 0.:
  74. print(' - Close < Mosaic Augmentation > ...')
  75. self.train_loader.dataset.mosaic_prob = 0.
  76. self.heavy_eval = True
  77. # close mixup augmentation
  78. if self.train_loader.dataset.mixup_prob > 0.:
  79. print(' - Close < Mixup Augmentation > ...')
  80. self.train_loader.dataset.mixup_prob = 0.
  81. self.heavy_eval = True
  82. # close rotation augmentation
  83. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  84. print(' - Close < degress of rotation > ...')
  85. self.trans_cfg['degrees'] = 0.0
  86. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  87. print(' - Close < shear of rotation >...')
  88. self.trans_cfg['shear'] = 0.0
  89. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  90. print(' - Close < perspective of rotation > ...')
  91. self.trans_cfg['perspective'] = 0.0
  92. # build a new transform for second stage
  93. print(' - Rebuild transforms ...')
  94. self.train_transform, self.trans_cfg = build_transform(
  95. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  96. self.train_loader.dataset.transform = self.train_transform
  97. def train(self, model):
  98. for epoch in range(self.start_epoch, self.args.max_epoch):
  99. if self.args.distributed:
  100. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  101. # check second stage
  102. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  103. self.check_second_stage()
  104. # train one epoch
  105. self.epoch = epoch
  106. self.train_one_epoch(model)
  107. # eval one epoch
  108. if self.heavy_eval:
  109. model_eval = model.module if self.args.distributed else model
  110. self.eval(model_eval)
  111. else:
  112. model_eval = model.module if self.args.distributed else model
  113. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  114. self.eval(model_eval)
  115. def eval(self, model):
  116. # chech model
  117. model_eval = model if self.model_ema is None else self.model_ema.ema
  118. # path to save model
  119. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  120. os.makedirs(path_to_save, exist_ok=True)
  121. if distributed_utils.is_main_process():
  122. # check evaluator
  123. if self.evaluator is None:
  124. print('No evaluator ... save model and go on training.')
  125. print('Saving state, epoch: {}'.format(self.epoch + 1))
  126. weight_name = '{}_no_eval.pth'.format(self.args.model)
  127. checkpoint_path = os.path.join(path_to_save, weight_name)
  128. torch.save({'model': model_eval.state_dict(),
  129. 'mAP': -1.,
  130. 'optimizer': self.optimizer.state_dict(),
  131. 'epoch': self.epoch,
  132. 'args': self.args},
  133. checkpoint_path)
  134. else:
  135. print('eval ...')
  136. # set eval mode
  137. model_eval.trainable = False
  138. model_eval.eval()
  139. # evaluate
  140. with torch.no_grad():
  141. self.evaluator.evaluate(model_eval)
  142. # save model
  143. cur_map = self.evaluator.map
  144. if cur_map > self.best_map:
  145. # update best-map
  146. self.best_map = cur_map
  147. # save model
  148. print('Saving state, epoch:', self.epoch + 1)
  149. weight_name = '{}_best.pth'.format(self.args.model)
  150. checkpoint_path = os.path.join(path_to_save, weight_name)
  151. torch.save({'model': model_eval.state_dict(),
  152. 'mAP': round(self.best_map*100, 1),
  153. 'optimizer': self.optimizer.state_dict(),
  154. 'epoch': self.epoch,
  155. 'args': self.args},
  156. checkpoint_path)
  157. # set train mode.
  158. model_eval.trainable = True
  159. model_eval.train()
  160. if self.args.distributed:
  161. # wait for all processes to synchronize
  162. dist.barrier()
  163. def train_one_epoch(self, model):
  164. # basic parameters
  165. epoch_size = len(self.train_loader)
  166. img_size = self.args.img_size
  167. t0 = time.time()
  168. nw = epoch_size * self.args.wp_epoch
  169. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  170. # train one epoch
  171. for iter_i, (images, targets) in enumerate(self.train_loader):
  172. ni = iter_i + self.epoch * epoch_size
  173. # Warmup
  174. if ni <= nw:
  175. xi = [0, nw] # x interp
  176. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  177. for j, x in enumerate(self.optimizer.param_groups):
  178. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  179. x['lr'] = np.interp(
  180. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  181. if 'momentum' in x:
  182. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  183. # to device
  184. images = images.to(self.device, non_blocking=True).float() / 255.
  185. # Multi scale
  186. if self.args.multi_scale:
  187. images, targets, img_size = self.rescale_image_targets(
  188. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  189. else:
  190. targets = self.refine_targets(targets, self.args.min_box_size)
  191. # visualize train targets
  192. if self.args.vis_tgt:
  193. vis_data(images*255, targets)
  194. # inference
  195. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  196. outputs = model(images)
  197. # loss
  198. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  199. losses = loss_dict['losses']
  200. losses *= images.shape[0] # loss * bs
  201. # reduce
  202. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  203. # gradient averaged between devices in DDP mode
  204. losses *= distributed_utils.get_world_size()
  205. # backward
  206. self.scaler.scale(losses).backward()
  207. # Optimize
  208. if ni - self.last_opt_step >= accumulate:
  209. if self.clip_grad > 0:
  210. # unscale gradients
  211. self.scaler.unscale_(self.optimizer)
  212. # clip gradients
  213. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  214. # optimizer.step
  215. self.scaler.step(self.optimizer)
  216. self.scaler.update()
  217. self.optimizer.zero_grad()
  218. # ema
  219. if self.model_ema is not None:
  220. self.model_ema.update(model)
  221. self.last_opt_step = ni
  222. # display
  223. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  224. t1 = time.time()
  225. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  226. # basic infor
  227. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  228. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  229. log += '[lr: {:.6f}]'.format(cur_lr[2])
  230. # loss infor
  231. for k in loss_dict_reduced.keys():
  232. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  233. # other infor
  234. log += '[time: {:.2f}]'.format(t1 - t0)
  235. log += '[size: {}]'.format(img_size)
  236. # print log infor
  237. print(log, flush=True)
  238. t0 = time.time()
  239. self.lr_scheduler.step()
  240. def refine_targets(self, targets, min_box_size):
  241. # rescale targets
  242. for tgt in targets:
  243. boxes = tgt["boxes"].clone()
  244. labels = tgt["labels"].clone()
  245. # refine tgt
  246. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  247. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  248. keep = (min_tgt_size >= min_box_size)
  249. tgt["boxes"] = boxes[keep]
  250. tgt["labels"] = labels[keep]
  251. return targets
  252. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  253. """
  254. Deployed for Multi scale trick.
  255. """
  256. if isinstance(stride, int):
  257. max_stride = stride
  258. elif isinstance(stride, list):
  259. max_stride = max(stride)
  260. # During training phase, the shape of input image is square.
  261. old_img_size = images.shape[-1]
  262. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  263. new_img_size = new_img_size // max_stride * max_stride # size
  264. if new_img_size / old_img_size != 1:
  265. # interpolate
  266. images = torch.nn.functional.interpolate(
  267. input=images,
  268. size=new_img_size,
  269. mode='bilinear',
  270. align_corners=False)
  271. # rescale targets
  272. for tgt in targets:
  273. boxes = tgt["boxes"].clone()
  274. labels = tgt["labels"].clone()
  275. boxes = torch.clamp(boxes, 0, old_img_size)
  276. # rescale box
  277. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  278. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  279. # refine tgt
  280. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  281. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  282. keep = (min_tgt_size >= min_box_size)
  283. tgt["boxes"] = boxes[keep]
  284. tgt["labels"] = labels[keep]
  285. return images, targets, new_img_size
  286. # Trainer refered to RTMDet
  287. class RTMTrainer(object):
  288. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  289. # ------------------- basic parameters -------------------
  290. self.args = args
  291. self.epoch = 0
  292. self.best_map = -1.
  293. self.device = device
  294. self.criterion = criterion
  295. self.world_size = world_size
  296. self.no_aug_epoch = args.no_aug_epoch
  297. self.clip_grad = 35
  298. self.heavy_eval = False
  299. self.second_stage = False
  300. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  301. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  302. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  303. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  304. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  305. self.data_cfg = data_cfg
  306. self.model_cfg = model_cfg
  307. self.trans_cfg = trans_cfg
  308. # ---------------------------- Build Transform ----------------------------
  309. self.train_transform, self.trans_cfg = build_transform(
  310. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  311. self.val_transform, _ = build_transform(
  312. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  313. # ---------------------------- Build Dataset & Dataloader ----------------------------
  314. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  315. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  316. # ---------------------------- Build Evaluator ----------------------------
  317. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  318. # ---------------------------- Build Grad. Scaler ----------------------------
  319. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  320. # ---------------------------- Build Optimizer ----------------------------
  321. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  322. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  323. # ---------------------------- Build LR Scheduler ----------------------------
  324. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  325. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  326. if self.args.resume:
  327. self.lr_scheduler.step()
  328. # ---------------------------- Build Model-EMA ----------------------------
  329. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  330. print('Build ModelEMA ...')
  331. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  332. else:
  333. self.model_ema = None
  334. def train(self, model):
  335. for epoch in range(self.start_epoch, self.args.max_epoch):
  336. if self.args.distributed:
  337. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  338. # check second stage
  339. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  340. self.check_second_stage()
  341. # train one epoch
  342. self.epoch = epoch
  343. self.train_one_epoch(model)
  344. # eval one epoch
  345. if self.heavy_eval:
  346. model_eval = model.module if self.args.distributed else model
  347. self.eval(model_eval)
  348. else:
  349. model_eval = model.module if self.args.distributed else model
  350. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  351. self.eval(model_eval)
  352. def eval(self, model):
  353. # chech model
  354. model_eval = model if self.model_ema is None else self.model_ema.ema
  355. # path to save model
  356. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  357. os.makedirs(path_to_save, exist_ok=True)
  358. if distributed_utils.is_main_process():
  359. # check evaluator
  360. if self.evaluator is None:
  361. print('No evaluator ... save model and go on training.')
  362. print('Saving state, epoch: {}'.format(self.epoch + 1))
  363. weight_name = '{}_no_eval.pth'.format(self.args.model)
  364. checkpoint_path = os.path.join(path_to_save, weight_name)
  365. torch.save({'model': model_eval.state_dict(),
  366. 'mAP': -1.,
  367. 'optimizer': self.optimizer.state_dict(),
  368. 'epoch': self.epoch,
  369. 'args': self.args},
  370. checkpoint_path)
  371. else:
  372. print('eval ...')
  373. # set eval mode
  374. model_eval.trainable = False
  375. model_eval.eval()
  376. # evaluate
  377. with torch.no_grad():
  378. self.evaluator.evaluate(model_eval)
  379. # save model
  380. cur_map = self.evaluator.map
  381. if cur_map > self.best_map:
  382. # update best-map
  383. self.best_map = cur_map
  384. # save model
  385. print('Saving state, epoch:', self.epoch + 1)
  386. weight_name = '{}_best.pth'.format(self.args.model)
  387. checkpoint_path = os.path.join(path_to_save, weight_name)
  388. torch.save({'model': model_eval.state_dict(),
  389. 'mAP': round(self.best_map*100, 1),
  390. 'optimizer': self.optimizer.state_dict(),
  391. 'epoch': self.epoch,
  392. 'args': self.args},
  393. checkpoint_path)
  394. # set train mode.
  395. model_eval.trainable = True
  396. model_eval.train()
  397. if self.args.distributed:
  398. # wait for all processes to synchronize
  399. dist.barrier()
  400. def train_one_epoch(self, model):
  401. # basic parameters
  402. epoch_size = len(self.train_loader)
  403. img_size = self.args.img_size
  404. t0 = time.time()
  405. nw = epoch_size * self.args.wp_epoch
  406. # Train one epoch
  407. for iter_i, (images, targets) in enumerate(self.train_loader):
  408. ni = iter_i + self.epoch * epoch_size
  409. # Warmup
  410. if ni <= nw:
  411. xi = [0, nw] # x interp
  412. for j, x in enumerate(self.optimizer.param_groups):
  413. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  414. x['lr'] = np.interp(
  415. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  416. if 'momentum' in x:
  417. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  418. # To device
  419. images = images.to(self.device, non_blocking=True).float() / 255.
  420. # Multi scale
  421. if self.args.multi_scale:
  422. images, targets, img_size = self.rescale_image_targets(
  423. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  424. else:
  425. targets = self.refine_targets(targets, self.args.min_box_size)
  426. # Visualize train targets
  427. if self.args.vis_tgt:
  428. vis_data(images*255, targets)
  429. # Inference
  430. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  431. outputs = model(images)
  432. # Compute loss
  433. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  434. losses = loss_dict['losses']
  435. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  436. # Backward
  437. self.scaler.scale(losses).backward()
  438. # Optimize
  439. if self.clip_grad > 0:
  440. # unscale gradients
  441. self.scaler.unscale_(self.optimizer)
  442. # clip gradients
  443. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  444. # optimizer.step
  445. self.scaler.step(self.optimizer)
  446. self.scaler.update()
  447. self.optimizer.zero_grad()
  448. # ema
  449. if self.model_ema is not None:
  450. self.model_ema.update(model)
  451. # Logs
  452. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  453. t1 = time.time()
  454. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  455. # basic infor
  456. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  457. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  458. log += '[lr: {:.6f}]'.format(cur_lr[2])
  459. # loss infor
  460. for k in loss_dict_reduced.keys():
  461. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  462. # other infor
  463. log += '[time: {:.2f}]'.format(t1 - t0)
  464. log += '[size: {}]'.format(img_size)
  465. # print log infor
  466. print(log, flush=True)
  467. t0 = time.time()
  468. # LR Schedule
  469. self.lr_scheduler.step()
  470. def check_second_stage(self):
  471. # set second stage
  472. print('============== Second stage of Training ==============')
  473. self.second_stage = True
  474. # close mosaic augmentation
  475. if self.train_loader.dataset.mosaic_prob > 0.:
  476. print(' - Close < Mosaic Augmentation > ...')
  477. self.train_loader.dataset.mosaic_prob = 0.
  478. self.heavy_eval = True
  479. # close mixup augmentation
  480. if self.train_loader.dataset.mixup_prob > 0.:
  481. print(' - Close < Mixup Augmentation > ...')
  482. self.train_loader.dataset.mixup_prob = 0.
  483. self.heavy_eval = True
  484. # close rotation augmentation
  485. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  486. print(' - Close < degress of rotation > ...')
  487. self.trans_cfg['degrees'] = 0.0
  488. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  489. print(' - Close < shear of rotation >...')
  490. self.trans_cfg['shear'] = 0.0
  491. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  492. print(' - Close < perspective of rotation > ...')
  493. self.trans_cfg['perspective'] = 0.0
  494. # build a new transform for second stage
  495. print(' - Rebuild transforms ...')
  496. self.train_transform, self.trans_cfg = build_transform(
  497. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  498. self.train_loader.dataset.transform = self.train_transform
  499. def refine_targets(self, targets, min_box_size):
  500. # rescale targets
  501. for tgt in targets:
  502. boxes = tgt["boxes"].clone()
  503. labels = tgt["labels"].clone()
  504. # refine tgt
  505. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  506. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  507. keep = (min_tgt_size >= min_box_size)
  508. tgt["boxes"] = boxes[keep]
  509. tgt["labels"] = labels[keep]
  510. return targets
  511. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  512. """
  513. Deployed for Multi scale trick.
  514. """
  515. if isinstance(stride, int):
  516. max_stride = stride
  517. elif isinstance(stride, list):
  518. max_stride = max(stride)
  519. # During training phase, the shape of input image is square.
  520. old_img_size = images.shape[-1]
  521. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  522. new_img_size = new_img_size // max_stride * max_stride # size
  523. if new_img_size / old_img_size != 1:
  524. # interpolate
  525. images = torch.nn.functional.interpolate(
  526. input=images,
  527. size=new_img_size,
  528. mode='bilinear',
  529. align_corners=False)
  530. # rescale targets
  531. for tgt in targets:
  532. boxes = tgt["boxes"].clone()
  533. labels = tgt["labels"].clone()
  534. boxes = torch.clamp(boxes, 0, old_img_size)
  535. # rescale box
  536. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  537. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  538. # refine tgt
  539. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  540. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  541. keep = (min_tgt_size >= min_box_size)
  542. tgt["boxes"] = boxes[keep]
  543. tgt["labels"] = labels[keep]
  544. return images, targets, new_img_size
  545. # Trainer for DETR
  546. class DetrTrainer(object):
  547. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  548. # ------------------- basic parameters -------------------
  549. self.args = args
  550. self.epoch = 0
  551. self.best_map = -1.
  552. self.last_opt_step = 0
  553. self.no_aug_epoch = args.no_aug_epoch
  554. self.clip_grad = -1
  555. self.device = device
  556. self.criterion = criterion
  557. self.world_size = world_size
  558. self.second_stage = False
  559. self.heavy_eval = False
  560. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001}
  561. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  562. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.1}
  563. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  564. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  565. self.data_cfg = data_cfg
  566. self.model_cfg = model_cfg
  567. self.trans_cfg = trans_cfg
  568. # ---------------------------- Build Transform ----------------------------
  569. self.train_transform, self.trans_cfg = build_transform(
  570. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  571. self.val_transform, _ = build_transform(
  572. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  573. # ---------------------------- Build Dataset & Dataloader ----------------------------
  574. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  575. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  576. # ---------------------------- Build Evaluator ----------------------------
  577. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  578. # ---------------------------- Build Grad. Scaler ----------------------------
  579. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  580. # ---------------------------- Build Optimizer ----------------------------
  581. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  582. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  583. # ---------------------------- Build LR Scheduler ----------------------------
  584. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  585. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  586. if self.args.resume:
  587. self.lr_scheduler.step()
  588. # ---------------------------- Build Model-EMA ----------------------------
  589. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  590. print('Build ModelEMA ...')
  591. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  592. else:
  593. self.model_ema = None
  594. def check_second_stage(self):
  595. # set second stage
  596. print('============== Second stage of Training ==============')
  597. self.second_stage = True
  598. # close mosaic augmentation
  599. if self.train_loader.dataset.mosaic_prob > 0.:
  600. print(' - Close < Mosaic Augmentation > ...')
  601. self.train_loader.dataset.mosaic_prob = 0.
  602. self.heavy_eval = True
  603. # close mixup augmentation
  604. if self.train_loader.dataset.mixup_prob > 0.:
  605. print(' - Close < Mixup Augmentation > ...')
  606. self.train_loader.dataset.mixup_prob = 0.
  607. self.heavy_eval = True
  608. # close rotation augmentation
  609. if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
  610. print(' - Close < degress of rotation > ...')
  611. self.trans_cfg['degrees'] = 0.0
  612. if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
  613. print(' - Close < shear of rotation >...')
  614. self.trans_cfg['shear'] = 0.0
  615. if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
  616. print(' - Close < perspective of rotation > ...')
  617. self.trans_cfg['perspective'] = 0.0
  618. # build a new transform for second stage
  619. print(' - Rebuild transforms ...')
  620. self.train_transform, self.trans_cfg = build_transform(
  621. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  622. self.train_loader.dataset.transform = self.train_transform
  623. def train(self, model):
  624. for epoch in range(self.start_epoch, self.args.max_epoch):
  625. if self.args.distributed:
  626. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  627. # check second stage
  628. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
  629. self.check_second_stage()
  630. # train one epoch
  631. self.epoch = epoch
  632. self.train_one_epoch(model)
  633. # eval one epoch
  634. if self.heavy_eval:
  635. model_eval = model.module if self.args.distributed else model
  636. self.eval(model_eval)
  637. else:
  638. model_eval = model.module if self.args.distributed else model
  639. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  640. self.eval(model_eval)
  641. def eval(self, model):
  642. # chech model
  643. model_eval = model if self.model_ema is None else self.model_ema.ema
  644. # path to save model
  645. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  646. os.makedirs(path_to_save, exist_ok=True)
  647. if distributed_utils.is_main_process():
  648. # check evaluator
  649. if self.evaluator is None:
  650. print('No evaluator ... save model and go on training.')
  651. print('Saving state, epoch: {}'.format(self.epoch + 1))
  652. weight_name = '{}_no_eval.pth'.format(self.args.model)
  653. checkpoint_path = os.path.join(path_to_save, weight_name)
  654. torch.save({'model': model_eval.state_dict(),
  655. 'mAP': -1.,
  656. 'optimizer': self.optimizer.state_dict(),
  657. 'epoch': self.epoch,
  658. 'args': self.args},
  659. checkpoint_path)
  660. else:
  661. print('eval ...')
  662. # set eval mode
  663. model_eval.trainable = False
  664. model_eval.eval()
  665. # evaluate
  666. with torch.no_grad():
  667. self.evaluator.evaluate(model_eval)
  668. # save model
  669. cur_map = self.evaluator.map
  670. if cur_map > self.best_map:
  671. # update best-map
  672. self.best_map = cur_map
  673. # save model
  674. print('Saving state, epoch:', self.epoch + 1)
  675. weight_name = '{}_best.pth'.format(self.args.model)
  676. checkpoint_path = os.path.join(path_to_save, weight_name)
  677. torch.save({'model': model_eval.state_dict(),
  678. 'mAP': round(self.best_map*100, 1),
  679. 'optimizer': self.optimizer.state_dict(),
  680. 'epoch': self.epoch,
  681. 'args': self.args},
  682. checkpoint_path)
  683. # set train mode.
  684. model_eval.trainable = True
  685. model_eval.train()
  686. if self.args.distributed:
  687. # wait for all processes to synchronize
  688. dist.barrier()
  689. def train_one_epoch(self, model):
  690. # basic parameters
  691. epoch_size = len(self.train_loader)
  692. img_size = self.args.img_size
  693. t0 = time.time()
  694. nw = epoch_size * self.args.wp_epoch
  695. # train one epoch
  696. for iter_i, (images, targets) in enumerate(self.train_loader):
  697. ni = iter_i + self.epoch * epoch_size
  698. # Warmup
  699. if ni <= nw:
  700. xi = [0, nw] # x interp
  701. for j, x in enumerate(self.optimizer.param_groups):
  702. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  703. x['lr'] = np.interp(
  704. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  705. if 'momentum' in x:
  706. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  707. # To device
  708. images = images.to(self.device, non_blocking=True).float() / 255.
  709. # Multi scale
  710. if self.args.multi_scale:
  711. images, targets, img_size = self.rescale_image_targets(
  712. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  713. else:
  714. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  715. # Visualize targets
  716. if self.args.vis_tgt:
  717. vis_data(images*255, targets)
  718. # Inference
  719. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  720. outputs = model(images)
  721. # Compute loss
  722. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  723. losses = loss_dict['losses']
  724. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  725. # Backward
  726. self.scaler.scale(losses).backward()
  727. # Optimize
  728. if self.clip_grad > 0:
  729. # unscale gradients
  730. self.scaler.unscale_(self.optimizer)
  731. # clip gradients
  732. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  733. self.scaler.step(self.optimizer)
  734. self.scaler.update()
  735. self.optimizer.zero_grad()
  736. # Model EMA
  737. if self.model_ema is not None:
  738. self.model_ema.update(model)
  739. self.last_opt_step = ni
  740. # Log
  741. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  742. t1 = time.time()
  743. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  744. # basic infor
  745. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  746. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  747. log += '[lr: {:.6f}]'.format(cur_lr[0])
  748. # loss infor
  749. for k in loss_dict_reduced.keys():
  750. if self.args.vis_aux_loss:
  751. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  752. else:
  753. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  754. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  755. # other infor
  756. log += '[time: {:.2f}]'.format(t1 - t0)
  757. log += '[size: {}]'.format(img_size)
  758. # print log infor
  759. print(log, flush=True)
  760. t0 = time.time()
  761. # LR Scheduler
  762. self.lr_scheduler.step()
  763. def refine_targets(self, targets, min_box_size, img_size):
  764. # rescale targets
  765. for tgt in targets:
  766. boxes = tgt["boxes"]
  767. labels = tgt["labels"]
  768. # refine tgt
  769. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  770. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  771. keep = (min_tgt_size >= min_box_size)
  772. # xyxy -> cxcywh
  773. new_boxes = torch.zeros_like(boxes)
  774. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  775. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  776. # normalize
  777. new_boxes /= img_size
  778. del boxes
  779. tgt["boxes"] = new_boxes[keep]
  780. tgt["labels"] = labels[keep]
  781. return targets
  782. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  783. """
  784. Deployed for Multi scale trick.
  785. """
  786. if isinstance(stride, int):
  787. max_stride = stride
  788. elif isinstance(stride, list):
  789. max_stride = max(stride)
  790. # During training phase, the shape of input image is square.
  791. old_img_size = images.shape[-1]
  792. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  793. new_img_size = new_img_size // max_stride * max_stride # size
  794. if new_img_size / old_img_size != 1:
  795. # interpolate
  796. images = torch.nn.functional.interpolate(
  797. input=images,
  798. size=new_img_size,
  799. mode='bilinear',
  800. align_corners=False)
  801. # rescale targets
  802. for tgt in targets:
  803. boxes = tgt["boxes"].clone()
  804. labels = tgt["labels"].clone()
  805. boxes = torch.clamp(boxes, 0, old_img_size)
  806. # rescale box
  807. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  808. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  809. # refine tgt
  810. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  811. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  812. keep = (min_tgt_size >= min_box_size)
  813. # xyxy -> cxcywh
  814. new_boxes = torch.zeros_like(boxes)
  815. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  816. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  817. # normalize
  818. new_boxes /= new_img_size
  819. del boxes
  820. tgt["boxes"] = new_boxes[keep]
  821. tgt["labels"] = labels[keep]
  822. return images, targets, new_img_size
  823. # Build Trainer
  824. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  825. if model_cfg['trainer_type'] == 'yolo':
  826. return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  827. elif model_cfg['trainer_type'] == 'rtmdet':
  828. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  829. elif model_cfg['trainer_type'] == 'detr':
  830. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  831. else:
  832. raise NotImplementedError