engine.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # Trainer for YOLO
  19. class YoloTrainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.device = device
  27. self.criterion = criterion
  28. self.heavy_eval = False
  29. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  30. self.data_cfg = data_cfg
  31. self.model_cfg = model_cfg
  32. self.trans_cfg = trans_cfg
  33. # ---------------------------- Build Transform ----------------------------
  34. self.train_transform, self.trans_cfg = build_transform(
  35. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  36. self.val_transform, _ = build_transform(
  37. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  38. # ---------------------------- Build Dataset & Dataloader ----------------------------
  39. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  40. world_size = distributed_utils.get_world_size()
  41. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  42. # ---------------------------- Build Evaluator ----------------------------
  43. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  44. # ---------------------------- Build Grad. Scaler ----------------------------
  45. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  46. # ---------------------------- Build Optimizer ----------------------------
  47. accumulate = max(1, round(64 / self.args.batch_size))
  48. self.model_cfg['weight_decay'] *= self.args.batch_size * accumulate / 64
  49. self.optimizer, self.start_epoch = build_yolo_optimizer(self.model_cfg, model, self.model_cfg['lr0'], self.args.resume)
  50. # ---------------------------- Build LR Scheduler ----------------------------
  51. self.args.max_epoch += self.args.wp_epoch
  52. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  53. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  54. if self.args.resume:
  55. self.lr_scheduler.step()
  56. # ---------------------------- Build Model-EMA ----------------------------
  57. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  58. print('Build ModelEMA ...')
  59. self.model_ema = ModelEMA(
  60. model,
  61. self.model_cfg['ema_decay'],
  62. self.model_cfg['ema_tau'],
  63. self.start_epoch * len(self.train_loader))
  64. else:
  65. self.model_ema = None
  66. def train(self, model):
  67. for epoch in range(self.start_epoch, self.args.max_epoch):
  68. if self.args.distributed:
  69. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  70. # check second stage
  71. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  72. # close mosaic augmentation
  73. if self.train_loader.dataset.mosaic_prob > 0.:
  74. print('close Mosaic Augmentation ...')
  75. self.train_loader.dataset.mosaic_prob = 0.
  76. self.heavy_eval = True
  77. # close mixup augmentation
  78. if self.train_loader.dataset.mixup_prob > 0.:
  79. print('close Mixup Augmentation ...')
  80. self.train_loader.dataset.mixup_prob = 0.
  81. self.heavy_eval = True
  82. # train one epoch
  83. self.train_one_epoch(model)
  84. # eval one epoch
  85. if self.heavy_eval:
  86. model_eval = model.module if self.args.distributed else model
  87. self.eval(model_eval)
  88. else:
  89. model_eval = model.module if self.args.distributed else model
  90. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  91. self.eval(model_eval)
  92. def eval(self, model):
  93. # chech model
  94. model_eval = model if self.model_ema is None else self.model_ema.ema
  95. # path to save model
  96. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  97. os.makedirs(path_to_save, exist_ok=True)
  98. if distributed_utils.is_main_process():
  99. # check evaluator
  100. if self.evaluator is None:
  101. print('No evaluator ... save model and go on training.')
  102. print('Saving state, epoch: {}'.format(self.epoch + 1))
  103. weight_name = '{}_no_eval.pth'.format(self.args.model)
  104. checkpoint_path = os.path.join(path_to_save, weight_name)
  105. torch.save({'model': model_eval.state_dict(),
  106. 'mAP': -1.,
  107. 'optimizer': self.optimizer.state_dict(),
  108. 'epoch': self.epoch,
  109. 'args': self.args},
  110. checkpoint_path)
  111. else:
  112. print('eval ...')
  113. # set eval mode
  114. model_eval.trainable = False
  115. model_eval.eval()
  116. # evaluate
  117. with torch.no_grad():
  118. self.evaluator.evaluate(model_eval)
  119. # save model
  120. cur_map = self.evaluator.map
  121. if cur_map > self.best_map:
  122. # update best-map
  123. self.best_map = cur_map
  124. # save model
  125. print('Saving state, epoch:', self.epoch + 1)
  126. weight_name = '{}_best.pth'.format(self.args.model)
  127. checkpoint_path = os.path.join(path_to_save, weight_name)
  128. torch.save({'model': model_eval.state_dict(),
  129. 'mAP': round(self.best_map*100, 1),
  130. 'optimizer': self.optimizer.state_dict(),
  131. 'epoch': self.epoch,
  132. 'args': self.args},
  133. checkpoint_path)
  134. # set train mode.
  135. model_eval.trainable = True
  136. model_eval.train()
  137. if self.args.distributed:
  138. # wait for all processes to synchronize
  139. dist.barrier()
  140. def train_one_epoch(self, model):
  141. # basic parameters
  142. epoch_size = len(self.train_loader)
  143. img_size = self.args.img_size
  144. t0 = time.time()
  145. nw = epoch_size * self.args.wp_epoch
  146. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  147. # Train one epoch
  148. for iter_i, (images, targets) in enumerate(self.train_loader):
  149. ni = iter_i + self.epoch * epoch_size
  150. # Warmup
  151. if ni <= nw:
  152. xi = [0, nw] # x interp
  153. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  154. for j, x in enumerate(self.optimizer.param_groups):
  155. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  156. x['lr'] = np.interp(
  157. ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  158. if 'momentum' in x:
  159. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  160. # To device
  161. images = images.to(self.device, non_blocking=True).float() / 255.
  162. # Multi scale
  163. if self.args.multi_scale:
  164. images, targets, img_size = self.rescale_image_targets(
  165. images, targets, model.stride, self.args.min_box_size, self.model_cfg['multi_scale'])
  166. else:
  167. targets = self.refine_targets(targets, self.args.min_box_size)
  168. # Visualize train targets
  169. if self.args.vis_tgt:
  170. vis_data(images*255, targets)
  171. # Inference
  172. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  173. outputs = model(images)
  174. # Compute loss
  175. loss_dict = self.criterion(outputs=outputs, targets=targets)
  176. losses = loss_dict['losses']
  177. losses *= images.shape[0] # loss * bs
  178. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  179. if self.args.distributed:
  180. # gradient averaged between devices in DDP mode
  181. losses *= distributed_utils.get_world_size()
  182. # Backward
  183. self.scaler.scale(losses).backward()
  184. # Optimize
  185. if ni - self.last_opt_step >= accumulate:
  186. if self.model_cfg['clip_grad'] > 0:
  187. # unscale gradients
  188. self.scaler.unscale_(self.optimizer)
  189. # clip gradients
  190. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  191. # optimizer.step
  192. self.scaler.step(self.optimizer)
  193. self.scaler.update()
  194. self.optimizer.zero_grad()
  195. # ema
  196. if self.model_ema is not None:
  197. self.model_ema.update(model)
  198. self.last_opt_step = ni
  199. # Logs
  200. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  201. t1 = time.time()
  202. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  203. # basic infor
  204. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  205. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  206. log += '[lr: {:.6f}]'.format(cur_lr[2])
  207. # loss infor
  208. for k in loss_dict_reduced.keys():
  209. if k == 'losses' and self.args.distributed:
  210. world_size = distributed_utils.get_world_size()
  211. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  212. else:
  213. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  214. # other infor
  215. log += '[time: {:.2f}]'.format(t1 - t0)
  216. log += '[size: {}]'.format(img_size)
  217. # print log infor
  218. print(log, flush=True)
  219. t0 = time.time()
  220. # LR Schedule
  221. self.lr_scheduler.step()
  222. self.epoch += 1
  223. def refine_targets(self, targets, min_box_size):
  224. # rescale targets
  225. for tgt in targets:
  226. boxes = tgt["boxes"].clone()
  227. labels = tgt["labels"].clone()
  228. # refine tgt
  229. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  230. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  231. keep = (min_tgt_size >= min_box_size)
  232. tgt["boxes"] = boxes[keep]
  233. tgt["labels"] = labels[keep]
  234. return targets
  235. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  236. """
  237. Deployed for Multi scale trick.
  238. """
  239. if isinstance(stride, int):
  240. max_stride = stride
  241. elif isinstance(stride, list):
  242. max_stride = max(stride)
  243. # During training phase, the shape of input image is square.
  244. old_img_size = images.shape[-1]
  245. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  246. new_img_size = new_img_size // max_stride * max_stride # size
  247. if new_img_size / old_img_size != 1:
  248. # interpolate
  249. images = torch.nn.functional.interpolate(
  250. input=images,
  251. size=new_img_size,
  252. mode='bilinear',
  253. align_corners=False)
  254. # rescale targets
  255. for tgt in targets:
  256. boxes = tgt["boxes"].clone()
  257. labels = tgt["labels"].clone()
  258. boxes = torch.clamp(boxes, 0, old_img_size)
  259. # rescale box
  260. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  261. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  262. # refine tgt
  263. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  264. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  265. keep = (min_tgt_size >= min_box_size)
  266. tgt["boxes"] = boxes[keep]
  267. tgt["labels"] = labels[keep]
  268. return images, targets, new_img_size
  269. # Trainer for DETR
  270. class DetrTrainer(object):
  271. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  272. # ------------------- basic parameters -------------------
  273. self.args = args
  274. self.epoch = 0
  275. self.best_map = -1.
  276. self.last_opt_step = 0
  277. self.device = device
  278. self.criterion = criterion
  279. self.heavy_eval = False
  280. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  281. self.data_cfg = data_cfg
  282. self.model_cfg = model_cfg
  283. self.trans_cfg = trans_cfg
  284. # ---------------------------- Build Transform ----------------------------
  285. self.train_transform, self.trans_cfg = build_transform(
  286. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  287. self.val_transform, _ = build_transform(
  288. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  289. # ---------------------------- Build Dataset & Dataloader ----------------------------
  290. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  291. world_size = distributed_utils.get_world_size()
  292. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  293. # ---------------------------- Build Evaluator ----------------------------
  294. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  295. # ---------------------------- Build Grad. Scaler ----------------------------
  296. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  297. # ---------------------------- Build Optimizer ----------------------------
  298. self.model_cfg['lr0'] *= self.args.batch_size / 16.
  299. self.optimizer, self.start_epoch = build_detr_optimizer(model_cfg, model, self.args.resume)
  300. # ---------------------------- Build LR Scheduler ----------------------------
  301. self.args.max_epoch += self.args.wp_epoch
  302. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  303. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  304. if self.args.resume:
  305. self.lr_scheduler.step()
  306. # ---------------------------- Build Model-EMA ----------------------------
  307. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  308. print('Build ModelEMA ...')
  309. self.model_ema = ModelEMA(
  310. model,
  311. self.model_cfg['ema_decay'],
  312. self.model_cfg['ema_tau'],
  313. self.start_epoch * len(self.train_loader))
  314. else:
  315. self.model_ema = None
  316. def train(self, model):
  317. for epoch in range(self.start_epoch, self.args.max_epoch):
  318. if self.args.distributed:
  319. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  320. # check second stage
  321. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  322. # close mosaic augmentation
  323. if self.train_loader.dataset.mosaic_prob > 0.:
  324. print('close Mosaic Augmentation ...')
  325. self.train_loader.dataset.mosaic_prob = 0.
  326. self.heavy_eval = True
  327. # close mixup augmentation
  328. if self.train_loader.dataset.mixup_prob > 0.:
  329. print('close Mixup Augmentation ...')
  330. self.train_loader.dataset.mixup_prob = 0.
  331. self.heavy_eval = True
  332. # train one epoch
  333. self.train_one_epoch(model)
  334. # eval one epoch
  335. if self.heavy_eval:
  336. model_eval = model.module if self.args.distributed else model
  337. self.eval(model_eval)
  338. else:
  339. model_eval = model.module if self.args.distributed else model
  340. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  341. self.eval(model_eval)
  342. def eval(self, model):
  343. # chech model
  344. model_eval = model if self.model_ema is None else self.model_ema.ema
  345. # path to save model
  346. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  347. os.makedirs(path_to_save, exist_ok=True)
  348. if distributed_utils.is_main_process():
  349. # check evaluator
  350. if self.evaluator is None:
  351. print('No evaluator ... save model and go on training.')
  352. print('Saving state, epoch: {}'.format(self.epoch + 1))
  353. weight_name = '{}_no_eval.pth'.format(self.args.model)
  354. checkpoint_path = os.path.join(path_to_save, weight_name)
  355. torch.save({'model': model_eval.state_dict(),
  356. 'mAP': -1.,
  357. 'optimizer': self.optimizer.state_dict(),
  358. 'epoch': self.epoch,
  359. 'args': self.args},
  360. checkpoint_path)
  361. else:
  362. print('eval ...')
  363. # set eval mode
  364. model_eval.trainable = False
  365. model_eval.eval()
  366. # evaluate
  367. with torch.no_grad():
  368. self.evaluator.evaluate(model_eval)
  369. # save model
  370. cur_map = self.evaluator.map
  371. if cur_map > self.best_map:
  372. # update best-map
  373. self.best_map = cur_map
  374. # save model
  375. print('Saving state, epoch:', self.epoch + 1)
  376. weight_name = '{}_best.pth'.format(self.args.model)
  377. checkpoint_path = os.path.join(path_to_save, weight_name)
  378. torch.save({'model': model_eval.state_dict(),
  379. 'mAP': round(self.best_map*100, 1),
  380. 'optimizer': self.optimizer.state_dict(),
  381. 'epoch': self.epoch,
  382. 'args': self.args},
  383. checkpoint_path)
  384. # set train mode.
  385. model_eval.trainable = True
  386. model_eval.train()
  387. if self.args.distributed:
  388. # wait for all processes to synchronize
  389. dist.barrier()
  390. def train_one_epoch(self, model):
  391. # basic parameters
  392. epoch_size = len(self.train_loader)
  393. img_size = self.args.img_size
  394. t0 = time.time()
  395. nw = epoch_size * self.args.wp_epoch
  396. # train one epoch
  397. for iter_i, (images, targets) in enumerate(self.train_loader):
  398. ni = iter_i + self.epoch * epoch_size
  399. # Warmup
  400. if ni <= nw:
  401. xi = [0, nw] # x interp
  402. for j, x in enumerate(self.optimizer.param_groups):
  403. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  404. x['lr'] = np.interp(
  405. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  406. if 'momentum' in x:
  407. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  408. # To device
  409. images = images.to(self.device, non_blocking=True).float() / 255.
  410. # Multi scale
  411. if self.args.multi_scale:
  412. images, targets, img_size = self.rescale_image_targets(
  413. images, targets, model.stride, self.args.min_box_size, self.model_cfg['multi_scale'])
  414. else:
  415. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  416. # Visualize targets
  417. if self.args.vis_tgt:
  418. vis_data(images*255, targets)
  419. # Inference
  420. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  421. outputs = model(images)
  422. # Compute loss
  423. loss_dict = self.criterion(outputs=outputs, targets=targets)
  424. losses = loss_dict['losses']
  425. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  426. # Backward
  427. self.scaler.scale(losses).backward()
  428. # Optimize
  429. if self.model_cfg['clip_grad'] > 0:
  430. # unscale gradients
  431. self.scaler.unscale_(self.optimizer)
  432. # clip gradients
  433. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  434. self.scaler.step(self.optimizer)
  435. self.scaler.update()
  436. self.optimizer.zero_grad()
  437. # Model EMA
  438. if self.model_ema is not None:
  439. self.model_ema.update(model)
  440. self.last_opt_step = ni
  441. # Log
  442. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  443. t1 = time.time()
  444. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  445. # basic infor
  446. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  447. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  448. log += '[lr: {:.6f}]'.format(cur_lr[0])
  449. # loss infor
  450. for k in loss_dict_reduced.keys():
  451. if self.args.vis_aux_loss:
  452. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  453. else:
  454. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  455. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  456. # other infor
  457. log += '[time: {:.2f}]'.format(t1 - t0)
  458. log += '[size: {}]'.format(img_size)
  459. # print log infor
  460. print(log, flush=True)
  461. t0 = time.time()
  462. # LR Scheduler
  463. self.lr_scheduler.step()
  464. self.epoch += 1
  465. def refine_targets(self, targets, min_box_size, img_size):
  466. # rescale targets
  467. for tgt in targets:
  468. boxes = tgt["boxes"]
  469. labels = tgt["labels"]
  470. # refine tgt
  471. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  472. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  473. keep = (min_tgt_size >= min_box_size)
  474. # xyxy -> cxcywh
  475. new_boxes = torch.zeros_like(boxes)
  476. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  477. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  478. # normalize
  479. new_boxes /= img_size
  480. del boxes
  481. tgt["boxes"] = new_boxes[keep]
  482. tgt["labels"] = labels[keep]
  483. return targets
  484. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  485. """
  486. Deployed for Multi scale trick.
  487. """
  488. if isinstance(stride, int):
  489. max_stride = stride
  490. elif isinstance(stride, list):
  491. max_stride = max(stride)
  492. # During training phase, the shape of input image is square.
  493. old_img_size = images.shape[-1]
  494. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  495. new_img_size = new_img_size // max_stride * max_stride # size
  496. if new_img_size / old_img_size != 1:
  497. # interpolate
  498. images = torch.nn.functional.interpolate(
  499. input=images,
  500. size=new_img_size,
  501. mode='bilinear',
  502. align_corners=False)
  503. # rescale targets
  504. for tgt in targets:
  505. boxes = tgt["boxes"].clone()
  506. labels = tgt["labels"].clone()
  507. boxes = torch.clamp(boxes, 0, old_img_size)
  508. # rescale box
  509. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  510. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  511. # refine tgt
  512. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  513. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  514. keep = (min_tgt_size >= min_box_size)
  515. # xyxy -> cxcywh
  516. new_boxes = torch.zeros_like(boxes)
  517. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  518. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  519. # normalize
  520. new_boxes /= new_img_size
  521. del boxes
  522. tgt["boxes"] = new_boxes[keep]
  523. tgt["labels"] = labels[keep]
  524. return images, targets, new_img_size
  525. # Build Trainer
  526. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  527. if model_cfg['trainer_type'] == 'yolo':
  528. return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  529. elif model_cfg['trainer_type'] == 'detr':
  530. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  531. else:
  532. raise NotImplementedError