train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from __future__ import division
  2. import os
  3. import argparse
  4. from copy import deepcopy
  5. # ----------------- Torch Components -----------------
  6. import torch
  7. import torch.distributed as dist
  8. from torch.nn.parallel import DistributedDataParallel as DDP
  9. # ----------------- Extra Components -----------------
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops
  12. from utils.misc import ModelEMA, CollateFunc, build_dataset, build_dataloader
  13. # ----------------- Evaluator Components -----------------
  14. from evaluator.build import build_evluator
  15. # ----------------- Optimizer & LrScheduler Components -----------------
  16. from utils.solver.optimizer import build_optimizer
  17. from utils.solver.lr_scheduler import build_lr_scheduler
  18. from engine import train_one_epoch, val_one_epoch
  19. # ----------------- Config Components -----------------
  20. from config import build_dataset_config, build_model_config, build_trans_config
  21. # ----------------- Dataset Components -----------------
  22. from dataset.build import build_dataset, build_transform
  23. # ----------------- Model Components -----------------
  24. from models.detectors import build_model
  25. def parse_args():
  26. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  27. # basic
  28. parser.add_argument('--cuda', action='store_true', default=False,
  29. help='use cuda.')
  30. parser.add_argument('-size', '--img_size', default=640, type=int,
  31. help='input image size')
  32. parser.add_argument('--num_workers', default=4, type=int,
  33. help='Number of workers used in dataloading')
  34. parser.add_argument('--tfboard', action='store_true', default=False,
  35. help='use tensorboard')
  36. parser.add_argument('--save_folder', default='weights/', type=str,
  37. help='path to save weight')
  38. parser.add_argument('--eval_first', action='store_true', default=False,
  39. help='evaluate model before training.')
  40. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  41. help="Adopting mix precision training.")
  42. parser.add_argument('--vis_tgt', action="store_true", default=False,
  43. help="visualize training data.")
  44. # Batchsize
  45. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  46. help='batch size on all the GPUs.')
  47. # Epoch
  48. parser.add_argument('--max_epoch', default=150, type=int,
  49. help='max epoch.')
  50. parser.add_argument('--wp_epoch', default=1, type=int,
  51. help='warmup epoch.')
  52. parser.add_argument('--eval_epoch', default=10, type=int,
  53. help='after eval epoch, the model is evaluated on val dataset.')
  54. parser.add_argument('--step_epoch', nargs='+', default=[90, 120], type=int,
  55. help='lr epoch to decay')
  56. # model
  57. parser.add_argument('-m', '--model', default='yolov1', type=str,
  58. help='build yolo')
  59. parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
  60. help='confidence threshold')
  61. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  62. help='NMS threshold')
  63. parser.add_argument('--topk', default=1000, type=int,
  64. help='topk candidates for evaluation')
  65. parser.add_argument('-p', '--pretrained', default=None, type=str,
  66. help='load pretrained weight')
  67. parser.add_argument('-r', '--resume', default=None, type=str,
  68. help='keep training')
  69. # dataset
  70. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  71. help='data root')
  72. parser.add_argument('-d', '--dataset', default='coco',
  73. help='coco, voc, widerface, crowdhuman')
  74. # train trick
  75. parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
  76. help='Multi scale')
  77. parser.add_argument('--ema', action='store_true', default=False,
  78. help='Model EMA')
  79. parser.add_argument('--min_box_size', default=8.0, type=float,
  80. help='min size of target bounding box.')
  81. parser.add_argument('--mosaic', default=None, type=float,
  82. help='mosaic augmentation.')
  83. parser.add_argument('--mixup', default=None, type=float,
  84. help='mixup augmentation.')
  85. # DDP train
  86. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  87. help='distributed training')
  88. parser.add_argument('--dist_url', default='env://',
  89. help='url used to set up distributed training')
  90. parser.add_argument('--world_size', default=1, type=int,
  91. help='number of distributed processes')
  92. parser.add_argument('--sybn', action='store_true', default=False,
  93. help='use sybn.')
  94. return parser.parse_args()
  95. def train():
  96. args = parse_args()
  97. print("Setting Arguments.. : ", args)
  98. print("----------------------------------------------------------")
  99. # dist
  100. world_size = distributed_utils.get_world_size()
  101. per_gpu_batch = args.batch_size // world_size
  102. print('World size: {}'.format(world_size))
  103. if args.distributed:
  104. distributed_utils.init_distributed_mode(args)
  105. print("git:\n {}\n".format(distributed_utils.get_sha()))
  106. # path to save model
  107. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  108. os.makedirs(path_to_save, exist_ok=True)
  109. # cuda
  110. if args.cuda:
  111. print('use cuda')
  112. # cudnn.benchmark = True
  113. device = torch.device("cuda")
  114. else:
  115. device = torch.device("cpu")
  116. # Dataset & Model & Trans Config
  117. data_cfg = build_dataset_config(args)
  118. model_cfg = build_model_config(args)
  119. trans_cfg = build_trans_config(model_cfg['trans_type'])
  120. # Transform
  121. train_transform, trans_config = build_transform(
  122. args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  123. val_transform, _ = build_transform(
  124. args=args, max_stride=model_cfg['max_stride'], is_train=False)
  125. # Dataset
  126. dataset, dataset_info = build_dataset(args, data_cfg, trans_config, train_transform, is_train=True)
  127. # Dataloader
  128. dataloader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
  129. # Evaluator
  130. evaluator = build_evluator(args, data_cfg, val_transform, device)
  131. # Build model
  132. model, criterion = build_model(
  133. args=args,
  134. model_cfg=model_cfg,
  135. device=device,
  136. num_classes=dataset_info['num_classes'],
  137. trainable=True,
  138. )
  139. model = model.to(device).train()
  140. # DDP
  141. model_without_ddp = model
  142. if args.distributed:
  143. model = DDP(model, device_ids=[args.gpu])
  144. model_without_ddp = model.module
  145. # SyncBatchNorm
  146. if args.sybn and args.distributed:
  147. print('use SyncBatchNorm ...')
  148. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  149. # compute FLOPs and Params
  150. if distributed_utils.is_main_process:
  151. model_copy = deepcopy(model_without_ddp)
  152. model_copy.trainable = False
  153. model_copy.eval()
  154. compute_flops(model=model_copy,
  155. img_size=args.img_size,
  156. device=device)
  157. del model_copy
  158. if args.distributed:
  159. # wait for all processes to synchronize
  160. dist.barrier()
  161. # amp
  162. scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  163. # batch size
  164. total_bs = args.batch_size
  165. accumulate = max(1, round(64 / total_bs))
  166. print('Grad_Accumulate: ', accumulate)
  167. # optimizer
  168. model_cfg['weight_decay'] *= total_bs * accumulate / 64
  169. optimizer, start_epoch = build_optimizer(model_cfg, model_without_ddp, model_cfg['lr0'], args.resume)
  170. # Scheduler
  171. total_epochs = args.max_epoch + args.wp_epoch
  172. scheduler, lf = build_lr_scheduler(model_cfg, optimizer, total_epochs)
  173. scheduler.last_epoch = start_epoch - 1 # do not move
  174. if args.resume:
  175. scheduler.step()
  176. # EMA
  177. if args.ema and distributed_utils.get_rank() in [-1, 0]:
  178. print('Build ModelEMA ...')
  179. ema = ModelEMA(model, decay=model_cfg['ema_decay'], tau=model_cfg['ema_tau'], updates=start_epoch * len(dataloader))
  180. else:
  181. ema = None
  182. # start training loop
  183. best_map = -1.0
  184. last_opt_step = -1
  185. heavy_eval = False
  186. optimizer.zero_grad()
  187. # eval before training
  188. if args.eval_first and distributed_utils.is_main_process():
  189. # to check whether the evaluator can work
  190. model_eval = ema.ema if ema else model_without_ddp
  191. val_one_epoch(
  192. args=args, model=model_eval, evaluator=evaluator, optimizer=optimizer,
  193. epoch=0, best_map=best_map, path_to_save=path_to_save)
  194. # start to train
  195. for epoch in range(start_epoch, total_epochs):
  196. if args.distributed:
  197. dataloader.batch_sampler.sampler.set_epoch(epoch)
  198. # check second stage
  199. if epoch >= (total_epochs - model_cfg['no_aug_epoch'] - 1):
  200. # close mosaic augmentation
  201. if dataloader.dataset.mosaic_prob > 0.:
  202. print('close Mosaic Augmentation ...')
  203. dataloader.dataset.mosaic_prob = 0.
  204. heavy_eval = True
  205. # close mixup augmentation
  206. if dataloader.dataset.mixup_prob > 0.:
  207. print('close Mixup Augmentation ...')
  208. dataloader.dataset.mixup_prob = 0.
  209. heavy_eval = True
  210. # train one epoch
  211. last_opt_step = train_one_epoch(
  212. epoch=epoch,
  213. total_epochs=total_epochs,
  214. args=args,
  215. device=device,
  216. ema=ema,
  217. model=model,
  218. criterion=criterion,
  219. cfg=model_cfg,
  220. dataloader=dataloader,
  221. optimizer=optimizer,
  222. scheduler=scheduler,
  223. lf=lf,
  224. scaler=scaler,
  225. last_opt_step=last_opt_step)
  226. # eval
  227. if heavy_eval:
  228. best_map = val_one_epoch(
  229. args=args,
  230. model=ema.ema if ema else model_without_ddp,
  231. evaluator=evaluator,
  232. optimizer=optimizer,
  233. epoch=epoch,
  234. best_map=best_map,
  235. path_to_save=path_to_save)
  236. else:
  237. if (epoch % args.eval_epoch) == 0 or (epoch == total_epochs - 1):
  238. best_map = val_one_epoch(
  239. args=args,
  240. model=ema.ema if ema else model_without_ddp,
  241. evaluator=evaluator,
  242. optimizer=optimizer,
  243. epoch=epoch,
  244. best_map=best_map,
  245. path_to_save=path_to_save)
  246. # Empty cache after train loop
  247. if args.cuda:
  248. torch.cuda.empty_cache()
  249. if __name__ == '__main__':
  250. train()