train.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from __future__ import division
  2. import os
  3. import argparse
  4. from copy import deepcopy
  5. import torch
  6. import torch.distributed as dist
  7. from torch.nn.parallel import DistributedDataParallel as DDP
  8. from utils import distributed_utils
  9. from utils.com_flops_params import FLOPs_and_Params
  10. from utils.misc import ModelEMA, CollateFunc, build_dataset, build_dataloader
  11. from utils.solver.optimizer import build_optimizer
  12. from utils.solver.warmup_schedule import build_warmup
  13. from utils.solver.lr_scheduler import build_lr_scheduler
  14. from engine import train_one_epoch, val_one_epoch
  15. from config import build_model_config, build_trans_config
  16. from models import build_model
  17. def parse_args():
  18. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  19. # basic
  20. parser.add_argument('--cuda', action='store_true', default=False,
  21. help='use cuda.')
  22. parser.add_argument('-size', '--img_size', default=640, type=int,
  23. help='input image size')
  24. parser.add_argument('--num_workers', default=4, type=int,
  25. help='Number of workers used in dataloading')
  26. parser.add_argument('--tfboard', action='store_true', default=False,
  27. help='use tensorboard')
  28. parser.add_argument('--save_folder', default='weights/', type=str,
  29. help='path to save weight')
  30. parser.add_argument('--eval_first', action='store_true', default=False,
  31. help='evaluate model before training.')
  32. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  33. help="Adopting mix precision training.")
  34. parser.add_argument('--vis_tgt', action="store_true", default=False,
  35. help="visualize training data.")
  36. # Batchsize
  37. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  38. help='batch size on all the GPUs.')
  39. # Epoch
  40. parser.add_argument('--max_epoch', default=150, type=int,
  41. help='max epoch.')
  42. parser.add_argument('--wp_epoch', default=1, type=int,
  43. help='warmup epoch.')
  44. parser.add_argument('--eval_epoch', default=10, type=int,
  45. help='after eval epoch, the model is evaluated on val dataset.')
  46. parser.add_argument('--step_epoch', nargs='+', default=[90, 120], type=int,
  47. help='lr epoch to decay')
  48. # model
  49. parser.add_argument('-m', '--model', default='yolov1', type=str,
  50. choices=['yolov1', 'yolov2', 'yolov3', 'yolov4'], help='build yolo')
  51. parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
  52. help='confidence threshold')
  53. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  54. help='NMS threshold')
  55. parser.add_argument('--topk', default=1000, type=int,
  56. help='topk candidates for evaluation')
  57. parser.add_argument('-p', '--pretrained', default=None, type=str,
  58. help='load pretrained weight')
  59. parser.add_argument('-r', '--resume', default=None, type=str,
  60. help='keep training')
  61. # dataset
  62. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  63. help='data root')
  64. parser.add_argument('-d', '--dataset', default='coco',
  65. help='coco, voc, widerface, crowdhuman')
  66. # train trick
  67. parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
  68. help='Multi scale')
  69. parser.add_argument('--ema', action='store_true', default=False,
  70. help='Model EMA')
  71. parser.add_argument('--min_box_size', default=8.0, type=float,
  72. help='min size of target bounding box.')
  73. parser.add_argument('--mosaic', default=None, type=float,
  74. help='mosaic augmentation.')
  75. parser.add_argument('--mixup', default=None, type=float,
  76. help='mixup augmentation.')
  77. # DDP train
  78. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  79. help='distributed training')
  80. parser.add_argument('--dist_url', default='env://',
  81. help='url used to set up distributed training')
  82. parser.add_argument('--world_size', default=1, type=int,
  83. help='number of distributed processes')
  84. parser.add_argument('--sybn', action='store_true', default=False,
  85. help='use sybn.')
  86. return parser.parse_args()
  87. def train():
  88. args = parse_args()
  89. print("Setting Arguments.. : ", args)
  90. print("----------------------------------------------------------")
  91. # dist
  92. world_size = distributed_utils.get_world_size()
  93. per_gpu_batch = args.batch_size // world_size
  94. print('World size: {}'.format(world_size))
  95. if args.distributed:
  96. distributed_utils.init_distributed_mode(args)
  97. print("git:\n {}\n".format(distributed_utils.get_sha()))
  98. # path to save model
  99. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  100. os.makedirs(path_to_save, exist_ok=True)
  101. # cuda
  102. if args.cuda:
  103. print('use cuda')
  104. # cudnn.benchmark = True
  105. device = torch.device("cuda")
  106. else:
  107. device = torch.device("cpu")
  108. # config
  109. model_cfg = build_model_config(args)
  110. trans_cfg = build_trans_config(model_cfg['trans_type'])
  111. # dataset and evaluator
  112. dataset, dataset_info, evaluator = build_dataset(args, trans_cfg, device, is_train=True)
  113. num_classes = dataset_info[0]
  114. # dataloader
  115. dataloader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
  116. # build model
  117. model, criterion = build_model(
  118. args=args,
  119. model_cfg=model_cfg,
  120. device=device,
  121. num_classes=num_classes,
  122. trainable=True,
  123. )
  124. model = model.to(device).train()
  125. # DDP
  126. model_without_ddp = model
  127. if args.distributed:
  128. model = DDP(model, device_ids=[args.gpu])
  129. model_without_ddp = model.module
  130. # SyncBatchNorm
  131. if args.sybn and args.distributed:
  132. print('use SyncBatchNorm ...')
  133. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  134. # compute FLOPs and Params
  135. if distributed_utils.is_main_process:
  136. model_copy = deepcopy(model_without_ddp)
  137. model_copy.trainable = False
  138. model_copy.eval()
  139. FLOPs_and_Params(model=model_copy,
  140. img_size=args.img_size,
  141. device=device)
  142. del model_copy
  143. if args.distributed:
  144. # wait for all processes to synchronize
  145. dist.barrier()
  146. # batch size
  147. total_bs = args.batch_size
  148. accumulate = max(1, round(64 / total_bs))
  149. print('Grad_Accumulate: ', accumulate)
  150. # optimizer
  151. model_cfg['weight_decay'] *= total_bs * accumulate / 64
  152. optimizer, start_epoch = build_optimizer(
  153. model_cfg, model_without_ddp, model_cfg['lr0'], args.resume)
  154. # warmup scheduler
  155. warmup_scheduler = build_warmup(
  156. model_cfg, model_cfg['lr0'], len(dataloader) * args.wp_epoch)
  157. # Scheduler
  158. lr_scheduler = build_lr_scheduler(
  159. args, model_cfg, optimizer, args.max_epoch)
  160. lr_scheduler.last_epoch = start_epoch - 1 # do not move
  161. if args.resume:
  162. lr_scheduler.step()
  163. # EMA
  164. if args.ema and distributed_utils.get_rank() in [-1, 0]:
  165. print('Build ModelEMA ...')
  166. ema = ModelEMA(model, model_cfg['ema_decay'], start_epoch * len(dataloader))
  167. else:
  168. ema = None
  169. # start training loop
  170. best_map = -1.0
  171. last_opt_step = -1
  172. total_epochs = args.max_epoch
  173. heavy_eval = False
  174. optimizer.zero_grad()
  175. # eval before training
  176. if args.eval_first and distributed_utils.is_main_process():
  177. # to check whether the evaluator can work
  178. model_eval = ema.ema if ema else model_without_ddp
  179. val_one_epoch(
  180. args=args, model=model_eval, evaluator=evaluator, optimizer=optimizer,
  181. epoch=0, best_map=best_map, path_to_save=path_to_save)
  182. # start to train
  183. for epoch in range(start_epoch, total_epochs):
  184. if args.distributed:
  185. dataloader.batch_sampler.sampler.set_epoch(epoch)
  186. # check second stage
  187. if epoch >= (total_epochs - model_cfg['no_aug_epoch'] - 1):
  188. # close mosaic augmentation
  189. if dataloader.dataset.mosaic_prob > 0.:
  190. print('close Mosaic Augmentation ...')
  191. dataloader.dataset.mosaic_prob = 0.
  192. heavy_eval = True
  193. # close mixup augmentation
  194. if dataloader.dataset.mixup_prob > 0.:
  195. print('close Mixup Augmentation ...')
  196. dataloader.dataset.mixup_prob = 0.
  197. heavy_eval = True
  198. # train one epoch
  199. last_opt_step = train_one_epoch(
  200. epoch=epoch,
  201. total_epochs=total_epochs,
  202. args=args,
  203. device=device,
  204. ema=ema,
  205. model=model,
  206. criterion=criterion,
  207. dataloader=dataloader,
  208. optimizer=optimizer,
  209. lr_scheduler=lr_scheduler,
  210. warmup_scheduler=warmup_scheduler,
  211. last_opt_step=last_opt_step
  212. )
  213. # eval
  214. if heavy_eval:
  215. best_map = val_one_epoch(
  216. args=args,
  217. model=ema.ema if ema else model_without_ddp,
  218. evaluator=evaluator,
  219. optimizer=optimizer,
  220. epoch=epoch,
  221. best_map=best_map,
  222. path_to_save=path_to_save)
  223. else:
  224. if (epoch % args.eval_epoch) == 0 or (epoch == total_epochs - 1):
  225. best_map = val_one_epoch(
  226. args=args,
  227. model=ema.ema if ema else model_without_ddp,
  228. evaluator=evaluator,
  229. optimizer=optimizer,
  230. epoch=epoch,
  231. best_map=best_map,
  232. path_to_save=path_to_save)
  233. # Empty cache after train loop
  234. if args.cuda:
  235. torch.cuda.empty_cache()
  236. if __name__ == '__main__':
  237. train()