train.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from __future__ import division
  2. import argparse
  3. from copy import deepcopy
  4. # ----------------- Torch Components -----------------
  5. import torch
  6. import torch.distributed as dist
  7. from torch.nn.parallel import DistributedDataParallel as DDP
  8. # ----------------- Extra Components -----------------
  9. from utils import distributed_utils
  10. from utils.misc import compute_flops
  11. # ----------------- Config Components -----------------
  12. from config import build_dataset_config, build_model_config, build_trans_config
  13. # ----------------- Model Components -----------------
  14. from models.detectors import build_model
  15. # ----------------- Train Components -----------------
  16. from engine import build_trainer
  17. def parse_args():
  18. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  19. # Basic
  20. parser.add_argument('--cuda', action='store_true', default=False,
  21. help='use cuda.')
  22. parser.add_argument('-size', '--img_size', default=640, type=int,
  23. help='input image size')
  24. parser.add_argument('--num_workers', default=4, type=int,
  25. help='Number of workers used in dataloading')
  26. parser.add_argument('--tfboard', action='store_true', default=False,
  27. help='use tensorboard')
  28. parser.add_argument('--save_folder', default='weights/', type=str,
  29. help='path to save weight')
  30. parser.add_argument('--eval_first', action='store_true', default=False,
  31. help='evaluate model before training.')
  32. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  33. help="Adopting mix precision training.")
  34. parser.add_argument('--vis_tgt', action="store_true", default=False,
  35. help="visualize training data.")
  36. parser.add_argument('--vis_aux_loss', action="store_true", default=False,
  37. help="visualize aux loss.")
  38. # Batchsize
  39. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  40. help='batch size on all the GPUs.')
  41. # Epoch
  42. parser.add_argument('--max_epoch', default=150, type=int,
  43. help='max epoch.')
  44. parser.add_argument('--wp_epoch', default=1, type=int,
  45. help='warmup epoch.')
  46. parser.add_argument('--eval_epoch', default=10, type=int,
  47. help='after eval epoch, the model is evaluated on val dataset.')
  48. parser.add_argument('--no_aug_epoch', default=20, type=int,
  49. help='cancel strong augmentation.')
  50. # Model
  51. parser.add_argument('-m', '--model', default='yolov1', type=str,
  52. help='build yolo')
  53. parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
  54. help='confidence threshold')
  55. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  56. help='NMS threshold')
  57. parser.add_argument('--topk', default=1000, type=int,
  58. help='topk candidates dets of each level before NMS')
  59. parser.add_argument('-p', '--pretrained', default=None, type=str,
  60. help='load pretrained weight')
  61. parser.add_argument('-r', '--resume', default=None, type=str,
  62. help='keep training')
  63. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  64. help='Perform NMS operations regardless of category.')
  65. # Dataset
  66. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  67. help='data root')
  68. parser.add_argument('-d', '--dataset', default='coco',
  69. help='coco, voc, widerface, crowdhuman')
  70. parser.add_argument('--load_cache', action='store_true', default=False,
  71. help='load data into memory.')
  72. # Train trick
  73. parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
  74. help='Multi scale')
  75. parser.add_argument('--ema', action='store_true', default=False,
  76. help='Model EMA')
  77. parser.add_argument('--min_box_size', default=8.0, type=float,
  78. help='min size of target bounding box.')
  79. parser.add_argument('--mosaic', default=None, type=float,
  80. help='mosaic augmentation.')
  81. parser.add_argument('--mixup', default=None, type=float,
  82. help='mixup augmentation.')
  83. parser.add_argument('--grad_accumulate', default=1, type=int,
  84. help='gradient accumulation')
  85. # DDP train
  86. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  87. help='distributed training')
  88. parser.add_argument('--dist_url', default='env://',
  89. help='url used to set up distributed training')
  90. parser.add_argument('--world_size', default=1, type=int,
  91. help='number of distributed processes')
  92. parser.add_argument('--sybn', action='store_true', default=False,
  93. help='use sybn.')
  94. return parser.parse_args()
  95. def train():
  96. args = parse_args()
  97. print("Setting Arguments.. : ", args)
  98. print("----------------------------------------------------------")
  99. # Build DDP
  100. if args.distributed:
  101. distributed_utils.init_distributed_mode(args)
  102. print("git:\n {}\n".format(distributed_utils.get_sha()))
  103. world_size = distributed_utils.get_world_size()
  104. print('World size: {}'.format(world_size))
  105. # Build CUDA
  106. if args.cuda:
  107. print('use cuda')
  108. # cudnn.benchmark = True
  109. device = torch.device("cuda")
  110. else:
  111. device = torch.device("cpu")
  112. # Build Dataset & Model & Trans. Config
  113. data_cfg = build_dataset_config(args)
  114. model_cfg = build_model_config(args)
  115. trans_cfg = build_trans_config(model_cfg['trans_type'])
  116. # Build Model
  117. model, criterion = build_model(args, model_cfg, device, data_cfg['num_classes'], True)
  118. # Keep training
  119. if distributed_utils.is_main_process and args.resume is not None:
  120. print('keep training: ', args.resume)
  121. checkpoint = torch.load(args.resume, map_location='cpu')
  122. # checkpoint state dict
  123. checkpoint_state_dict = checkpoint.pop("model")
  124. model.load_state_dict(checkpoint_state_dict)
  125. model = model.to(device).train()
  126. model_without_ddp = model
  127. if args.sybn and args.distributed:
  128. print('use SyncBatchNorm ...')
  129. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  130. if args.distributed:
  131. model = DDP(model, device_ids=[args.gpu])
  132. model_without_ddp = model.module
  133. # Calcute Params & GFLOPs
  134. if distributed_utils.is_main_process:
  135. model_copy = deepcopy(model_without_ddp)
  136. model_copy.trainable = False
  137. model_copy.eval()
  138. compute_flops(model=model_copy,
  139. img_size=args.img_size,
  140. device=device)
  141. del model_copy
  142. if args.distributed:
  143. dist.barrier()
  144. # Build Trainer
  145. trainer = build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model_without_ddp, criterion, world_size)
  146. # --------------------------------- Train: Start ---------------------------------
  147. ## Eval before training
  148. if args.eval_first and distributed_utils.is_main_process():
  149. # to check whether the evaluator can work
  150. model_eval = model_without_ddp
  151. trainer.eval(model_eval)
  152. ## Satrt Training
  153. trainer.train(model)
  154. # --------------------------------- Train: End ---------------------------------
  155. # Empty cache after train loop
  156. del trainer
  157. if args.cuda:
  158. torch.cuda.empty_cache()
  159. if __name__ == '__main__':
  160. train()