train.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from __future__ import division
  2. import os
  3. import argparse
  4. from copy import deepcopy
  5. import torch
  6. import torch.distributed as dist
  7. from torch.nn.parallel import DistributedDataParallel as DDP
  8. from utils import distributed_utils
  9. from utils.com_flops_params import FLOPs_and_Params
  10. from utils.misc import ModelEMA, CollateFunc, build_dataset, build_dataloader
  11. from utils.solver.optimizer import build_optimizer
  12. from utils.solver.lr_scheduler import build_lr_scheduler
  13. from engine import train_one_epoch, val_one_epoch
  14. from config import build_model_config, build_trans_config
  15. from models import build_model
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  18. # basic
  19. parser.add_argument('--cuda', action='store_true', default=False,
  20. help='use cuda.')
  21. parser.add_argument('-size', '--img_size', default=640, type=int,
  22. help='input image size')
  23. parser.add_argument('--num_workers', default=4, type=int,
  24. help='Number of workers used in dataloading')
  25. parser.add_argument('--tfboard', action='store_true', default=False,
  26. help='use tensorboard')
  27. parser.add_argument('--save_folder', default='weights/', type=str,
  28. help='path to save weight')
  29. parser.add_argument('--eval_first', action='store_true', default=False,
  30. help='evaluate model before training.')
  31. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  32. help="Adopting mix precision training.")
  33. parser.add_argument('--vis_tgt', action="store_true", default=False,
  34. help="visualize training data.")
  35. # Batchsize
  36. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  37. help='batch size on all the GPUs.')
  38. # Epoch
  39. parser.add_argument('--max_epoch', default=150, type=int,
  40. help='max epoch.')
  41. parser.add_argument('--wp_epoch', default=1, type=int,
  42. help='warmup epoch.')
  43. parser.add_argument('--eval_epoch', default=10, type=int,
  44. help='after eval epoch, the model is evaluated on val dataset.')
  45. parser.add_argument('--step_epoch', nargs='+', default=[90, 120], type=int,
  46. help='lr epoch to decay')
  47. # model
  48. parser.add_argument('-m', '--model', default='yolov1', type=str,
  49. choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
  50. parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
  51. help='confidence threshold')
  52. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  53. help='NMS threshold')
  54. parser.add_argument('--topk', default=1000, type=int,
  55. help='topk candidates for evaluation')
  56. parser.add_argument('-p', '--pretrained', default=None, type=str,
  57. help='load pretrained weight')
  58. parser.add_argument('-r', '--resume', default=None, type=str,
  59. help='keep training')
  60. # dataset
  61. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  62. help='data root')
  63. parser.add_argument('-d', '--dataset', default='coco',
  64. help='coco, voc, widerface, crowdhuman')
  65. # train trick
  66. parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
  67. help='Multi scale')
  68. parser.add_argument('--ema', action='store_true', default=False,
  69. help='Model EMA')
  70. parser.add_argument('--min_box_size', default=8.0, type=float,
  71. help='min size of target bounding box.')
  72. parser.add_argument('--mosaic', default=None, type=float,
  73. help='mosaic augmentation.')
  74. parser.add_argument('--mixup', default=None, type=float,
  75. help='mixup augmentation.')
  76. # DDP train
  77. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  78. help='distributed training')
  79. parser.add_argument('--dist_url', default='env://',
  80. help='url used to set up distributed training')
  81. parser.add_argument('--world_size', default=1, type=int,
  82. help='number of distributed processes')
  83. parser.add_argument('--sybn', action='store_true', default=False,
  84. help='use sybn.')
  85. return parser.parse_args()
  86. def train():
  87. args = parse_args()
  88. print("Setting Arguments.. : ", args)
  89. print("----------------------------------------------------------")
  90. # dist
  91. world_size = distributed_utils.get_world_size()
  92. per_gpu_batch = args.batch_size // world_size
  93. print('World size: {}'.format(world_size))
  94. if args.distributed:
  95. distributed_utils.init_distributed_mode(args)
  96. print("git:\n {}\n".format(distributed_utils.get_sha()))
  97. # path to save model
  98. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  99. os.makedirs(path_to_save, exist_ok=True)
  100. # cuda
  101. if args.cuda:
  102. print('use cuda')
  103. # cudnn.benchmark = True
  104. device = torch.device("cuda")
  105. else:
  106. device = torch.device("cpu")
  107. # config
  108. model_cfg = build_model_config(args)
  109. trans_cfg = build_trans_config(model_cfg['trans_type'])
  110. # dataset and evaluator
  111. dataset, dataset_info, evaluator = build_dataset(args, trans_cfg, device, is_train=True)
  112. num_classes = dataset_info[0]
  113. # dataloader
  114. dataloader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
  115. # build model
  116. model, criterion = build_model(
  117. args=args,
  118. model_cfg=model_cfg,
  119. device=device,
  120. num_classes=num_classes,
  121. trainable=True,
  122. )
  123. model = model.to(device).train()
  124. # DDP
  125. model_without_ddp = model
  126. if args.distributed:
  127. model = DDP(model, device_ids=[args.gpu])
  128. model_without_ddp = model.module
  129. # SyncBatchNorm
  130. if args.sybn and args.distributed:
  131. print('use SyncBatchNorm ...')
  132. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  133. # compute FLOPs and Params
  134. if distributed_utils.is_main_process:
  135. model_copy = deepcopy(model_without_ddp)
  136. model_copy.trainable = False
  137. model_copy.eval()
  138. FLOPs_and_Params(model=model_copy,
  139. img_size=args.img_size,
  140. device=device)
  141. del model_copy
  142. if args.distributed:
  143. # wait for all processes to synchronize
  144. dist.barrier()
  145. # amp
  146. scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  147. # batch size
  148. total_bs = args.batch_size
  149. accumulate = max(1, round(64 / total_bs))
  150. print('Grad_Accumulate: ', accumulate)
  151. # optimizer
  152. model_cfg['weight_decay'] *= total_bs * accumulate / 64
  153. optimizer, start_epoch = build_optimizer(model_cfg, model_without_ddp, model_cfg['lr0'], args.resume)
  154. # Scheduler
  155. scheduler, lf = build_lr_scheduler(model_cfg, optimizer, args.max_epoch)
  156. scheduler.last_epoch = start_epoch - 1 # do not move
  157. if args.resume:
  158. scheduler.step()
  159. # EMA
  160. if args.ema and distributed_utils.get_rank() in [-1, 0]:
  161. print('Build ModelEMA ...')
  162. ema = ModelEMA(model, decay=model_cfg['ema_decay'], tau=model_cfg['ema_tau'], updates=start_epoch * len(dataloader))
  163. else:
  164. ema = None
  165. # start training loop
  166. best_map = -1.0
  167. last_opt_step = -1
  168. total_epochs = args.max_epoch
  169. heavy_eval = False
  170. optimizer.zero_grad()
  171. # eval before training
  172. if args.eval_first and distributed_utils.is_main_process():
  173. # to check whether the evaluator can work
  174. model_eval = ema.ema if ema else model_without_ddp
  175. val_one_epoch(
  176. args=args, model=model_eval, evaluator=evaluator, optimizer=optimizer,
  177. epoch=0, best_map=best_map, path_to_save=path_to_save)
  178. # start to train
  179. for epoch in range(start_epoch, total_epochs):
  180. if args.distributed:
  181. dataloader.batch_sampler.sampler.set_epoch(epoch)
  182. # check second stage
  183. if epoch >= (total_epochs - model_cfg['no_aug_epoch'] - 1):
  184. # close mosaic augmentation
  185. if dataloader.dataset.mosaic_prob > 0.:
  186. print('close Mosaic Augmentation ...')
  187. dataloader.dataset.mosaic_prob = 0.
  188. heavy_eval = True
  189. # close mixup augmentation
  190. if dataloader.dataset.mixup_prob > 0.:
  191. print('close Mixup Augmentation ...')
  192. dataloader.dataset.mixup_prob = 0.
  193. heavy_eval = True
  194. # train one epoch
  195. last_opt_step = train_one_epoch(
  196. epoch=epoch,
  197. total_epochs=total_epochs,
  198. args=args,
  199. device=device,
  200. ema=ema,
  201. model=model,
  202. criterion=criterion,
  203. cfg=model_cfg,
  204. dataloader=dataloader,
  205. optimizer=optimizer,
  206. scheduler=scheduler,
  207. lf=lf,
  208. scaler=scaler,
  209. last_opt_step=last_opt_step)
  210. # eval
  211. if heavy_eval:
  212. best_map = val_one_epoch(
  213. args=args,
  214. model=ema.ema if ema else model_without_ddp,
  215. evaluator=evaluator,
  216. optimizer=optimizer,
  217. epoch=epoch,
  218. best_map=best_map,
  219. path_to_save=path_to_save)
  220. else:
  221. if (epoch % args.eval_epoch) == 0 or (epoch == total_epochs - 1):
  222. best_map = val_one_epoch(
  223. args=args,
  224. model=ema.ema if ema else model_without_ddp,
  225. evaluator=evaluator,
  226. optimizer=optimizer,
  227. epoch=epoch,
  228. best_map=best_map,
  229. path_to_save=path_to_save)
  230. # Empty cache after train loop
  231. if args.cuda:
  232. torch.cuda.empty_cache()
  233. if __name__ == '__main__':
  234. train()