engine.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # Trainer refered to YOLOv8
  19. class YoloTrainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.device = device
  27. self.criterion = criterion
  28. self.heavy_eval = False
  29. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  30. self.data_cfg = data_cfg
  31. self.model_cfg = model_cfg
  32. self.trans_cfg = trans_cfg
  33. # ---------------------------- Build Transform ----------------------------
  34. self.train_transform, self.trans_cfg = build_transform(
  35. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  36. self.val_transform, _ = build_transform(
  37. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  38. # ---------------------------- Build Dataset & Dataloader ----------------------------
  39. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  40. world_size = distributed_utils.get_world_size()
  41. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  42. # ---------------------------- Build Evaluator ----------------------------
  43. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  44. # ---------------------------- Build Grad. Scaler ----------------------------
  45. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  46. # ---------------------------- Build Optimizer ----------------------------
  47. accumulate = max(1, round(64 / self.args.batch_size))
  48. self.model_cfg['weight_decay'] *= self.args.batch_size * accumulate / 64
  49. self.optimizer, self.start_epoch = build_yolo_optimizer(self.model_cfg, model, self.args.resume)
  50. # ---------------------------- Build LR Scheduler ----------------------------
  51. self.args.max_epoch += self.args.wp_epoch
  52. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  53. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  54. if self.args.resume:
  55. self.lr_scheduler.step()
  56. # ---------------------------- Build Model-EMA ----------------------------
  57. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  58. print('Build ModelEMA ...')
  59. self.model_ema = ModelEMA(
  60. model,
  61. self.model_cfg['ema_decay'],
  62. self.model_cfg['ema_tau'],
  63. self.start_epoch * len(self.train_loader))
  64. else:
  65. self.model_ema = None
  66. def train(self, model):
  67. for epoch in range(self.start_epoch, self.args.max_epoch):
  68. if self.args.distributed:
  69. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  70. # check second stage
  71. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  72. # close mosaic augmentation
  73. if self.train_loader.dataset.mosaic_prob > 0.:
  74. print('close Mosaic Augmentation ...')
  75. self.train_loader.dataset.mosaic_prob = 0.
  76. self.heavy_eval = True
  77. # close mixup augmentation
  78. if self.train_loader.dataset.mixup_prob > 0.:
  79. print('close Mixup Augmentation ...')
  80. self.train_loader.dataset.mixup_prob = 0.
  81. self.heavy_eval = True
  82. # train one epoch
  83. self.epoch = epoch
  84. self.train_one_epoch(model)
  85. # eval one epoch
  86. if self.heavy_eval:
  87. model_eval = model.module if self.args.distributed else model
  88. self.eval(model_eval)
  89. else:
  90. model_eval = model.module if self.args.distributed else model
  91. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  92. self.eval(model_eval)
  93. def eval(self, model):
  94. # chech model
  95. model_eval = model if self.model_ema is None else self.model_ema.ema
  96. # path to save model
  97. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  98. os.makedirs(path_to_save, exist_ok=True)
  99. if distributed_utils.is_main_process():
  100. # check evaluator
  101. if self.evaluator is None:
  102. print('No evaluator ... save model and go on training.')
  103. print('Saving state, epoch: {}'.format(self.epoch + 1))
  104. weight_name = '{}_no_eval.pth'.format(self.args.model)
  105. checkpoint_path = os.path.join(path_to_save, weight_name)
  106. torch.save({'model': model_eval.state_dict(),
  107. 'mAP': -1.,
  108. 'optimizer': self.optimizer.state_dict(),
  109. 'epoch': self.epoch,
  110. 'args': self.args},
  111. checkpoint_path)
  112. else:
  113. print('eval ...')
  114. # set eval mode
  115. model_eval.trainable = False
  116. model_eval.eval()
  117. # evaluate
  118. with torch.no_grad():
  119. self.evaluator.evaluate(model_eval)
  120. # save model
  121. cur_map = self.evaluator.map
  122. if cur_map > self.best_map:
  123. # update best-map
  124. self.best_map = cur_map
  125. # save model
  126. print('Saving state, epoch:', self.epoch + 1)
  127. weight_name = '{}_best.pth'.format(self.args.model)
  128. checkpoint_path = os.path.join(path_to_save, weight_name)
  129. torch.save({'model': model_eval.state_dict(),
  130. 'mAP': round(self.best_map*100, 1),
  131. 'optimizer': self.optimizer.state_dict(),
  132. 'epoch': self.epoch,
  133. 'args': self.args},
  134. checkpoint_path)
  135. # set train mode.
  136. model_eval.trainable = True
  137. model_eval.train()
  138. if self.args.distributed:
  139. # wait for all processes to synchronize
  140. dist.barrier()
  141. def train_one_epoch(self, model):
  142. # basic parameters
  143. epoch_size = len(self.train_loader)
  144. img_size = self.args.img_size
  145. t0 = time.time()
  146. nw = epoch_size * self.args.wp_epoch
  147. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  148. # Train one epoch
  149. for iter_i, (images, targets) in enumerate(self.train_loader):
  150. ni = iter_i + self.epoch * epoch_size
  151. # Warmup
  152. if ni <= nw:
  153. xi = [0, nw] # x interp
  154. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  155. for j, x in enumerate(self.optimizer.param_groups):
  156. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  157. x['lr'] = np.interp(
  158. ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  159. if 'momentum' in x:
  160. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  161. # To device
  162. images = images.to(self.device, non_blocking=True).float() / 255.
  163. # Multi scale
  164. if self.args.multi_scale:
  165. images, targets, img_size = self.rescale_image_targets(
  166. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  167. else:
  168. targets = self.refine_targets(targets, self.args.min_box_size)
  169. # Visualize train targets
  170. if self.args.vis_tgt:
  171. vis_data(images*255, targets)
  172. # Inference
  173. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  174. outputs = model(images)
  175. # Compute loss
  176. loss_dict = self.criterion(outputs=outputs, targets=targets)
  177. losses = loss_dict['losses']
  178. losses *= images.shape[0] # loss * bs
  179. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  180. if self.args.distributed:
  181. # gradient averaged between devices in DDP mode
  182. losses *= distributed_utils.get_world_size()
  183. # Backward
  184. self.scaler.scale(losses).backward()
  185. # Optimize
  186. if ni - self.last_opt_step >= accumulate:
  187. if self.model_cfg['clip_grad'] > 0:
  188. # unscale gradients
  189. self.scaler.unscale_(self.optimizer)
  190. # clip gradients
  191. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  192. # optimizer.step
  193. self.scaler.step(self.optimizer)
  194. self.scaler.update()
  195. self.optimizer.zero_grad()
  196. # ema
  197. if self.model_ema is not None:
  198. self.model_ema.update(model)
  199. self.last_opt_step = ni
  200. # Logs
  201. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  202. t1 = time.time()
  203. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  204. # basic infor
  205. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  206. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  207. log += '[lr: {:.6f}]'.format(cur_lr[2])
  208. # loss infor
  209. for k in loss_dict_reduced.keys():
  210. if k == 'losses' and self.args.distributed:
  211. world_size = distributed_utils.get_world_size()
  212. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  213. else:
  214. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  215. # other infor
  216. log += '[time: {:.2f}]'.format(t1 - t0)
  217. log += '[size: {}]'.format(img_size)
  218. # print log infor
  219. print(log, flush=True)
  220. t0 = time.time()
  221. # LR Schedule
  222. self.lr_scheduler.step()
  223. def refine_targets(self, targets, min_box_size):
  224. # rescale targets
  225. for tgt in targets:
  226. boxes = tgt["boxes"].clone()
  227. labels = tgt["labels"].clone()
  228. # refine tgt
  229. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  230. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  231. keep = (min_tgt_size >= min_box_size)
  232. tgt["boxes"] = boxes[keep]
  233. tgt["labels"] = labels[keep]
  234. return targets
  235. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  236. """
  237. Deployed for Multi scale trick.
  238. """
  239. if isinstance(stride, int):
  240. max_stride = stride
  241. elif isinstance(stride, list):
  242. max_stride = max(stride)
  243. # During training phase, the shape of input image is square.
  244. old_img_size = images.shape[-1]
  245. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  246. new_img_size = new_img_size // max_stride * max_stride # size
  247. if new_img_size / old_img_size != 1:
  248. # interpolate
  249. images = torch.nn.functional.interpolate(
  250. input=images,
  251. size=new_img_size,
  252. mode='bilinear',
  253. align_corners=False)
  254. # rescale targets
  255. for tgt in targets:
  256. boxes = tgt["boxes"].clone()
  257. labels = tgt["labels"].clone()
  258. boxes = torch.clamp(boxes, 0, old_img_size)
  259. # rescale box
  260. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  261. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  262. # refine tgt
  263. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  264. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  265. keep = (min_tgt_size >= min_box_size)
  266. tgt["boxes"] = boxes[keep]
  267. tgt["labels"] = labels[keep]
  268. return images, targets, new_img_size
  269. # Trainer refered to RTMDet
  270. class RTMTrainer(object):
  271. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  272. # ------------------- basic parameters -------------------
  273. self.args = args
  274. self.epoch = 0
  275. self.best_map = -1.
  276. self.device = device
  277. self.criterion = criterion
  278. self.heavy_eval = False
  279. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  280. self.data_cfg = data_cfg
  281. self.model_cfg = model_cfg
  282. self.trans_cfg = trans_cfg
  283. # ---------------------------- Build Transform ----------------------------
  284. self.train_transform, self.trans_cfg = build_transform(
  285. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  286. self.val_transform, _ = build_transform(
  287. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  288. # ---------------------------- Build Dataset & Dataloader ----------------------------
  289. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  290. world_size = distributed_utils.get_world_size()
  291. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  292. # ---------------------------- Build Evaluator ----------------------------
  293. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  294. # ---------------------------- Build Grad. Scaler ----------------------------
  295. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  296. # ---------------------------- Build Optimizer ----------------------------
  297. self.model_cfg['lr0'] *= self.args.batch_size / 64
  298. self.optimizer, self.start_epoch = build_yolo_optimizer(self.model_cfg, model, self.args.resume)
  299. # ---------------------------- Build LR Scheduler ----------------------------
  300. self.args.max_epoch += self.args.wp_epoch
  301. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  302. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  303. if self.args.resume:
  304. self.lr_scheduler.step()
  305. # ---------------------------- Build Model-EMA ----------------------------
  306. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  307. print('Build ModelEMA ...')
  308. self.model_ema = ModelEMA(
  309. model,
  310. self.model_cfg['ema_decay'],
  311. self.model_cfg['ema_tau'],
  312. self.start_epoch * len(self.train_loader))
  313. else:
  314. self.model_ema = None
  315. def train(self, model):
  316. for epoch in range(self.start_epoch, self.args.max_epoch):
  317. if self.args.distributed:
  318. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  319. # check second stage
  320. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  321. # close mosaic augmentation
  322. if self.train_loader.dataset.mosaic_prob > 0.:
  323. print('close Mosaic Augmentation ...')
  324. self.train_loader.dataset.mosaic_prob = 0.
  325. self.heavy_eval = True
  326. # close mixup augmentation
  327. if self.train_loader.dataset.mixup_prob > 0.:
  328. print('close Mixup Augmentation ...')
  329. self.train_loader.dataset.mixup_prob = 0.
  330. self.heavy_eval = True
  331. # train one epoch
  332. self.epoch = epoch
  333. self.train_one_epoch(model)
  334. # eval one epoch
  335. if self.heavy_eval:
  336. model_eval = model.module if self.args.distributed else model
  337. self.eval(model_eval)
  338. else:
  339. model_eval = model.module if self.args.distributed else model
  340. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  341. self.eval(model_eval)
  342. def eval(self, model):
  343. # chech model
  344. model_eval = model if self.model_ema is None else self.model_ema.ema
  345. # path to save model
  346. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  347. os.makedirs(path_to_save, exist_ok=True)
  348. if distributed_utils.is_main_process():
  349. # check evaluator
  350. if self.evaluator is None:
  351. print('No evaluator ... save model and go on training.')
  352. print('Saving state, epoch: {}'.format(self.epoch + 1))
  353. weight_name = '{}_no_eval.pth'.format(self.args.model)
  354. checkpoint_path = os.path.join(path_to_save, weight_name)
  355. torch.save({'model': model_eval.state_dict(),
  356. 'mAP': -1.,
  357. 'optimizer': self.optimizer.state_dict(),
  358. 'epoch': self.epoch,
  359. 'args': self.args},
  360. checkpoint_path)
  361. else:
  362. print('eval ...')
  363. # set eval mode
  364. model_eval.trainable = False
  365. model_eval.eval()
  366. # evaluate
  367. with torch.no_grad():
  368. self.evaluator.evaluate(model_eval)
  369. # save model
  370. cur_map = self.evaluator.map
  371. if cur_map > self.best_map:
  372. # update best-map
  373. self.best_map = cur_map
  374. # save model
  375. print('Saving state, epoch:', self.epoch + 1)
  376. weight_name = '{}_best.pth'.format(self.args.model)
  377. checkpoint_path = os.path.join(path_to_save, weight_name)
  378. torch.save({'model': model_eval.state_dict(),
  379. 'mAP': round(self.best_map*100, 1),
  380. 'optimizer': self.optimizer.state_dict(),
  381. 'epoch': self.epoch,
  382. 'args': self.args},
  383. checkpoint_path)
  384. # set train mode.
  385. model_eval.trainable = True
  386. model_eval.train()
  387. if self.args.distributed:
  388. # wait for all processes to synchronize
  389. dist.barrier()
  390. def train_one_epoch(self, model):
  391. # basic parameters
  392. epoch_size = len(self.train_loader)
  393. img_size = self.args.img_size
  394. t0 = time.time()
  395. nw = epoch_size * self.args.wp_epoch
  396. # Train one epoch
  397. for iter_i, (images, targets) in enumerate(self.train_loader):
  398. ni = iter_i + self.epoch * epoch_size
  399. # Warmup
  400. if ni <= nw:
  401. xi = [0, nw] # x interp
  402. for j, x in enumerate(self.optimizer.param_groups):
  403. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  404. x['lr'] = np.interp(
  405. ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  406. if 'momentum' in x:
  407. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  408. # To device
  409. images = images.to(self.device, non_blocking=True).float() / 255.
  410. # Multi scale
  411. if self.args.multi_scale:
  412. images, targets, img_size = self.rescale_image_targets(
  413. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  414. else:
  415. targets = self.refine_targets(targets, self.args.min_box_size)
  416. # Visualize train targets
  417. if self.args.vis_tgt:
  418. vis_data(images*255, targets)
  419. # Inference
  420. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  421. outputs = model(images)
  422. # Compute loss
  423. loss_dict = self.criterion(outputs=outputs, targets=targets)
  424. losses = loss_dict['losses']
  425. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  426. # Backward
  427. self.scaler.scale(losses).backward()
  428. # Optimize
  429. if self.model_cfg['clip_grad'] > 0:
  430. # unscale gradients
  431. self.scaler.unscale_(self.optimizer)
  432. # clip gradients
  433. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  434. # optimizer.step
  435. self.scaler.step(self.optimizer)
  436. self.scaler.update()
  437. self.optimizer.zero_grad()
  438. # ema
  439. if self.model_ema is not None:
  440. self.model_ema.update(model)
  441. # Logs
  442. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  443. t1 = time.time()
  444. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  445. # basic infor
  446. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  447. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  448. log += '[lr: {:.6f}]'.format(cur_lr[2])
  449. # loss infor
  450. for k in loss_dict_reduced.keys():
  451. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  452. # other infor
  453. log += '[time: {:.2f}]'.format(t1 - t0)
  454. log += '[size: {}]'.format(img_size)
  455. # print log infor
  456. print(log, flush=True)
  457. t0 = time.time()
  458. # LR Schedule
  459. self.lr_scheduler.step()
  460. def refine_targets(self, targets, min_box_size):
  461. # rescale targets
  462. for tgt in targets:
  463. boxes = tgt["boxes"].clone()
  464. labels = tgt["labels"].clone()
  465. # refine tgt
  466. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  467. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  468. keep = (min_tgt_size >= min_box_size)
  469. tgt["boxes"] = boxes[keep]
  470. tgt["labels"] = labels[keep]
  471. return targets
  472. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  473. """
  474. Deployed for Multi scale trick.
  475. """
  476. if isinstance(stride, int):
  477. max_stride = stride
  478. elif isinstance(stride, list):
  479. max_stride = max(stride)
  480. # During training phase, the shape of input image is square.
  481. old_img_size = images.shape[-1]
  482. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  483. new_img_size = new_img_size // max_stride * max_stride # size
  484. if new_img_size / old_img_size != 1:
  485. # interpolate
  486. images = torch.nn.functional.interpolate(
  487. input=images,
  488. size=new_img_size,
  489. mode='bilinear',
  490. align_corners=False)
  491. # rescale targets
  492. for tgt in targets:
  493. boxes = tgt["boxes"].clone()
  494. labels = tgt["labels"].clone()
  495. boxes = torch.clamp(boxes, 0, old_img_size)
  496. # rescale box
  497. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  498. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  499. # refine tgt
  500. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  501. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  502. keep = (min_tgt_size >= min_box_size)
  503. tgt["boxes"] = boxes[keep]
  504. tgt["labels"] = labels[keep]
  505. return images, targets, new_img_size
  506. # Trainer for DETR
  507. class DetrTrainer(object):
  508. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  509. # ------------------- basic parameters -------------------
  510. self.args = args
  511. self.epoch = 0
  512. self.best_map = -1.
  513. self.last_opt_step = 0
  514. self.device = device
  515. self.criterion = criterion
  516. self.heavy_eval = False
  517. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  518. self.data_cfg = data_cfg
  519. self.model_cfg = model_cfg
  520. self.trans_cfg = trans_cfg
  521. # ---------------------------- Build Transform ----------------------------
  522. self.train_transform, self.trans_cfg = build_transform(
  523. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  524. self.val_transform, _ = build_transform(
  525. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  526. # ---------------------------- Build Dataset & Dataloader ----------------------------
  527. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  528. world_size = distributed_utils.get_world_size()
  529. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  530. # ---------------------------- Build Evaluator ----------------------------
  531. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  532. # ---------------------------- Build Grad. Scaler ----------------------------
  533. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  534. # ---------------------------- Build Optimizer ----------------------------
  535. self.model_cfg['lr0'] *= self.args.batch_size / 16.
  536. self.optimizer, self.start_epoch = build_detr_optimizer(model_cfg, model, self.args.resume)
  537. # ---------------------------- Build LR Scheduler ----------------------------
  538. self.args.max_epoch += self.args.wp_epoch
  539. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  540. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  541. if self.args.resume:
  542. self.lr_scheduler.step()
  543. # ---------------------------- Build Model-EMA ----------------------------
  544. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  545. print('Build ModelEMA ...')
  546. self.model_ema = ModelEMA(
  547. model,
  548. self.model_cfg['ema_decay'],
  549. self.model_cfg['ema_tau'],
  550. self.start_epoch * len(self.train_loader))
  551. else:
  552. self.model_ema = None
  553. def train(self, model):
  554. for epoch in range(self.start_epoch, self.args.max_epoch):
  555. if self.args.distributed:
  556. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  557. # check second stage
  558. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  559. # close mosaic augmentation
  560. if self.train_loader.dataset.mosaic_prob > 0.:
  561. print('close Mosaic Augmentation ...')
  562. self.train_loader.dataset.mosaic_prob = 0.
  563. self.heavy_eval = True
  564. # close mixup augmentation
  565. if self.train_loader.dataset.mixup_prob > 0.:
  566. print('close Mixup Augmentation ...')
  567. self.train_loader.dataset.mixup_prob = 0.
  568. self.heavy_eval = True
  569. # train one epoch
  570. self.train_one_epoch(model)
  571. # eval one epoch
  572. if self.heavy_eval:
  573. model_eval = model.module if self.args.distributed else model
  574. self.eval(model_eval)
  575. else:
  576. model_eval = model.module if self.args.distributed else model
  577. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  578. self.eval(model_eval)
  579. def eval(self, model):
  580. # chech model
  581. model_eval = model if self.model_ema is None else self.model_ema.ema
  582. # path to save model
  583. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  584. os.makedirs(path_to_save, exist_ok=True)
  585. if distributed_utils.is_main_process():
  586. # check evaluator
  587. if self.evaluator is None:
  588. print('No evaluator ... save model and go on training.')
  589. print('Saving state, epoch: {}'.format(self.epoch + 1))
  590. weight_name = '{}_no_eval.pth'.format(self.args.model)
  591. checkpoint_path = os.path.join(path_to_save, weight_name)
  592. torch.save({'model': model_eval.state_dict(),
  593. 'mAP': -1.,
  594. 'optimizer': self.optimizer.state_dict(),
  595. 'epoch': self.epoch,
  596. 'args': self.args},
  597. checkpoint_path)
  598. else:
  599. print('eval ...')
  600. # set eval mode
  601. model_eval.trainable = False
  602. model_eval.eval()
  603. # evaluate
  604. with torch.no_grad():
  605. self.evaluator.evaluate(model_eval)
  606. # save model
  607. cur_map = self.evaluator.map
  608. if cur_map > self.best_map:
  609. # update best-map
  610. self.best_map = cur_map
  611. # save model
  612. print('Saving state, epoch:', self.epoch + 1)
  613. weight_name = '{}_best.pth'.format(self.args.model)
  614. checkpoint_path = os.path.join(path_to_save, weight_name)
  615. torch.save({'model': model_eval.state_dict(),
  616. 'mAP': round(self.best_map*100, 1),
  617. 'optimizer': self.optimizer.state_dict(),
  618. 'epoch': self.epoch,
  619. 'args': self.args},
  620. checkpoint_path)
  621. # set train mode.
  622. model_eval.trainable = True
  623. model_eval.train()
  624. if self.args.distributed:
  625. # wait for all processes to synchronize
  626. dist.barrier()
  627. def train_one_epoch(self, model):
  628. # basic parameters
  629. epoch_size = len(self.train_loader)
  630. img_size = self.args.img_size
  631. t0 = time.time()
  632. nw = epoch_size * self.args.wp_epoch
  633. # train one epoch
  634. for iter_i, (images, targets) in enumerate(self.train_loader):
  635. ni = iter_i + self.epoch * epoch_size
  636. # Warmup
  637. if ni <= nw:
  638. xi = [0, nw] # x interp
  639. for j, x in enumerate(self.optimizer.param_groups):
  640. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  641. x['lr'] = np.interp(
  642. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  643. if 'momentum' in x:
  644. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  645. # To device
  646. images = images.to(self.device, non_blocking=True).float() / 255.
  647. # Multi scale
  648. if self.args.multi_scale:
  649. images, targets, img_size = self.rescale_image_targets(
  650. images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
  651. else:
  652. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  653. # Visualize targets
  654. if self.args.vis_tgt:
  655. vis_data(images*255, targets)
  656. # Inference
  657. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  658. outputs = model(images)
  659. # Compute loss
  660. loss_dict = self.criterion(outputs=outputs, targets=targets)
  661. losses = loss_dict['losses']
  662. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  663. # Backward
  664. self.scaler.scale(losses).backward()
  665. # Optimize
  666. if self.model_cfg['clip_grad'] > 0:
  667. # unscale gradients
  668. self.scaler.unscale_(self.optimizer)
  669. # clip gradients
  670. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  671. self.scaler.step(self.optimizer)
  672. self.scaler.update()
  673. self.optimizer.zero_grad()
  674. # Model EMA
  675. if self.model_ema is not None:
  676. self.model_ema.update(model)
  677. self.last_opt_step = ni
  678. # Log
  679. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  680. t1 = time.time()
  681. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  682. # basic infor
  683. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  684. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  685. log += '[lr: {:.6f}]'.format(cur_lr[0])
  686. # loss infor
  687. for k in loss_dict_reduced.keys():
  688. if self.args.vis_aux_loss:
  689. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  690. else:
  691. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  692. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  693. # other infor
  694. log += '[time: {:.2f}]'.format(t1 - t0)
  695. log += '[size: {}]'.format(img_size)
  696. # print log infor
  697. print(log, flush=True)
  698. t0 = time.time()
  699. # LR Scheduler
  700. self.lr_scheduler.step()
  701. self.epoch += 1
  702. def refine_targets(self, targets, min_box_size, img_size):
  703. # rescale targets
  704. for tgt in targets:
  705. boxes = tgt["boxes"]
  706. labels = tgt["labels"]
  707. # refine tgt
  708. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  709. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  710. keep = (min_tgt_size >= min_box_size)
  711. # xyxy -> cxcywh
  712. new_boxes = torch.zeros_like(boxes)
  713. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  714. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  715. # normalize
  716. new_boxes /= img_size
  717. del boxes
  718. tgt["boxes"] = new_boxes[keep]
  719. tgt["labels"] = labels[keep]
  720. return targets
  721. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  722. """
  723. Deployed for Multi scale trick.
  724. """
  725. if isinstance(stride, int):
  726. max_stride = stride
  727. elif isinstance(stride, list):
  728. max_stride = max(stride)
  729. # During training phase, the shape of input image is square.
  730. old_img_size = images.shape[-1]
  731. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  732. new_img_size = new_img_size // max_stride * max_stride # size
  733. if new_img_size / old_img_size != 1:
  734. # interpolate
  735. images = torch.nn.functional.interpolate(
  736. input=images,
  737. size=new_img_size,
  738. mode='bilinear',
  739. align_corners=False)
  740. # rescale targets
  741. for tgt in targets:
  742. boxes = tgt["boxes"].clone()
  743. labels = tgt["labels"].clone()
  744. boxes = torch.clamp(boxes, 0, old_img_size)
  745. # rescale box
  746. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  747. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  748. # refine tgt
  749. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  750. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  751. keep = (min_tgt_size >= min_box_size)
  752. # xyxy -> cxcywh
  753. new_boxes = torch.zeros_like(boxes)
  754. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  755. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  756. # normalize
  757. new_boxes /= new_img_size
  758. del boxes
  759. tgt["boxes"] = new_boxes[keep]
  760. tgt["labels"] = labels[keep]
  761. return images, targets, new_img_size
  762. # Build Trainer
  763. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  764. if model_cfg['trainer_type'] == 'yolo':
  765. return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  766. elif model_cfg['trainer_type'] == 'rtmdet':
  767. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  768. elif model_cfg['trainer_type'] == 'detr':
  769. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  770. else:
  771. raise NotImplementedError