engine.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # Trainer refered to YOLOv8
  19. class YoloTrainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.no_aug_epoch = args.no_aug_epoch
  27. self.clip_grad = 10
  28. self.device = device
  29. self.criterion = criterion
  30. self.world_size = world_size
  31. self.heavy_eval = False
  32. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  33. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  34. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  35. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  36. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  37. self.data_cfg = data_cfg
  38. self.model_cfg = model_cfg
  39. self.trans_cfg = trans_cfg
  40. # ---------------------------- Build Transform ----------------------------
  41. self.train_transform, self.trans_cfg = build_transform(
  42. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  43. self.val_transform, _ = build_transform(
  44. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  45. # ---------------------------- Build Dataset & Dataloader ----------------------------
  46. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  47. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  48. # ---------------------------- Build Evaluator ----------------------------
  49. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  50. # ---------------------------- Build Grad. Scaler ----------------------------
  51. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  52. # ---------------------------- Build Optimizer ----------------------------
  53. accumulate = max(1, round(64 / self.args.batch_size))
  54. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  55. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  56. # ---------------------------- Build LR Scheduler ----------------------------
  57. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  58. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  59. if self.args.resume:
  60. self.lr_scheduler.step()
  61. # ---------------------------- Build Model-EMA ----------------------------
  62. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  63. print('Build ModelEMA ...')
  64. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  65. else:
  66. self.model_ema = None
  67. def train(self, model):
  68. for epoch in range(self.start_epoch, self.args.max_epoch):
  69. if self.args.distributed:
  70. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  71. # check second stage
  72. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1):
  73. # close mosaic augmentation
  74. if self.train_loader.dataset.mosaic_prob > 0.:
  75. print('close Mosaic Augmentation ...')
  76. self.train_loader.dataset.mosaic_prob = 0.
  77. self.heavy_eval = True
  78. # close mixup augmentation
  79. if self.train_loader.dataset.mixup_prob > 0.:
  80. print('close Mixup Augmentation ...')
  81. self.train_loader.dataset.mixup_prob = 0.
  82. self.heavy_eval = True
  83. # train one epoch
  84. self.epoch = epoch
  85. self.train_one_epoch(model)
  86. # eval one epoch
  87. if self.heavy_eval:
  88. model_eval = model.module if self.args.distributed else model
  89. self.eval(model_eval)
  90. else:
  91. model_eval = model.module if self.args.distributed else model
  92. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  93. self.eval(model_eval)
  94. def eval(self, model):
  95. # chech model
  96. model_eval = model if self.model_ema is None else self.model_ema.ema
  97. # path to save model
  98. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  99. os.makedirs(path_to_save, exist_ok=True)
  100. if distributed_utils.is_main_process():
  101. # check evaluator
  102. if self.evaluator is None:
  103. print('No evaluator ... save model and go on training.')
  104. print('Saving state, epoch: {}'.format(self.epoch + 1))
  105. weight_name = '{}_no_eval.pth'.format(self.args.model)
  106. checkpoint_path = os.path.join(path_to_save, weight_name)
  107. torch.save({'model': model_eval.state_dict(),
  108. 'mAP': -1.,
  109. 'optimizer': self.optimizer.state_dict(),
  110. 'epoch': self.epoch,
  111. 'args': self.args},
  112. checkpoint_path)
  113. else:
  114. print('eval ...')
  115. # set eval mode
  116. model_eval.trainable = False
  117. model_eval.eval()
  118. # evaluate
  119. with torch.no_grad():
  120. self.evaluator.evaluate(model_eval)
  121. # save model
  122. cur_map = self.evaluator.map
  123. if cur_map > self.best_map:
  124. # update best-map
  125. self.best_map = cur_map
  126. # save model
  127. print('Saving state, epoch:', self.epoch + 1)
  128. weight_name = '{}_best.pth'.format(self.args.model)
  129. checkpoint_path = os.path.join(path_to_save, weight_name)
  130. torch.save({'model': model_eval.state_dict(),
  131. 'mAP': round(self.best_map*100, 1),
  132. 'optimizer': self.optimizer.state_dict(),
  133. 'epoch': self.epoch,
  134. 'args': self.args},
  135. checkpoint_path)
  136. # set train mode.
  137. model_eval.trainable = True
  138. model_eval.train()
  139. if self.args.distributed:
  140. # wait for all processes to synchronize
  141. dist.barrier()
  142. def train_one_epoch(self, model):
  143. # basic parameters
  144. epoch_size = len(self.train_loader)
  145. img_size = self.args.img_size
  146. t0 = time.time()
  147. nw = epoch_size * self.args.wp_epoch
  148. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  149. # train one epoch
  150. for iter_i, (images, targets) in enumerate(self.train_loader):
  151. ni = iter_i + self.epoch * epoch_size
  152. # Warmup
  153. if ni <= nw:
  154. xi = [0, nw] # x interp
  155. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  156. for j, x in enumerate(self.optimizer.param_groups):
  157. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  158. x['lr'] = np.interp(
  159. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  160. if 'momentum' in x:
  161. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  162. # to device
  163. images = images.to(self.device, non_blocking=True).float() / 255.
  164. # Multi scale
  165. if self.args.multi_scale:
  166. images, targets, img_size = self.rescale_image_targets(
  167. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  168. else:
  169. targets = self.refine_targets(targets, self.args.min_box_size)
  170. # visualize train targets
  171. if self.args.vis_tgt:
  172. vis_data(images*255, targets)
  173. # inference
  174. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  175. outputs = model(images)
  176. # loss
  177. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  178. losses = loss_dict['losses']
  179. losses *= images.shape[0] # loss * bs
  180. # reduce
  181. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  182. # gradient averaged between devices in DDP mode
  183. losses *= distributed_utils.get_world_size()
  184. # backward
  185. self.scaler.scale(losses).backward()
  186. # Optimize
  187. if ni - self.last_opt_step >= accumulate:
  188. if self.clip_grad > 0:
  189. # unscale gradients
  190. self.scaler.unscale_(self.optimizer)
  191. # clip gradients
  192. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  193. # optimizer.step
  194. self.scaler.step(self.optimizer)
  195. self.scaler.update()
  196. self.optimizer.zero_grad()
  197. # ema
  198. if self.model_ema is not None:
  199. self.model_ema.update(model)
  200. self.last_opt_step = ni
  201. # display
  202. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  203. t1 = time.time()
  204. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  205. # basic infor
  206. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  207. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  208. log += '[lr: {:.6f}]'.format(cur_lr[2])
  209. # loss infor
  210. for k in loss_dict_reduced.keys():
  211. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  212. # other infor
  213. log += '[time: {:.2f}]'.format(t1 - t0)
  214. log += '[size: {}]'.format(img_size)
  215. # print log infor
  216. print(log, flush=True)
  217. t0 = time.time()
  218. self.lr_scheduler.step()
  219. def refine_targets(self, targets, min_box_size):
  220. # rescale targets
  221. for tgt in targets:
  222. boxes = tgt["boxes"].clone()
  223. labels = tgt["labels"].clone()
  224. # refine tgt
  225. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  226. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  227. keep = (min_tgt_size >= min_box_size)
  228. tgt["boxes"] = boxes[keep]
  229. tgt["labels"] = labels[keep]
  230. return targets
  231. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  232. """
  233. Deployed for Multi scale trick.
  234. """
  235. if isinstance(stride, int):
  236. max_stride = stride
  237. elif isinstance(stride, list):
  238. max_stride = max(stride)
  239. # During training phase, the shape of input image is square.
  240. old_img_size = images.shape[-1]
  241. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  242. new_img_size = new_img_size // max_stride * max_stride # size
  243. if new_img_size / old_img_size != 1:
  244. # interpolate
  245. images = torch.nn.functional.interpolate(
  246. input=images,
  247. size=new_img_size,
  248. mode='bilinear',
  249. align_corners=False)
  250. # rescale targets
  251. for tgt in targets:
  252. boxes = tgt["boxes"].clone()
  253. labels = tgt["labels"].clone()
  254. boxes = torch.clamp(boxes, 0, old_img_size)
  255. # rescale box
  256. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  257. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  258. # refine tgt
  259. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  260. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  261. keep = (min_tgt_size >= min_box_size)
  262. tgt["boxes"] = boxes[keep]
  263. tgt["labels"] = labels[keep]
  264. return images, targets, new_img_size
  265. # Trainer refered to RTMDet
  266. class RTMTrainer(object):
  267. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  268. # ------------------- basic parameters -------------------
  269. self.args = args
  270. self.epoch = 0
  271. self.best_map = -1.
  272. self.device = device
  273. self.criterion = criterion
  274. self.world_size = world_size
  275. self.no_aug_epoch = args.no_aug_epoch
  276. self.clip_grad = 35
  277. self.heavy_eval = False
  278. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  279. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  280. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  281. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  282. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  283. self.data_cfg = data_cfg
  284. self.model_cfg = model_cfg
  285. self.trans_cfg = trans_cfg
  286. # ---------------------------- Build Transform ----------------------------
  287. self.train_transform, self.trans_cfg = build_transform(
  288. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  289. self.val_transform, _ = build_transform(
  290. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  291. # ---------------------------- Build Dataset & Dataloader ----------------------------
  292. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  293. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  294. # ---------------------------- Build Evaluator ----------------------------
  295. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  296. # ---------------------------- Build Grad. Scaler ----------------------------
  297. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  298. # ---------------------------- Build Optimizer ----------------------------
  299. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  300. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  301. # ---------------------------- Build LR Scheduler ----------------------------
  302. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  303. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  304. if self.args.resume:
  305. self.lr_scheduler.step()
  306. # ---------------------------- Build Model-EMA ----------------------------
  307. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  308. print('Build ModelEMA ...')
  309. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  310. else:
  311. self.model_ema = None
  312. def train(self, model):
  313. for epoch in range(self.start_epoch, self.args.max_epoch):
  314. if self.args.distributed:
  315. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  316. # check second stage
  317. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1):
  318. # close mosaic augmentation
  319. if self.train_loader.dataset.mosaic_prob > 0.:
  320. print('close Mosaic Augmentation ...')
  321. self.train_loader.dataset.mosaic_prob = 0.
  322. self.heavy_eval = True
  323. # close mixup augmentation
  324. if self.train_loader.dataset.mixup_prob > 0.:
  325. print('close Mixup Augmentation ...')
  326. self.train_loader.dataset.mixup_prob = 0.
  327. self.heavy_eval = True
  328. # train one epoch
  329. self.epoch = epoch
  330. self.train_one_epoch(model)
  331. # eval one epoch
  332. if self.heavy_eval:
  333. model_eval = model.module if self.args.distributed else model
  334. self.eval(model_eval)
  335. else:
  336. model_eval = model.module if self.args.distributed else model
  337. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  338. self.eval(model_eval)
  339. def eval(self, model):
  340. # chech model
  341. model_eval = model if self.model_ema is None else self.model_ema.ema
  342. # path to save model
  343. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  344. os.makedirs(path_to_save, exist_ok=True)
  345. if distributed_utils.is_main_process():
  346. # check evaluator
  347. if self.evaluator is None:
  348. print('No evaluator ... save model and go on training.')
  349. print('Saving state, epoch: {}'.format(self.epoch + 1))
  350. weight_name = '{}_no_eval.pth'.format(self.args.model)
  351. checkpoint_path = os.path.join(path_to_save, weight_name)
  352. torch.save({'model': model_eval.state_dict(),
  353. 'mAP': -1.,
  354. 'optimizer': self.optimizer.state_dict(),
  355. 'epoch': self.epoch,
  356. 'args': self.args},
  357. checkpoint_path)
  358. else:
  359. print('eval ...')
  360. # set eval mode
  361. model_eval.trainable = False
  362. model_eval.eval()
  363. # evaluate
  364. with torch.no_grad():
  365. self.evaluator.evaluate(model_eval)
  366. # save model
  367. cur_map = self.evaluator.map
  368. if cur_map > self.best_map:
  369. # update best-map
  370. self.best_map = cur_map
  371. # save model
  372. print('Saving state, epoch:', self.epoch + 1)
  373. weight_name = '{}_best.pth'.format(self.args.model)
  374. checkpoint_path = os.path.join(path_to_save, weight_name)
  375. torch.save({'model': model_eval.state_dict(),
  376. 'mAP': round(self.best_map*100, 1),
  377. 'optimizer': self.optimizer.state_dict(),
  378. 'epoch': self.epoch,
  379. 'args': self.args},
  380. checkpoint_path)
  381. # set train mode.
  382. model_eval.trainable = True
  383. model_eval.train()
  384. if self.args.distributed:
  385. # wait for all processes to synchronize
  386. dist.barrier()
  387. def train_one_epoch(self, model):
  388. # basic parameters
  389. epoch_size = len(self.train_loader)
  390. img_size = self.args.img_size
  391. t0 = time.time()
  392. nw = epoch_size * self.args.wp_epoch
  393. # Train one epoch
  394. for iter_i, (images, targets) in enumerate(self.train_loader):
  395. ni = iter_i + self.epoch * epoch_size
  396. # Warmup
  397. if ni <= nw:
  398. xi = [0, nw] # x interp
  399. for j, x in enumerate(self.optimizer.param_groups):
  400. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  401. x['lr'] = np.interp(
  402. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  403. if 'momentum' in x:
  404. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  405. # To device
  406. images = images.to(self.device, non_blocking=True).float() / 255.
  407. # Multi scale
  408. if self.args.multi_scale:
  409. images, targets, img_size = self.rescale_image_targets(
  410. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  411. else:
  412. targets = self.refine_targets(targets, self.args.min_box_size)
  413. # Visualize train targets
  414. if self.args.vis_tgt:
  415. vis_data(images*255, targets)
  416. # Inference
  417. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  418. outputs = model(images)
  419. # Compute loss
  420. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  421. losses = loss_dict['losses']
  422. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  423. # Backward
  424. self.scaler.scale(losses).backward()
  425. # Optimize
  426. if self.clip_grad > 0:
  427. # unscale gradients
  428. self.scaler.unscale_(self.optimizer)
  429. # clip gradients
  430. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  431. # optimizer.step
  432. self.scaler.step(self.optimizer)
  433. self.scaler.update()
  434. self.optimizer.zero_grad()
  435. # ema
  436. if self.model_ema is not None:
  437. self.model_ema.update(model)
  438. # Logs
  439. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  440. t1 = time.time()
  441. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  442. # basic infor
  443. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  444. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  445. log += '[lr: {:.6f}]'.format(cur_lr[2])
  446. # loss infor
  447. for k in loss_dict_reduced.keys():
  448. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  449. # other infor
  450. log += '[time: {:.2f}]'.format(t1 - t0)
  451. log += '[size: {}]'.format(img_size)
  452. # print log infor
  453. print(log, flush=True)
  454. t0 = time.time()
  455. # LR Schedule
  456. self.lr_scheduler.step()
  457. def refine_targets(self, targets, min_box_size):
  458. # rescale targets
  459. for tgt in targets:
  460. boxes = tgt["boxes"].clone()
  461. labels = tgt["labels"].clone()
  462. # refine tgt
  463. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  464. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  465. keep = (min_tgt_size >= min_box_size)
  466. tgt["boxes"] = boxes[keep]
  467. tgt["labels"] = labels[keep]
  468. return targets
  469. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  470. """
  471. Deployed for Multi scale trick.
  472. """
  473. if isinstance(stride, int):
  474. max_stride = stride
  475. elif isinstance(stride, list):
  476. max_stride = max(stride)
  477. # During training phase, the shape of input image is square.
  478. old_img_size = images.shape[-1]
  479. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  480. new_img_size = new_img_size // max_stride * max_stride # size
  481. if new_img_size / old_img_size != 1:
  482. # interpolate
  483. images = torch.nn.functional.interpolate(
  484. input=images,
  485. size=new_img_size,
  486. mode='bilinear',
  487. align_corners=False)
  488. # rescale targets
  489. for tgt in targets:
  490. boxes = tgt["boxes"].clone()
  491. labels = tgt["labels"].clone()
  492. boxes = torch.clamp(boxes, 0, old_img_size)
  493. # rescale box
  494. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  495. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  496. # refine tgt
  497. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  498. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  499. keep = (min_tgt_size >= min_box_size)
  500. tgt["boxes"] = boxes[keep]
  501. tgt["labels"] = labels[keep]
  502. return images, targets, new_img_size
  503. # Trainer for DETR
  504. class DetrTrainer(object):
  505. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  506. # ------------------- basic parameters -------------------
  507. self.args = args
  508. self.epoch = 0
  509. self.best_map = -1.
  510. self.last_opt_step = 0
  511. self.no_aug_epoch = args.no_aug_epoch
  512. self.clip_grad = -1
  513. self.device = device
  514. self.criterion = criterion
  515. self.world_size = world_size
  516. self.heavy_eval = False
  517. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001}
  518. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  519. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.1}
  520. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  521. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  522. self.data_cfg = data_cfg
  523. self.model_cfg = model_cfg
  524. self.trans_cfg = trans_cfg
  525. # ---------------------------- Build Transform ----------------------------
  526. self.train_transform, self.trans_cfg = build_transform(
  527. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  528. self.val_transform, _ = build_transform(
  529. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  530. # ---------------------------- Build Dataset & Dataloader ----------------------------
  531. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  532. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  533. # ---------------------------- Build Evaluator ----------------------------
  534. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  535. # ---------------------------- Build Grad. Scaler ----------------------------
  536. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  537. # ---------------------------- Build Optimizer ----------------------------
  538. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  539. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  540. # ---------------------------- Build LR Scheduler ----------------------------
  541. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  542. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  543. if self.args.resume:
  544. self.lr_scheduler.step()
  545. # ---------------------------- Build Model-EMA ----------------------------
  546. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  547. print('Build ModelEMA ...')
  548. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  549. else:
  550. self.model_ema = None
  551. def train(self, model):
  552. for epoch in range(self.start_epoch, self.args.max_epoch):
  553. if self.args.distributed:
  554. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  555. # check second stage
  556. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1):
  557. # close mosaic augmentation
  558. if self.train_loader.dataset.mosaic_prob > 0.:
  559. print('close Mosaic Augmentation ...')
  560. self.train_loader.dataset.mosaic_prob = 0.
  561. self.heavy_eval = True
  562. # close mixup augmentation
  563. if self.train_loader.dataset.mixup_prob > 0.:
  564. print('close Mixup Augmentation ...')
  565. self.train_loader.dataset.mixup_prob = 0.
  566. self.heavy_eval = True
  567. # train one epoch
  568. self.epoch = epoch
  569. self.train_one_epoch(model)
  570. # eval one epoch
  571. if self.heavy_eval:
  572. model_eval = model.module if self.args.distributed else model
  573. self.eval(model_eval)
  574. else:
  575. model_eval = model.module if self.args.distributed else model
  576. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  577. self.eval(model_eval)
  578. def eval(self, model):
  579. # chech model
  580. model_eval = model if self.model_ema is None else self.model_ema.ema
  581. # path to save model
  582. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  583. os.makedirs(path_to_save, exist_ok=True)
  584. if distributed_utils.is_main_process():
  585. # check evaluator
  586. if self.evaluator is None:
  587. print('No evaluator ... save model and go on training.')
  588. print('Saving state, epoch: {}'.format(self.epoch + 1))
  589. weight_name = '{}_no_eval.pth'.format(self.args.model)
  590. checkpoint_path = os.path.join(path_to_save, weight_name)
  591. torch.save({'model': model_eval.state_dict(),
  592. 'mAP': -1.,
  593. 'optimizer': self.optimizer.state_dict(),
  594. 'epoch': self.epoch,
  595. 'args': self.args},
  596. checkpoint_path)
  597. else:
  598. print('eval ...')
  599. # set eval mode
  600. model_eval.trainable = False
  601. model_eval.eval()
  602. # evaluate
  603. with torch.no_grad():
  604. self.evaluator.evaluate(model_eval)
  605. # save model
  606. cur_map = self.evaluator.map
  607. if cur_map > self.best_map:
  608. # update best-map
  609. self.best_map = cur_map
  610. # save model
  611. print('Saving state, epoch:', self.epoch + 1)
  612. weight_name = '{}_best.pth'.format(self.args.model)
  613. checkpoint_path = os.path.join(path_to_save, weight_name)
  614. torch.save({'model': model_eval.state_dict(),
  615. 'mAP': round(self.best_map*100, 1),
  616. 'optimizer': self.optimizer.state_dict(),
  617. 'epoch': self.epoch,
  618. 'args': self.args},
  619. checkpoint_path)
  620. # set train mode.
  621. model_eval.trainable = True
  622. model_eval.train()
  623. if self.args.distributed:
  624. # wait for all processes to synchronize
  625. dist.barrier()
  626. def train_one_epoch(self, model):
  627. # basic parameters
  628. epoch_size = len(self.train_loader)
  629. img_size = self.args.img_size
  630. t0 = time.time()
  631. nw = epoch_size * self.args.wp_epoch
  632. # train one epoch
  633. for iter_i, (images, targets) in enumerate(self.train_loader):
  634. ni = iter_i + self.epoch * epoch_size
  635. # Warmup
  636. if ni <= nw:
  637. xi = [0, nw] # x interp
  638. for j, x in enumerate(self.optimizer.param_groups):
  639. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  640. x['lr'] = np.interp(
  641. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  642. if 'momentum' in x:
  643. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  644. # To device
  645. images = images.to(self.device, non_blocking=True).float() / 255.
  646. # Multi scale
  647. if self.args.multi_scale:
  648. images, targets, img_size = self.rescale_image_targets(
  649. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  650. else:
  651. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  652. # Visualize targets
  653. if self.args.vis_tgt:
  654. vis_data(images*255, targets)
  655. # Inference
  656. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  657. outputs = model(images)
  658. # Compute loss
  659. loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
  660. losses = loss_dict['losses']
  661. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  662. # Backward
  663. self.scaler.scale(losses).backward()
  664. # Optimize
  665. if self.clip_grad > 0:
  666. # unscale gradients
  667. self.scaler.unscale_(self.optimizer)
  668. # clip gradients
  669. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  670. self.scaler.step(self.optimizer)
  671. self.scaler.update()
  672. self.optimizer.zero_grad()
  673. # Model EMA
  674. if self.model_ema is not None:
  675. self.model_ema.update(model)
  676. self.last_opt_step = ni
  677. # Log
  678. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  679. t1 = time.time()
  680. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  681. # basic infor
  682. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  683. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  684. log += '[lr: {:.6f}]'.format(cur_lr[0])
  685. # loss infor
  686. for k in loss_dict_reduced.keys():
  687. if self.args.vis_aux_loss:
  688. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  689. else:
  690. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  691. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  692. # other infor
  693. log += '[time: {:.2f}]'.format(t1 - t0)
  694. log += '[size: {}]'.format(img_size)
  695. # print log infor
  696. print(log, flush=True)
  697. t0 = time.time()
  698. # LR Scheduler
  699. self.lr_scheduler.step()
  700. def refine_targets(self, targets, min_box_size, img_size):
  701. # rescale targets
  702. for tgt in targets:
  703. boxes = tgt["boxes"]
  704. labels = tgt["labels"]
  705. # refine tgt
  706. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  707. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  708. keep = (min_tgt_size >= min_box_size)
  709. # xyxy -> cxcywh
  710. new_boxes = torch.zeros_like(boxes)
  711. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  712. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  713. # normalize
  714. new_boxes /= img_size
  715. del boxes
  716. tgt["boxes"] = new_boxes[keep]
  717. tgt["labels"] = labels[keep]
  718. return targets
  719. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  720. """
  721. Deployed for Multi scale trick.
  722. """
  723. if isinstance(stride, int):
  724. max_stride = stride
  725. elif isinstance(stride, list):
  726. max_stride = max(stride)
  727. # During training phase, the shape of input image is square.
  728. old_img_size = images.shape[-1]
  729. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  730. new_img_size = new_img_size // max_stride * max_stride # size
  731. if new_img_size / old_img_size != 1:
  732. # interpolate
  733. images = torch.nn.functional.interpolate(
  734. input=images,
  735. size=new_img_size,
  736. mode='bilinear',
  737. align_corners=False)
  738. # rescale targets
  739. for tgt in targets:
  740. boxes = tgt["boxes"].clone()
  741. labels = tgt["labels"].clone()
  742. boxes = torch.clamp(boxes, 0, old_img_size)
  743. # rescale box
  744. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  745. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  746. # refine tgt
  747. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  748. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  749. keep = (min_tgt_size >= min_box_size)
  750. # xyxy -> cxcywh
  751. new_boxes = torch.zeros_like(boxes)
  752. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  753. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  754. # normalize
  755. new_boxes /= new_img_size
  756. del boxes
  757. tgt["boxes"] = new_boxes[keep]
  758. tgt["labels"] = labels[keep]
  759. return images, targets, new_img_size
  760. # Build Trainer
  761. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  762. if model_cfg['trainer_type'] == 'yolo':
  763. return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  764. elif model_cfg['trainer_type'] == 'rtmdet':
  765. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  766. elif model_cfg['trainer_type'] == 'detr':
  767. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  768. else:
  769. raise NotImplementedError