engine.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import numpy as np
  6. import random
  7. # ----------------- Extra Components -----------------
  8. from utils import distributed_utils
  9. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  10. from utils.vis_tools import vis_data
  11. # ----------------- Evaluator Components -----------------
  12. from evaluator.build import build_evluator
  13. # ----------------- Optimizer & LrScheduler Components -----------------
  14. from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
  15. from utils.solver.lr_scheduler import build_lr_scheduler
  16. # ----------------- Dataset Components -----------------
  17. from dataset.build import build_dataset, build_transform
  18. # Trainer refered to YOLOv8
  19. class YoloTrainer(object):
  20. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  21. # ------------------- basic parameters -------------------
  22. self.args = args
  23. self.epoch = 0
  24. self.best_map = -1.
  25. self.last_opt_step = 0
  26. self.device = device
  27. self.criterion = criterion
  28. self.heavy_eval = False
  29. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  30. self.data_cfg = data_cfg
  31. self.model_cfg = model_cfg
  32. self.trans_cfg = trans_cfg
  33. # ---------------------------- Build Transform ----------------------------
  34. self.train_transform, self.trans_cfg = build_transform(
  35. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  36. self.val_transform, _ = build_transform(
  37. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  38. # ---------------------------- Build Dataset & Dataloader ----------------------------
  39. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  40. world_size = distributed_utils.get_world_size()
  41. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  42. # ---------------------------- Build Evaluator ----------------------------
  43. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  44. # ---------------------------- Build Grad. Scaler ----------------------------
  45. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  46. # ---------------------------- Build Optimizer ----------------------------
  47. accumulate = max(1, round(64 / self.args.batch_size))
  48. self.model_cfg['weight_decay'] *= self.args.batch_size * accumulate / 64
  49. self.optimizer, self.start_epoch = build_yolo_optimizer(self.model_cfg, model, self.args.resume)
  50. # ---------------------------- Build LR Scheduler ----------------------------
  51. self.args.max_epoch += self.args.wp_epoch
  52. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  53. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  54. if self.args.resume:
  55. self.lr_scheduler.step()
  56. # ---------------------------- Build Model-EMA ----------------------------
  57. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  58. print('Build ModelEMA ...')
  59. self.model_ema = ModelEMA(
  60. model,
  61. self.model_cfg['ema_decay'],
  62. self.model_cfg['ema_tau'],
  63. self.start_epoch * len(self.train_loader))
  64. else:
  65. self.model_ema = None
  66. def train(self, model):
  67. for epoch in range(self.start_epoch, self.args.max_epoch):
  68. if self.args.distributed:
  69. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  70. # check second stage
  71. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  72. # close mosaic augmentation
  73. if self.train_loader.dataset.mosaic_prob > 0.:
  74. print('close Mosaic Augmentation ...')
  75. self.train_loader.dataset.mosaic_prob = 0.
  76. self.heavy_eval = True
  77. # close mixup augmentation
  78. if self.train_loader.dataset.mixup_prob > 0.:
  79. print('close Mixup Augmentation ...')
  80. self.train_loader.dataset.mixup_prob = 0.
  81. self.heavy_eval = True
  82. # train one epoch
  83. self.train_one_epoch(model)
  84. # eval one epoch
  85. if self.heavy_eval:
  86. model_eval = model.module if self.args.distributed else model
  87. self.eval(model_eval)
  88. else:
  89. model_eval = model.module if self.args.distributed else model
  90. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  91. self.eval(model_eval)
  92. def eval(self, model):
  93. # chech model
  94. model_eval = model if self.model_ema is None else self.model_ema.ema
  95. # path to save model
  96. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  97. os.makedirs(path_to_save, exist_ok=True)
  98. if distributed_utils.is_main_process():
  99. # check evaluator
  100. if self.evaluator is None:
  101. print('No evaluator ... save model and go on training.')
  102. print('Saving state, epoch: {}'.format(self.epoch + 1))
  103. weight_name = '{}_no_eval.pth'.format(self.args.model)
  104. checkpoint_path = os.path.join(path_to_save, weight_name)
  105. torch.save({'model': model_eval.state_dict(),
  106. 'mAP': -1.,
  107. 'optimizer': self.optimizer.state_dict(),
  108. 'epoch': self.epoch,
  109. 'args': self.args},
  110. checkpoint_path)
  111. else:
  112. print('eval ...')
  113. # set eval mode
  114. model_eval.trainable = False
  115. model_eval.eval()
  116. # evaluate
  117. with torch.no_grad():
  118. self.evaluator.evaluate(model_eval)
  119. # save model
  120. cur_map = self.evaluator.map
  121. if cur_map > self.best_map:
  122. # update best-map
  123. self.best_map = cur_map
  124. # save model
  125. print('Saving state, epoch:', self.epoch + 1)
  126. weight_name = '{}_best.pth'.format(self.args.model)
  127. checkpoint_path = os.path.join(path_to_save, weight_name)
  128. torch.save({'model': model_eval.state_dict(),
  129. 'mAP': round(self.best_map*100, 1),
  130. 'optimizer': self.optimizer.state_dict(),
  131. 'epoch': self.epoch,
  132. 'args': self.args},
  133. checkpoint_path)
  134. # set train mode.
  135. model_eval.trainable = True
  136. model_eval.train()
  137. if self.args.distributed:
  138. # wait for all processes to synchronize
  139. dist.barrier()
  140. def train_one_epoch(self, model):
  141. # basic parameters
  142. epoch_size = len(self.train_loader)
  143. img_size = self.args.img_size
  144. t0 = time.time()
  145. nw = epoch_size * self.args.wp_epoch
  146. accumulate = accumulate = max(1, round(64 / self.args.batch_size))
  147. # Train one epoch
  148. for iter_i, (images, targets) in enumerate(self.train_loader):
  149. ni = iter_i + self.epoch * epoch_size
  150. # Warmup
  151. if ni <= nw:
  152. xi = [0, nw] # x interp
  153. accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
  154. for j, x in enumerate(self.optimizer.param_groups):
  155. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  156. x['lr'] = np.interp(
  157. ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  158. if 'momentum' in x:
  159. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  160. # To device
  161. images = images.to(self.device, non_blocking=True).float() / 255.
  162. # Multi scale
  163. if self.args.multi_scale:
  164. images, targets, img_size = self.rescale_image_targets(
  165. images, targets, model.stride, self.args.min_box_size, self.model_cfg['multi_scale'])
  166. else:
  167. targets = self.refine_targets(targets, self.args.min_box_size)
  168. # Visualize train targets
  169. if self.args.vis_tgt:
  170. vis_data(images*255, targets)
  171. # Inference
  172. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  173. outputs = model(images)
  174. # Compute loss
  175. loss_dict = self.criterion(outputs=outputs, targets=targets)
  176. losses = loss_dict['losses']
  177. losses *= images.shape[0] # loss * bs
  178. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  179. if self.args.distributed:
  180. # gradient averaged between devices in DDP mode
  181. losses *= distributed_utils.get_world_size()
  182. # Backward
  183. self.scaler.scale(losses).backward()
  184. # Optimize
  185. if ni - self.last_opt_step >= accumulate:
  186. if self.model_cfg['clip_grad'] > 0:
  187. # unscale gradients
  188. self.scaler.unscale_(self.optimizer)
  189. # clip gradients
  190. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  191. # optimizer.step
  192. self.scaler.step(self.optimizer)
  193. self.scaler.update()
  194. self.optimizer.zero_grad()
  195. # ema
  196. if self.model_ema is not None:
  197. self.model_ema.update(model)
  198. self.last_opt_step = ni
  199. # Logs
  200. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  201. t1 = time.time()
  202. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  203. # basic infor
  204. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  205. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  206. log += '[lr: {:.6f}]'.format(cur_lr[2])
  207. # loss infor
  208. for k in loss_dict_reduced.keys():
  209. if k == 'losses' and self.args.distributed:
  210. world_size = distributed_utils.get_world_size()
  211. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  212. else:
  213. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  214. # other infor
  215. log += '[time: {:.2f}]'.format(t1 - t0)
  216. log += '[size: {}]'.format(img_size)
  217. # print log infor
  218. print(log, flush=True)
  219. t0 = time.time()
  220. # LR Schedule
  221. self.lr_scheduler.step()
  222. self.epoch += 1
  223. def refine_targets(self, targets, min_box_size):
  224. # rescale targets
  225. for tgt in targets:
  226. boxes = tgt["boxes"].clone()
  227. labels = tgt["labels"].clone()
  228. # refine tgt
  229. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  230. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  231. keep = (min_tgt_size >= min_box_size)
  232. tgt["boxes"] = boxes[keep]
  233. tgt["labels"] = labels[keep]
  234. return targets
  235. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  236. """
  237. Deployed for Multi scale trick.
  238. """
  239. if isinstance(stride, int):
  240. max_stride = stride
  241. elif isinstance(stride, list):
  242. max_stride = max(stride)
  243. # During training phase, the shape of input image is square.
  244. old_img_size = images.shape[-1]
  245. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  246. new_img_size = new_img_size // max_stride * max_stride # size
  247. if new_img_size / old_img_size != 1:
  248. # interpolate
  249. images = torch.nn.functional.interpolate(
  250. input=images,
  251. size=new_img_size,
  252. mode='bilinear',
  253. align_corners=False)
  254. # rescale targets
  255. for tgt in targets:
  256. boxes = tgt["boxes"].clone()
  257. labels = tgt["labels"].clone()
  258. boxes = torch.clamp(boxes, 0, old_img_size)
  259. # rescale box
  260. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  261. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  262. # refine tgt
  263. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  264. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  265. keep = (min_tgt_size >= min_box_size)
  266. tgt["boxes"] = boxes[keep]
  267. tgt["labels"] = labels[keep]
  268. return images, targets, new_img_size
  269. # Trainer refered to RTMDet
  270. class RTMTrainer(object):
  271. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  272. # ------------------- basic parameters -------------------
  273. self.args = args
  274. self.epoch = 0
  275. self.best_map = -1.
  276. self.device = device
  277. self.criterion = criterion
  278. self.heavy_eval = False
  279. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  280. self.data_cfg = data_cfg
  281. self.model_cfg = model_cfg
  282. self.trans_cfg = trans_cfg
  283. # ---------------------------- Build Transform ----------------------------
  284. self.train_transform, self.trans_cfg = build_transform(
  285. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  286. self.val_transform, _ = build_transform(
  287. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  288. # ---------------------------- Build Dataset & Dataloader ----------------------------
  289. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  290. world_size = distributed_utils.get_world_size()
  291. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  292. # ---------------------------- Build Evaluator ----------------------------
  293. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  294. # ---------------------------- Build Grad. Scaler ----------------------------
  295. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  296. # ---------------------------- Build Optimizer ----------------------------
  297. self.model_cfg['lr0'] *= self.args.batch_size / 64
  298. self.optimizer, self.start_epoch = build_yolo_optimizer(self.model_cfg, model, self.args.resume)
  299. # ---------------------------- Build LR Scheduler ----------------------------
  300. self.args.max_epoch += self.args.wp_epoch
  301. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  302. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  303. if self.args.resume:
  304. self.lr_scheduler.step()
  305. # ---------------------------- Build Model-EMA ----------------------------
  306. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  307. print('Build ModelEMA ...')
  308. self.model_ema = ModelEMA(
  309. model,
  310. self.model_cfg['ema_decay'],
  311. self.model_cfg['ema_tau'],
  312. self.start_epoch * len(self.train_loader))
  313. else:
  314. self.model_ema = None
  315. def train(self, model):
  316. for epoch in range(self.start_epoch, self.args.max_epoch):
  317. if self.args.distributed:
  318. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  319. # check second stage
  320. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  321. # close mosaic augmentation
  322. if self.train_loader.dataset.mosaic_prob > 0.:
  323. print('close Mosaic Augmentation ...')
  324. self.train_loader.dataset.mosaic_prob = 0.
  325. self.heavy_eval = True
  326. # close mixup augmentation
  327. if self.train_loader.dataset.mixup_prob > 0.:
  328. print('close Mixup Augmentation ...')
  329. self.train_loader.dataset.mixup_prob = 0.
  330. self.heavy_eval = True
  331. # train one epoch
  332. self.train_one_epoch(model)
  333. # eval one epoch
  334. if self.heavy_eval:
  335. model_eval = model.module if self.args.distributed else model
  336. self.eval(model_eval)
  337. else:
  338. model_eval = model.module if self.args.distributed else model
  339. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  340. self.eval(model_eval)
  341. def eval(self, model):
  342. # chech model
  343. model_eval = model if self.model_ema is None else self.model_ema.ema
  344. # path to save model
  345. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  346. os.makedirs(path_to_save, exist_ok=True)
  347. if distributed_utils.is_main_process():
  348. # check evaluator
  349. if self.evaluator is None:
  350. print('No evaluator ... save model and go on training.')
  351. print('Saving state, epoch: {}'.format(self.epoch + 1))
  352. weight_name = '{}_no_eval.pth'.format(self.args.model)
  353. checkpoint_path = os.path.join(path_to_save, weight_name)
  354. torch.save({'model': model_eval.state_dict(),
  355. 'mAP': -1.,
  356. 'optimizer': self.optimizer.state_dict(),
  357. 'epoch': self.epoch,
  358. 'args': self.args},
  359. checkpoint_path)
  360. else:
  361. print('eval ...')
  362. # set eval mode
  363. model_eval.trainable = False
  364. model_eval.eval()
  365. # evaluate
  366. with torch.no_grad():
  367. self.evaluator.evaluate(model_eval)
  368. # save model
  369. cur_map = self.evaluator.map
  370. if cur_map > self.best_map:
  371. # update best-map
  372. self.best_map = cur_map
  373. # save model
  374. print('Saving state, epoch:', self.epoch + 1)
  375. weight_name = '{}_best.pth'.format(self.args.model)
  376. checkpoint_path = os.path.join(path_to_save, weight_name)
  377. torch.save({'model': model_eval.state_dict(),
  378. 'mAP': round(self.best_map*100, 1),
  379. 'optimizer': self.optimizer.state_dict(),
  380. 'epoch': self.epoch,
  381. 'args': self.args},
  382. checkpoint_path)
  383. # set train mode.
  384. model_eval.trainable = True
  385. model_eval.train()
  386. if self.args.distributed:
  387. # wait for all processes to synchronize
  388. dist.barrier()
  389. def train_one_epoch(self, model):
  390. # basic parameters
  391. epoch_size = len(self.train_loader)
  392. img_size = self.args.img_size
  393. t0 = time.time()
  394. nw = epoch_size * self.args.wp_epoch
  395. # Train one epoch
  396. for iter_i, (images, targets) in enumerate(self.train_loader):
  397. ni = iter_i + self.epoch * epoch_size
  398. # Warmup
  399. if ni <= nw:
  400. xi = [0, nw] # x interp
  401. for j, x in enumerate(self.optimizer.param_groups):
  402. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  403. x['lr'] = np.interp(
  404. ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
  405. if 'momentum' in x:
  406. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  407. # To device
  408. images = images.to(self.device, non_blocking=True).float() / 255.
  409. # Multi scale
  410. if self.args.multi_scale:
  411. images, targets, img_size = self.rescale_image_targets(
  412. images, targets, model.stride, self.args.min_box_size, self.model_cfg['multi_scale'])
  413. else:
  414. targets = self.refine_targets(targets, self.args.min_box_size)
  415. # Visualize train targets
  416. if self.args.vis_tgt:
  417. vis_data(images*255, targets)
  418. # Inference
  419. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  420. outputs = model(images)
  421. # Compute loss
  422. loss_dict = self.criterion(outputs=outputs, targets=targets)
  423. losses = loss_dict['losses']
  424. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  425. # Backward
  426. self.scaler.scale(losses).backward()
  427. # Optimize
  428. if self.model_cfg['clip_grad'] > 0:
  429. # unscale gradients
  430. self.scaler.unscale_(self.optimizer)
  431. # clip gradients
  432. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  433. # optimizer.step
  434. self.scaler.step(self.optimizer)
  435. self.scaler.update()
  436. self.optimizer.zero_grad()
  437. # ema
  438. if self.model_ema is not None:
  439. self.model_ema.update(model)
  440. # Logs
  441. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  442. t1 = time.time()
  443. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  444. # basic infor
  445. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  446. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  447. log += '[lr: {:.6f}]'.format(cur_lr[2])
  448. # loss infor
  449. for k in loss_dict_reduced.keys():
  450. if k == 'losses' and self.args.distributed:
  451. world_size = distributed_utils.get_world_size()
  452. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  453. else:
  454. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  455. # other infor
  456. log += '[time: {:.2f}]'.format(t1 - t0)
  457. log += '[size: {}]'.format(img_size)
  458. # print log infor
  459. print(log, flush=True)
  460. t0 = time.time()
  461. # LR Schedule
  462. self.lr_scheduler.step()
  463. self.epoch += 1
  464. def refine_targets(self, targets, min_box_size):
  465. # rescale targets
  466. for tgt in targets:
  467. boxes = tgt["boxes"].clone()
  468. labels = tgt["labels"].clone()
  469. # refine tgt
  470. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  471. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  472. keep = (min_tgt_size >= min_box_size)
  473. tgt["boxes"] = boxes[keep]
  474. tgt["labels"] = labels[keep]
  475. return targets
  476. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  477. """
  478. Deployed for Multi scale trick.
  479. """
  480. if isinstance(stride, int):
  481. max_stride = stride
  482. elif isinstance(stride, list):
  483. max_stride = max(stride)
  484. # During training phase, the shape of input image is square.
  485. old_img_size = images.shape[-1]
  486. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  487. new_img_size = new_img_size // max_stride * max_stride # size
  488. if new_img_size / old_img_size != 1:
  489. # interpolate
  490. images = torch.nn.functional.interpolate(
  491. input=images,
  492. size=new_img_size,
  493. mode='bilinear',
  494. align_corners=False)
  495. # rescale targets
  496. for tgt in targets:
  497. boxes = tgt["boxes"].clone()
  498. labels = tgt["labels"].clone()
  499. boxes = torch.clamp(boxes, 0, old_img_size)
  500. # rescale box
  501. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  502. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  503. # refine tgt
  504. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  505. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  506. keep = (min_tgt_size >= min_box_size)
  507. tgt["boxes"] = boxes[keep]
  508. tgt["labels"] = labels[keep]
  509. return images, targets, new_img_size
  510. # Trainer for DETR
  511. class DetrTrainer(object):
  512. def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  513. # ------------------- basic parameters -------------------
  514. self.args = args
  515. self.epoch = 0
  516. self.best_map = -1.
  517. self.last_opt_step = 0
  518. self.device = device
  519. self.criterion = criterion
  520. self.heavy_eval = False
  521. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  522. self.data_cfg = data_cfg
  523. self.model_cfg = model_cfg
  524. self.trans_cfg = trans_cfg
  525. # ---------------------------- Build Transform ----------------------------
  526. self.train_transform, self.trans_cfg = build_transform(
  527. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
  528. self.val_transform, _ = build_transform(
  529. args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
  530. # ---------------------------- Build Dataset & Dataloader ----------------------------
  531. self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
  532. world_size = distributed_utils.get_world_size()
  533. self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // world_size, CollateFunc())
  534. # ---------------------------- Build Evaluator ----------------------------
  535. self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
  536. # ---------------------------- Build Grad. Scaler ----------------------------
  537. self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
  538. # ---------------------------- Build Optimizer ----------------------------
  539. self.model_cfg['lr0'] *= self.args.batch_size / 16.
  540. self.optimizer, self.start_epoch = build_detr_optimizer(model_cfg, model, self.args.resume)
  541. # ---------------------------- Build LR Scheduler ----------------------------
  542. self.args.max_epoch += self.args.wp_epoch
  543. self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, self.args.max_epoch)
  544. self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
  545. if self.args.resume:
  546. self.lr_scheduler.step()
  547. # ---------------------------- Build Model-EMA ----------------------------
  548. if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
  549. print('Build ModelEMA ...')
  550. self.model_ema = ModelEMA(
  551. model,
  552. self.model_cfg['ema_decay'],
  553. self.model_cfg['ema_tau'],
  554. self.start_epoch * len(self.train_loader))
  555. else:
  556. self.model_ema = None
  557. def train(self, model):
  558. for epoch in range(self.start_epoch, self.args.max_epoch):
  559. if self.args.distributed:
  560. self.train_loader.batch_sampler.sampler.set_epoch(epoch)
  561. # check second stage
  562. if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
  563. # close mosaic augmentation
  564. if self.train_loader.dataset.mosaic_prob > 0.:
  565. print('close Mosaic Augmentation ...')
  566. self.train_loader.dataset.mosaic_prob = 0.
  567. self.heavy_eval = True
  568. # close mixup augmentation
  569. if self.train_loader.dataset.mixup_prob > 0.:
  570. print('close Mixup Augmentation ...')
  571. self.train_loader.dataset.mixup_prob = 0.
  572. self.heavy_eval = True
  573. # train one epoch
  574. self.train_one_epoch(model)
  575. # eval one epoch
  576. if self.heavy_eval:
  577. model_eval = model.module if self.args.distributed else model
  578. self.eval(model_eval)
  579. else:
  580. model_eval = model.module if self.args.distributed else model
  581. if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
  582. self.eval(model_eval)
  583. def eval(self, model):
  584. # chech model
  585. model_eval = model if self.model_ema is None else self.model_ema.ema
  586. # path to save model
  587. path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
  588. os.makedirs(path_to_save, exist_ok=True)
  589. if distributed_utils.is_main_process():
  590. # check evaluator
  591. if self.evaluator is None:
  592. print('No evaluator ... save model and go on training.')
  593. print('Saving state, epoch: {}'.format(self.epoch + 1))
  594. weight_name = '{}_no_eval.pth'.format(self.args.model)
  595. checkpoint_path = os.path.join(path_to_save, weight_name)
  596. torch.save({'model': model_eval.state_dict(),
  597. 'mAP': -1.,
  598. 'optimizer': self.optimizer.state_dict(),
  599. 'epoch': self.epoch,
  600. 'args': self.args},
  601. checkpoint_path)
  602. else:
  603. print('eval ...')
  604. # set eval mode
  605. model_eval.trainable = False
  606. model_eval.eval()
  607. # evaluate
  608. with torch.no_grad():
  609. self.evaluator.evaluate(model_eval)
  610. # save model
  611. cur_map = self.evaluator.map
  612. if cur_map > self.best_map:
  613. # update best-map
  614. self.best_map = cur_map
  615. # save model
  616. print('Saving state, epoch:', self.epoch + 1)
  617. weight_name = '{}_best.pth'.format(self.args.model)
  618. checkpoint_path = os.path.join(path_to_save, weight_name)
  619. torch.save({'model': model_eval.state_dict(),
  620. 'mAP': round(self.best_map*100, 1),
  621. 'optimizer': self.optimizer.state_dict(),
  622. 'epoch': self.epoch,
  623. 'args': self.args},
  624. checkpoint_path)
  625. # set train mode.
  626. model_eval.trainable = True
  627. model_eval.train()
  628. if self.args.distributed:
  629. # wait for all processes to synchronize
  630. dist.barrier()
  631. def train_one_epoch(self, model):
  632. # basic parameters
  633. epoch_size = len(self.train_loader)
  634. img_size = self.args.img_size
  635. t0 = time.time()
  636. nw = epoch_size * self.args.wp_epoch
  637. # train one epoch
  638. for iter_i, (images, targets) in enumerate(self.train_loader):
  639. ni = iter_i + self.epoch * epoch_size
  640. # Warmup
  641. if ni <= nw:
  642. xi = [0, nw] # x interp
  643. for j, x in enumerate(self.optimizer.param_groups):
  644. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  645. x['lr'] = np.interp(
  646. ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
  647. if 'momentum' in x:
  648. x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
  649. # To device
  650. images = images.to(self.device, non_blocking=True).float() / 255.
  651. # Multi scale
  652. if self.args.multi_scale:
  653. images, targets, img_size = self.rescale_image_targets(
  654. images, targets, model.max_stride, self.args.min_box_size, self.model_cfg['multi_scale'])
  655. else:
  656. targets = self.refine_targets(targets, self.args.min_box_size, img_size)
  657. # Visualize targets
  658. if self.args.vis_tgt:
  659. vis_data(images*255, targets)
  660. # Inference
  661. with torch.cuda.amp.autocast(enabled=self.args.fp16):
  662. outputs = model(images)
  663. # Compute loss
  664. loss_dict = self.criterion(outputs=outputs, targets=targets)
  665. losses = loss_dict['losses']
  666. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  667. # Backward
  668. self.scaler.scale(losses).backward()
  669. # Optimize
  670. if self.model_cfg['clip_grad'] > 0:
  671. # unscale gradients
  672. self.scaler.unscale_(self.optimizer)
  673. # clip gradients
  674. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
  675. self.scaler.step(self.optimizer)
  676. self.scaler.update()
  677. self.optimizer.zero_grad()
  678. # Model EMA
  679. if self.model_ema is not None:
  680. self.model_ema.update(model)
  681. self.last_opt_step = ni
  682. # Log
  683. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  684. t1 = time.time()
  685. cur_lr = [param_group['lr'] for param_group in self.optimizer.param_groups]
  686. # basic infor
  687. log = '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
  688. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  689. log += '[lr: {:.6f}]'.format(cur_lr[0])
  690. # loss infor
  691. for k in loss_dict_reduced.keys():
  692. if self.args.vis_aux_loss:
  693. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  694. else:
  695. if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
  696. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  697. # other infor
  698. log += '[time: {:.2f}]'.format(t1 - t0)
  699. log += '[size: {}]'.format(img_size)
  700. # print log infor
  701. print(log, flush=True)
  702. t0 = time.time()
  703. # LR Scheduler
  704. self.lr_scheduler.step()
  705. self.epoch += 1
  706. def refine_targets(self, targets, min_box_size, img_size):
  707. # rescale targets
  708. for tgt in targets:
  709. boxes = tgt["boxes"]
  710. labels = tgt["labels"]
  711. # refine tgt
  712. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  713. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  714. keep = (min_tgt_size >= min_box_size)
  715. # xyxy -> cxcywh
  716. new_boxes = torch.zeros_like(boxes)
  717. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  718. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  719. # normalize
  720. new_boxes /= img_size
  721. del boxes
  722. tgt["boxes"] = new_boxes[keep]
  723. tgt["labels"] = labels[keep]
  724. return targets
  725. def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  726. """
  727. Deployed for Multi scale trick.
  728. """
  729. if isinstance(stride, int):
  730. max_stride = stride
  731. elif isinstance(stride, list):
  732. max_stride = max(stride)
  733. # During training phase, the shape of input image is square.
  734. old_img_size = images.shape[-1]
  735. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  736. new_img_size = new_img_size // max_stride * max_stride # size
  737. if new_img_size / old_img_size != 1:
  738. # interpolate
  739. images = torch.nn.functional.interpolate(
  740. input=images,
  741. size=new_img_size,
  742. mode='bilinear',
  743. align_corners=False)
  744. # rescale targets
  745. for tgt in targets:
  746. boxes = tgt["boxes"].clone()
  747. labels = tgt["labels"].clone()
  748. boxes = torch.clamp(boxes, 0, old_img_size)
  749. # rescale box
  750. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  751. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  752. # refine tgt
  753. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  754. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  755. keep = (min_tgt_size >= min_box_size)
  756. # xyxy -> cxcywh
  757. new_boxes = torch.zeros_like(boxes)
  758. new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
  759. new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
  760. # normalize
  761. new_boxes /= new_img_size
  762. del boxes
  763. tgt["boxes"] = new_boxes[keep]
  764. tgt["labels"] = labels[keep]
  765. return images, targets, new_img_size
  766. # Build Trainer
  767. def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
  768. if model_cfg['trainer_type'] == 'yolo':
  769. return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  770. elif model_cfg['trainer_type'] == 'rtmdet':
  771. return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  772. elif model_cfg['trainer_type'] == 'detr':
  773. return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
  774. else:
  775. raise NotImplementedError