engine.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # Trainer refered to YOLOv8
  19. class YoloTrainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.device = device
  27. self.criterion = criterion
  28. self.world_size = world_size
  29. self.heavy_eval = False
  30. self.no_aug_epoch = 20
  31. self.clip_grad = 10
  32. self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
  33. self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
  34. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
  35. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  36. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  37. self.data_cfg = data_cfg
  38. self.model_cfg = model_cfg
  39. self.trans_cfg = trans_cfg
  40. # ---------------------------- Build Transform ----------------------------
  41. self.train_transform, self.trans_cfg = build_transform(
  42. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  43. self.val_transform, _ = build_transform(
  44. args=args, trans_config=self.trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  45. # ---------------------------- Build Dataset & Dataloader ----------------------------
  46. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  47. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  48. # ---------------------------- Build Evaluator ----------------------------
  49. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  50. # ---------------------------- Build Grad. Scaler ----------------------------
  51. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  52. # ---------------------------- Build Optimizer ----------------------------
  53. accumulate = max(1, round(64 / self.args.batch_size))
  54. self.optimizer_dict['weight_decay'] *= self.args.batch_size * accumulate / 64
  55. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  56. # ---------------------------- Build LR Scheduler ----------------------------
  57. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  58. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  59. if self.args.resume:
  60. self.lr_scheduler.step()
  61. # ---------------------------- Build Model-EMA ----------------------------
  62. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  63. print('Build ModelEMA ...')
  64. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  65. else:
  66. self.model_ema = None
  67. def train(self, model):
  68. for epoch in range(self.start_epoch, self.args.max_epoch):
  69. if self.args.distributed:
  70. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  71. # check second stage
  72. if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1):
  73. # close mosaic augmentation
  74. if self.train_loader.dataset.mosaic_prob > 0.:
  75. print('close Mosaic Augmentation ...')
  76. self.train_loader.dataset.mosaic_prob = 0.
  77. self.heavy_eval = True
  78. # close mixup augmentation
  79. if self.train_loader.dataset.mixup_prob > 0.:
  80. print('close Mixup Augmentation ...')
  81. self.train_loader.dataset.mixup_prob = 0.
  82. self.heavy_eval = True
  83. # train one epoch
  84. self.epoch = epoch
  85. self.train_one_epoch(model)
  86. # eval one epoch
  87. if self.heavy_eval:
  88. model_eval = model.module if self.args.distributed else model
  89. self.eval(model_eval)
  90. else:
  91. model_eval = model.module if self.args.distributed else model
  92. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  93. self.eval(model_eval)
  94. def eval(self, model):
  95. # chech model
  96. model_eval = model if self.model_ema is None else self.model_ema.ema
  97. # path to save model
  98. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  99. os.makedirs(path_to_save, exist_ok=True)
  100. if distributed_utils.is_main_process():
  101. # check evaluator
  102. if self.evaluator is None:
  103. print('No evaluator ... save model and go on training.')
  104. print('Saving state, epoch: {}'.format(self.epoch + 1))
  105. weight_name = '{}_no_eval.pth'.format(self.args.model)
  106. checkpoint_path = os.path.join(path_to_save, weight_name)
  107. torch.save({'model': model_eval.state_dict(),
  108. 'mAP': -1.,
  109. 'optimizer': self.optimizer.state_dict(),
  110. 'epoch': self.epoch,
  111. 'args': self.args},
  112. checkpoint_path)
  113. else:
  114. print('eval ...')
  115. # set eval mode
  116. model_eval.trainable = False
  117. model_eval.eval()
  118. # evaluate
  119. with torch.no_grad():
  120. self.evaluator.evaluate(model_eval)
  121. # save model
  122. cur_map = self.evaluator.map
  123. if cur_map > self.best_map:
  124. # update best-map
  125. self.best_map = cur_map
  126. # save model
  127. print('Saving state, epoch:', self.epoch + 1)
  128. weight_name = '{}_best.pth'.format(self.args.model)
  129. checkpoint_path = os.path.join(path_to_save, weight_name)
  130. torch.save({'model': model_eval.state_dict(),
  131. 'mAP': round(self.best_map*100, 1),
  132. 'optimizer': self.optimizer.state_dict(),
  133. 'epoch': self.epoch,
  134. 'args': self.args},
  135. checkpoint_path)
  136. # set train mode.
  137. model_eval.trainable = True
  138. model_eval.train()
  139. if self.args.distributed:
  140. # wait for all processes to synchronize
  141. dist.barrier()
  142. def train_one_epoch(self, model):
  143. # basic parameters
  144. epoch_size = len(self.train_loader)
  145. img_size = self.args.img_size
  146. t0 = time.time()
  147. nw = epoch_size * self.args.wp_epoch
  148. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  149. # train one epoch
  150. for iter_i, (images, targets) in enumerate(self.train_loader):
  151. ni = iter_i + self.epoch * epoch_size
  152. # Warmup
  153. if ni <= nw:
  154. xi = [0, nw] # x interp
  155. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  156. for j, x in enumerate(self.optimizer.param_groups):
  157. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  158. x['lr'] = np.interp(
  159. ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  160. if 'momentum' in x:
  161. x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
  162. # to device
  163. images = images.to(self.device, non_blocking=True).float() / 255.
  164. # Multi scale
  165. if self.args.multi_scale:
  166. images, targets, img_size = self.rescale_image_targets(
  167. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  168. else:
  169. targets = self.refine_targets(targets, self.args.min_box_size)
  170. # visualize train targets
  171. if self.args.vis_tgt:
  172. vis_data(images*255, targets)
  173. # inference
  174. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  175. outputs = model(images)
  176. # loss
  177. loss_dict = self.criterion(outputs=outputs, targets=targets)
  178. losses = loss_dict['losses']
  179. losses *= images.shape[0] # loss * bs
  180. # reduce
  181. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  182. # gradient averaged between devices in DDP mode
  183. losses *= distributed_utils.get_world_size()
  184. # backward
  185. self.scaler.scale(losses).backward()
  186. # Optimize
  187. if ni - self.last_opt_step >= accumulate:
  188. if self.clip_grad > 0:
  189. # unscale gradients
  190. self.scaler.unscale_(self.optimizer)
  191. # clip gradients
  192. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
  193. # optimizer.step
  194. self.scaler.step(self.optimizer)
  195. self.scaler.update()
  196. self.optimizer.zero_grad()
  197. # ema
  198. if self.model_ema is not None:
  199. self.model_ema.update(model)
  200. self.last_opt_step = ni
  201. # display
  202. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  203. t1 = time.time()
  204. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  205. # basic infor
  206. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  207. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  208. log += '[lr: {:.6f}]'.format(cur_lr[2])
  209. # loss infor
  210. for k in loss_dict_reduced.keys():
  211. log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
  212. # other infor
  213. log += '[time: {:.2f}]'.format(t1 - t0)
  214. log += '[size: {}]'.format(img_size)
  215. # print log infor
  216. print(log, flush=True)
  217. t0 = time.time()
  218. self.lr_scheduler.step()
  219. def refine_targets(self, targets, min_box_size):
  220. # rescale targets
  221. for tgt in targets:
  222. boxes = tgt["boxes"].clone()
  223. labels = tgt["labels"].clone()
  224. # refine tgt
  225. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  226. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  227. keep = (min_tgt_size >= min_box_size)
  228. tgt["boxes"] = boxes[keep]
  229. tgt["labels"] = labels[keep]
  230. return targets
  231. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  232. """
  233. Deployed for Multi scale trick.
  234. """
  235. if isinstance(stride, int):
  236. max_stride = stride
  237. elif isinstance(stride, list):
  238. max_stride = max(stride)
  239. # During training phase, the shape of input image is square.
  240. old_img_size = images.shape[-1]
  241. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  242. new_img_size = new_img_size // max_stride * max_stride # size
  243. if new_img_size / old_img_size != 1:
  244. # interpolate
  245. images = torch.nn.functional.interpolate(
  246. input=images,
  247. size=new_img_size,
  248. mode='bilinear',
  249. align_corners=False)
  250. # rescale targets
  251. for tgt in targets:
  252. boxes = tgt["boxes"].clone()
  253. labels = tgt["labels"].clone()
  254. boxes = torch.clamp(boxes, 0, old_img_size)
  255. # rescale box
  256. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  257. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  258. # refine tgt
  259. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  260. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  261. keep = (min_tgt_size >= min_box_size)
  262. tgt["boxes"] = boxes[keep]
  263. tgt["labels"] = labels[keep]
  264. return images, targets, new_img_size
  265. # Trainer refered to RTMDet
  266. class RTMTrainer(object):
  267. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  268. # ------------------- basic parameters -------------------
  269. self.args = args
  270. self.epoch = 0
  271. self.best_map = -1.
  272. self.device = device
  273. self.criterion = criterion
  274. self.world_size = world_size
  275. self.heavy_eval = False
  276. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
  277. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  278. self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.01}
  279. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  280. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  281. self.data_cfg = data_cfg
  282. self.model_cfg = model_cfg
  283. self.trans_cfg = trans_cfg
  284. # ---------------------------- Build Transform ----------------------------
  285. self.train_transform, self.trans_cfg = build_transform(
  286. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  287. self.val_transform, _ = build_transform(
  288. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  289. # ---------------------------- Build Dataset & Dataloader ----------------------------
  290. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  291. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  292. # ---------------------------- Build Evaluator ----------------------------
  293. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  294. # ---------------------------- Build Grad. Scaler ----------------------------
  295. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  296. # ---------------------------- Build Optimizer ----------------------------
  297. self.optimizer_dict['lr0'] *= self.args.batch_size / 64
  298. self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
  299. # ---------------------------- Build LR Scheduler ----------------------------
  300. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  301. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  302. if self.args.resume:
  303. self.lr_scheduler.step()
  304. # ---------------------------- Build Model-EMA ----------------------------
  305. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  306. print('Build ModelEMA ...')
  307. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  308. else:
  309. self.model_ema = None
  310. def train(self, model):
  311. for epoch in range(self.start_epoch, self.args.max_epoch):
  312. if self.args.distributed:
  313. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  314. # check second stage
  315. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  316. # close mosaic augmentation
  317. if self.train_loader.dataset.mosaic_prob > 0.:
  318. print('close Mosaic Augmentation ...')
  319. self.train_loader.dataset.mosaic_prob = 0.
  320. self.heavy_eval = True
  321. # close mixup augmentation
  322. if self.train_loader.dataset.mixup_prob > 0.:
  323. print('close Mixup Augmentation ...')
  324. self.train_loader.dataset.mixup_prob = 0.
  325. self.heavy_eval = True
  326. # train one epoch
  327. self.epoch = epoch
  328. self.train_one_epoch(model)
  329. # eval one epoch
  330. if self.heavy_eval:
  331. model_eval = model.module if self.args.distributed else model
  332. self.eval(model_eval)
  333. else:
  334. model_eval = model.module if self.args.distributed else model
  335. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  336. self.eval(model_eval)
  337. def eval(self, model):
  338. # chech model
  339. model_eval = model if self.model_ema is None else self.model_ema.ema
  340. # path to save model
  341. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  342. os.makedirs(path_to_save, exist_ok=True)
  343. if distributed_utils.is_main_process():
  344. # check evaluator
  345. if self.evaluator is None:
  346. print('No evaluator ... save model and go on training.')
  347. print('Saving state, epoch: {}'.format(self.epoch + 1))
  348. weight_name = '{}_no_eval.pth'.format(self.args.model)
  349. checkpoint_path = os.path.join(path_to_save, weight_name)
  350. torch.save({'model': model_eval.state_dict(),
  351. 'mAP': -1.,
  352. 'optimizer': self.optimizer.state_dict(),
  353. 'epoch': self.epoch,
  354. 'args': self.args},
  355. checkpoint_path)
  356. else:
  357. print('eval ...')
  358. # set eval mode
  359. model_eval.trainable = False
  360. model_eval.eval()
  361. # evaluate
  362. with torch.no_grad():
  363. self.evaluator.evaluate(model_eval)
  364. # save model
  365. cur_map = self.evaluator.map
  366. if cur_map > self.best_map:
  367. # update best-map
  368. self.best_map = cur_map
  369. # save model
  370. print('Saving state, epoch:', self.epoch + 1)
  371. weight_name = '{}_best.pth'.format(self.args.model)
  372. checkpoint_path = os.path.join(path_to_save, weight_name)
  373. torch.save({'model': model_eval.state_dict(),
  374. 'mAP': round(self.best_map*100, 1),
  375. 'optimizer': self.optimizer.state_dict(),
  376. 'epoch': self.epoch,
  377. 'args': self.args},
  378. checkpoint_path)
  379. # set train mode.
  380. model_eval.trainable = True
  381. model_eval.train()
  382. if self.args.distributed:
  383. # wait for all processes to synchronize
  384. dist.barrier()
  385. def train_one_epoch(self, model):
  386. # basic parameters
  387. epoch_size = len(self.train_loader)
  388. img_size = self.args.img_size
  389. t0 = time.time()
  390. nw = epoch_size * self.args.wp_epoch
  391. # Train one epoch
  392. for iter_i, (images, targets) in enumerate(self.train_loader):
  393. ni = iter_i + self.epoch * epoch_size
  394. # Warmup
  395. if ni <= nw:
  396. xi = [0, nw] # x interp
  397. for j, x in enumerate(self.optimizer.param_groups):
  398. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  399. x['lr'] = np.interp(
  400. ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  401. if 'momentum' in x:
  402. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  403. # To device
  404. images = images.to(self.device, non_blocking=True).float() / 255.
  405. # Multi scale
  406. if self.args.multi_scale:
  407. images, targets, img_size = self.rescale_image_targets(
  408. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  409. else:
  410. targets = self.refine_targets(targets, self.args.min_box_size)
  411. # Visualize train targets
  412. if self.args.vis_tgt:
  413. vis_data(images*255, targets)
  414. # Inference
  415. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  416. outputs = model(images)
  417. # Compute loss
  418. loss_dict = self.criterion(outputs=outputs, targets=targets)
  419. losses = loss_dict['losses']
  420. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  421. # Backward
  422. self.scaler.scale(losses).backward()
  423. # Optimize
  424. if self.model_cfg['clip_grad'] > 0:
  425. # unscale gradients
  426. self.scaler.unscale_(self.optimizer)
  427. # clip gradients
  428. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  429. # optimizer.step
  430. self.scaler.step(self.optimizer)
  431. self.scaler.update()
  432. self.optimizer.zero_grad()
  433. # ema
  434. if self.model_ema is not None:
  435. self.model_ema.update(model)
  436. # Logs
  437. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  438. t1 = time.time()
  439. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  440. # basic infor
  441. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  442. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  443. log += '[lr: {:.6f}]'.format(cur_lr[2])
  444. # loss infor
  445. for k in loss_dict_reduced.keys():
  446. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  447. # other infor
  448. log += '[time: {:.2f}]'.format(t1 - t0)
  449. log += '[size: {}]'.format(img_size)
  450. # print log infor
  451. print(log, flush=True)
  452. t0 = time.time()
  453. # LR Schedule
  454. self.lr_scheduler.step()
  455. def refine_targets(self, targets, min_box_size):
  456. # rescale targets
  457. for tgt in targets:
  458. boxes = tgt["boxes"].clone()
  459. labels = tgt["labels"].clone()
  460. # refine tgt
  461. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  462. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  463. keep = (min_tgt_size >= min_box_size)
  464. tgt["boxes"] = boxes[keep]
  465. tgt["labels"] = labels[keep]
  466. return targets
  467. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  468. """
  469. Deployed for Multi scale trick.
  470. """
  471. if isinstance(stride, int):
  472. max_stride = stride
  473. elif isinstance(stride, list):
  474. max_stride = max(stride)
  475. # During training phase, the shape of input image is square.
  476. old_img_size = images.shape[-1]
  477. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  478. new_img_size = new_img_size // max_stride * max_stride # size
  479. if new_img_size / old_img_size != 1:
  480. # interpolate
  481. images = torch.nn.functional.interpolate(
  482. input=images,
  483. size=new_img_size,
  484. mode='bilinear',
  485. align_corners=False)
  486. # rescale targets
  487. for tgt in targets:
  488. boxes = tgt["boxes"].clone()
  489. labels = tgt["labels"].clone()
  490. boxes = torch.clamp(boxes, 0, old_img_size)
  491. # rescale box
  492. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  493. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  494. # refine tgt
  495. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  496. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  497. keep = (min_tgt_size >= min_box_size)
  498. tgt["boxes"] = boxes[keep]
  499. tgt["labels"] = labels[keep]
  500. return images, targets, new_img_size
  501. # Trainer for DETR
  502. class DetrTrainer(object):
  503. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  504. # ------------------- basic parameters -------------------
  505. self.args = args
  506. self.epoch = 0
  507. self.best_map = -1.
  508. self.last_opt_step = 0
  509. self.device = device
  510. self.criterion = criterion
  511. self.world_size = world_size
  512. self.heavy_eval = False
  513. self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001}
  514. self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
  515. self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.1}
  516. self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}
  517. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  518. self.data_cfg = data_cfg
  519. self.model_cfg = model_cfg
  520. self.trans_cfg = trans_cfg
  521. # ---------------------------- Build Transform ----------------------------
  522. self.train_transform, self.trans_cfg = build_transform(
  523. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  524. self.val_transform, _ = build_transform(
  525. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  526. # ---------------------------- Build Dataset & Dataloader ----------------------------
  527. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  528. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
  529. # ---------------------------- Build Evaluator ----------------------------
  530. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  531. # ---------------------------- Build Grad. Scaler ----------------------------
  532. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  533. # ---------------------------- Build Optimizer ----------------------------
  534. self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
  535. self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
  536. # ---------------------------- Build LR Scheduler ----------------------------
  537. self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
  538. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  539. if self.args.resume:
  540. self.lr_scheduler.step()
  541. # ---------------------------- Build Model-EMA ----------------------------
  542. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  543. print('Build ModelEMA ...')
  544. self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
  545. else:
  546. self.model_ema = None
  547. def train(self, model):
  548. for epoch in range(self.start_epoch, self.args.max_epoch):
  549. if self.args.distributed:
  550. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  551. # check second stage
  552. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  553. # close mosaic augmentation
  554. if self.train_loader.dataset.mosaic_prob > 0.:
  555. print('close Mosaic Augmentation ...')
  556. self.train_loader.dataset.mosaic_prob = 0.
  557. self.heavy_eval = True
  558. # close mixup augmentation
  559. if self.train_loader.dataset.mixup_prob > 0.:
  560. print('close Mixup Augmentation ...')
  561. self.train_loader.dataset.mixup_prob = 0.
  562. self.heavy_eval = True
  563. # train one epoch
  564. self.epoch = epoch
  565. self.train_one_epoch(model)
  566. # eval one epoch
  567. if self.heavy_eval:
  568. model_eval = model.module if self.args.distributed else model
  569. self.eval(model_eval)
  570. else:
  571. model_eval = model.module if self.args.distributed else model
  572. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  573. self.eval(model_eval)
  574. def eval(self, model):
  575. # chech model
  576. model_eval = model if self.model_ema is None else self.model_ema.ema
  577. # path to save model
  578. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  579. os.makedirs(path_to_save, exist_ok=True)
  580. if distributed_utils.is_main_process():
  581. # check evaluator
  582. if self.evaluator is None:
  583. print('No evaluator ... save model and go on training.')
  584. print('Saving state, epoch: {}'.format(self.epoch + 1))
  585. weight_name = '{}_no_eval.pth'.format(self.args.model)
  586. checkpoint_path = os.path.join(path_to_save, weight_name)
  587. torch.save({'model': model_eval.state_dict(),
  588. 'mAP': -1.,
  589. 'optimizer': self.optimizer.state_dict(),
  590. 'epoch': self.epoch,
  591. 'args': self.args},
  592. checkpoint_path)
  593. else:
  594. print('eval ...')
  595. # set eval mode
  596. model_eval.trainable = False
  597. model_eval.eval()
  598. # evaluate
  599. with torch.no_grad():
  600. self.evaluator.evaluate(model_eval)
  601. # save model
  602. cur_map = self.evaluator.map
  603. if cur_map > self.best_map:
  604. # update best-map
  605. self.best_map = cur_map
  606. # save model
  607. print('Saving state, epoch:', self.epoch + 1)
  608. weight_name = '{}_best.pth'.format(self.args.model)
  609. checkpoint_path = os.path.join(path_to_save, weight_name)
  610. torch.save({'model': model_eval.state_dict(),
  611. 'mAP': round(self.best_map*100, 1),
  612. 'optimizer': self.optimizer.state_dict(),
  613. 'epoch': self.epoch,
  614. 'args': self.args},
  615. checkpoint_path)
  616. # set train mode.
  617. model_eval.trainable = True
  618. model_eval.train()
  619. if self.args.distributed:
  620. # wait for all processes to synchronize
  621. dist.barrier()
  622. def train_one_epoch(self, model):
  623. # basic parameters
  624. epoch_size = len(self.train_loader)
  625. img_size = self.args.img_size
  626. t0 = time.time()
  627. nw = epoch_size * self.args.wp_epoch
  628. # train one epoch
  629. for iter_i, (images, targets) in enumerate(self.train_loader):
  630. ni = iter_i + self.epoch * epoch_size
  631. # Warmup
  632. if ni <= nw:
  633. xi = [0, nw] # x interp
  634. for j, x in enumerate(self.optimizer.param_groups):
  635. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  636. x['lr'] = np.interp(
  637. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  638. if 'momentum' in x:
  639. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  640. # To device
  641. images = images.to(self.device, non_blocking=True).float() / 255.
  642. # Multi scale
  643. if self.args.multi_scale:
  644. images, targets, img_size = self.rescale_image_targets(
  645. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  646. else:
  647. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  648. # Visualize targets
  649. if self.args.vis_tgt:
  650. vis_data(images*255, targets)
  651. # Inference
  652. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  653. outputs = model(images)
  654. # Compute loss
  655. loss_dict = self.criterion(outputs=outputs, targets=targets)
  656. losses = loss_dict['losses']
  657. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  658. # Backward
  659. self.scaler.scale(losses).backward()
  660. # Optimize
  661. if self.model_cfg['clip_grad'] > 0:
  662. # unscale gradients
  663. self.scaler.unscale_(self.optimizer)
  664. # clip gradients
  665. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  666. self.scaler.step(self.optimizer)
  667. self.scaler.update()
  668. self.optimizer.zero_grad()
  669. # Model EMA
  670. if self.model_ema is not None:
  671. self.model_ema.update(model)
  672. self.last_opt_step = ni
  673. # Log
  674. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  675. t1 = time.time()
  676. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  677. # basic infor
  678. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  679. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  680. log += '[lr: {:.6f}]'.format(cur_lr[0])
  681. # loss infor
  682. for k in loss_dict_reduced.keys():
  683. if self.args.vis_aux_loss:
  684. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  685. else:
  686. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  687. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  688. # other infor
  689. log += '[time: {:.2f}]'.format(t1 - t0)
  690. log += '[size: {}]'.format(img_size)
  691. # print log infor
  692. print(log, flush=True)
  693. t0 = time.time()
  694. # LR Scheduler
  695. self.lr_scheduler.step()
  696. def refine_targets(self, targets, min_box_size, img_size):
  697. # rescale targets
  698. for tgt in targets:
  699. boxes = tgt["boxes"]
  700. labels = tgt["labels"]
  701. # refine tgt
  702. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  703. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  704. keep = (min_tgt_size >= min_box_size)
  705. # xyxy -> cxcywh
  706. new_boxes = torch.zeros_like(boxes)
  707. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  708. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  709. # normalize
  710. new_boxes /= img_size
  711. del boxes
  712. tgt["boxes"] = new_boxes[keep]
  713. tgt["labels"] = labels[keep]
  714. return targets
  715. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  716. """
  717. Deployed for Multi scale trick.
  718. """
  719. if isinstance(stride, int):
  720. max_stride = stride
  721. elif isinstance(stride, list):
  722. max_stride = max(stride)
  723. # During training phase, the shape of input image is square.
  724. old_img_size = images.shape[-1]
  725. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  726. new_img_size = new_img_size // max_stride * max_stride # size
  727. if new_img_size / old_img_size != 1:
  728. # interpolate
  729. images = torch.nn.functional.interpolate(
  730. input=images,
  731. size=new_img_size,
  732. mode='bilinear',
  733. align_corners=False)
  734. # rescale targets
  735. for tgt in targets:
  736. boxes = tgt["boxes"].clone()
  737. labels = tgt["labels"].clone()
  738. boxes = torch.clamp(boxes, 0, old_img_size)
  739. # rescale box
  740. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  741. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  742. # refine tgt
  743. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  744. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  745. keep = (min_tgt_size >= min_box_size)
  746. # xyxy -> cxcywh
  747. new_boxes = torch.zeros_like(boxes)
  748. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  749. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  750. # normalize
  751. new_boxes /= new_img_size
  752. del boxes
  753. tgt["boxes"] = new_boxes[keep]
  754. tgt["labels"] = labels[keep]
  755. return images, targets, new_img_size
  756. # Build Trainer
  757. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
  758. if model_cfg['trainer_type'] == 'yolo':
  759. return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  760. elif model_cfg['trainer_type'] == 'rtmdet':
  761. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  762. elif model_cfg['trainer_type'] == 'detr':
  763. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
  764. else:
  765. raise NotImplementedError