train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from __future__ import division
  2. import os
  3. import argparse
  4. from copy import deepcopy
  5. # ----------------- Torch Components -----------------
  6. import torch
  7. import torch.distributed as dist
  8. from torch.nn.parallel import DistributedDataParallel as DDP
  9. # ----------------- Extra Components -----------------
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops
  12. from utils.misc import ModelEMA, CollateFunc, build_dataloader
  13. # ----------------- Evaluator Components -----------------
  14. from evaluator.build import build_evluator
  15. # ----------------- Optimizer & LrScheduler Components -----------------
  16. from utils.solver.optimizer import build_optimizer
  17. from utils.solver.lr_scheduler import build_lr_scheduler
  18. # ----------------- Config Components -----------------
  19. from config import build_dataset_config, build_model_config, build_trans_config
  20. # ----------------- Dataset Components -----------------
  21. from dataset.build import build_dataset, build_transform
  22. # ----------------- Model Components -----------------
  23. from models.detectors import build_model
  24. # ----------------- Train Components -----------------
  25. from engine import Trainer
  26. def parse_args():
  27. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  28. # basic
  29. parser.add_argument('--cuda', action='store_true', default=False,
  30. help='use cuda.')
  31. parser.add_argument('-size', '--img_size', default=640, type=int,
  32. help='input image size')
  33. parser.add_argument('--num_workers', default=4, type=int,
  34. help='Number of workers used in dataloading')
  35. parser.add_argument('--tfboard', action='store_true', default=False,
  36. help='use tensorboard')
  37. parser.add_argument('--save_folder', default='weights/', type=str,
  38. help='path to save weight')
  39. parser.add_argument('--eval_first', action='store_true', default=False,
  40. help='evaluate model before training.')
  41. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  42. help="Adopting mix precision training.")
  43. parser.add_argument('--vis_tgt', action="store_true", default=False,
  44. help="visualize training data.")
  45. # Batchsize
  46. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  47. help='batch size on all the GPUs.')
  48. # Epoch
  49. parser.add_argument('--max_epoch', default=150, type=int,
  50. help='max epoch.')
  51. parser.add_argument('--wp_epoch', default=1, type=int,
  52. help='warmup epoch.')
  53. parser.add_argument('--eval_epoch', default=10, type=int,
  54. help='after eval epoch, the model is evaluated on val dataset.')
  55. parser.add_argument('--step_epoch', nargs='+', default=[90, 120], type=int,
  56. help='lr epoch to decay')
  57. # model
  58. parser.add_argument('-m', '--model', default='yolov1', type=str,
  59. help='build yolo')
  60. parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
  61. help='confidence threshold')
  62. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  63. help='NMS threshold')
  64. parser.add_argument('--topk', default=1000, type=int,
  65. help='topk candidates for evaluation')
  66. parser.add_argument('-p', '--pretrained', default=None, type=str,
  67. help='load pretrained weight')
  68. parser.add_argument('-r', '--resume', default=None, type=str,
  69. help='keep training')
  70. # dataset
  71. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  72. help='data root')
  73. parser.add_argument('-d', '--dataset', default='coco',
  74. help='coco, voc, widerface, crowdhuman')
  75. # train trick
  76. parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
  77. help='Multi scale')
  78. parser.add_argument('--ema', action='store_true', default=False,
  79. help='Model EMA')
  80. parser.add_argument('--min_box_size', default=8.0, type=float,
  81. help='min size of target bounding box.')
  82. parser.add_argument('--mosaic', default=None, type=float,
  83. help='mosaic augmentation.')
  84. parser.add_argument('--mixup', default=None, type=float,
  85. help='mixup augmentation.')
  86. # DDP train
  87. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  88. help='distributed training')
  89. parser.add_argument('--dist_url', default='env://',
  90. help='url used to set up distributed training')
  91. parser.add_argument('--world_size', default=1, type=int,
  92. help='number of distributed processes')
  93. parser.add_argument('--sybn', action='store_true', default=False,
  94. help='use sybn.')
  95. return parser.parse_args()
  96. def train():
  97. args = parse_args()
  98. print("Setting Arguments.. : ", args)
  99. print("----------------------------------------------------------")
  100. # ---------------------------- Build DDP ----------------------------
  101. world_size = distributed_utils.get_world_size()
  102. per_gpu_batch = args.batch_size // world_size
  103. print('World size: {}'.format(world_size))
  104. if args.distributed:
  105. distributed_utils.init_distributed_mode(args)
  106. print("git:\n {}\n".format(distributed_utils.get_sha()))
  107. # ---------------------------- Build CUDA ----------------------------
  108. if args.cuda:
  109. print('use cuda')
  110. # cudnn.benchmark = True
  111. device = torch.device("cuda")
  112. else:
  113. device = torch.device("cpu")
  114. # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
  115. data_cfg = build_dataset_config(args)
  116. model_cfg = build_model_config(args)
  117. trans_cfg = build_trans_config(model_cfg['trans_type'])
  118. # ---------------------------- Build Transform ----------------------------
  119. train_transform, trans_cfg = build_transform(
  120. args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
  121. val_transform, _ = build_transform(
  122. args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
  123. # ---------------------------- Build Dataset & Dataloader ----------------------------
  124. dataset, dataset_info = build_dataset(args, data_cfg, trans_cfg, train_transform, is_train=True)
  125. train_loader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
  126. # ---------------------------- Build Evaluator ----------------------------
  127. evaluator = build_evluator(args, data_cfg, val_transform, device)
  128. # ---------------------------- Build Model ----------------------------
  129. model, criterion = build_model(args, model_cfg, device, dataset_info['num_classes'], True)
  130. model = model.to(device).train()
  131. if args.sybn and args.distributed:
  132. print('use SyncBatchNorm ...')
  133. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  134. # ---------------------------- Build DDP Model ----------------------------
  135. model_without_ddp = model
  136. if args.distributed:
  137. model = DDP(model, device_ids=[args.gpu])
  138. model_without_ddp = model.module
  139. # ---------------------------- Calcute Params & GFLOPs ----------------------------
  140. if distributed_utils.is_main_process:
  141. model_copy = deepcopy(model_without_ddp)
  142. model_copy.trainable = False
  143. model_copy.eval()
  144. compute_flops(model=model_copy,
  145. img_size=args.img_size,
  146. device=device)
  147. del model_copy
  148. if args.distributed:
  149. # wait for all processes to synchronize
  150. dist.barrier()
  151. dist.barrier()
  152. # ---------------------------- Build Grad. Scaler ----------------------------
  153. scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  154. # ---------------------------- Build Optimizer ----------------------------
  155. accumulate = max(1, round(64 / args.batch_size))
  156. print('Grad_Accumulate: ', accumulate)
  157. model_cfg['weight_decay'] *= args.batch_size * accumulate / 64
  158. optimizer, start_epoch = build_optimizer(model_cfg, model_without_ddp, model_cfg['lr0'], args.resume)
  159. # ---------------------------- Build LR Scheduler ----------------------------
  160. args.max_epoch += args.wp_epoch
  161. lr_scheduler, lf = build_lr_scheduler(model_cfg, optimizer, args.max_epoch)
  162. lr_scheduler.last_epoch = start_epoch - 1 # do not move
  163. if args.resume:
  164. lr_scheduler.step()
  165. # ---------------------------- Build Model-EMA ----------------------------
  166. if args.ema and distributed_utils.get_rank() in [-1, 0]:
  167. print('Build ModelEMA ...')
  168. model_ema = ModelEMA(model, model_cfg['ema_decay'], model_cfg['ema_tau'], start_epoch * len(train_loader))
  169. else:
  170. model_ema = None
  171. # ---------------------------- Build Trainer ----------------------------
  172. trainer = Trainer(args, device, model_cfg, model_ema, optimizer, lf, lr_scheduler, criterion, scaler)
  173. # start training loop
  174. heavy_eval = False
  175. optimizer.zero_grad()
  176. # --------------------------------- Main process for training ---------------------------------
  177. ## Eval before training
  178. if args.eval_first and distributed_utils.is_main_process():
  179. # to check whether the evaluator can work
  180. model_eval = model_without_ddp
  181. trainer.eval_one_epoch(model_eval, evaluator)
  182. ## Satrt Training
  183. for epoch in range(start_epoch, args.max_epoch):
  184. if args.distributed:
  185. train_loader.batch_sampler.sampler.set_epoch(epoch)
  186. # check second stage
  187. if epoch >= (args.max_epoch - model_cfg['no_aug_epoch'] - 1):
  188. # close mosaic augmentation
  189. if train_loader.dataset.mosaic_prob > 0.:
  190. print('close Mosaic Augmentation ...')
  191. train_loader.dataset.mosaic_prob = 0.
  192. heavy_eval = True
  193. # close mixup augmentation
  194. if train_loader.dataset.mixup_prob > 0.:
  195. print('close Mixup Augmentation ...')
  196. train_loader.dataset.mixup_prob = 0.
  197. heavy_eval = True
  198. # train one epoch
  199. trainer.train_one_epoch(model, train_loader)
  200. # eval one epoch
  201. if heavy_eval:
  202. trainer.eval_one_epoch(model_without_ddp, evaluator)
  203. else:
  204. if (epoch % args.eval_epoch) == 0 or (epoch == args.max_epoch - 1):
  205. trainer.eval_one_epoch(model_without_ddp, evaluator)
  206. # Empty cache after train loop
  207. del trainer
  208. if args.cuda:
  209. torch.cuda.empty_cache()
  210. if __name__ == '__main__':
  211. train()