train.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. from __future__ import division
  2. import os
  3. import random
  4. import numpy as np
  5. import argparse
  6. from copy import deepcopy
  7. # ----------------- Torch Components -----------------
  8. import torch
  9. import torch.distributed as dist
  10. from torch.nn.parallel import DistributedDataParallel as DDP
  11. # ----------------- Extra Components -----------------
  12. from utils import distributed_utils
  13. from utils.misc import compute_flops
  14. # ----------------- Config Components -----------------
  15. from config import build_dataset_config, build_model_config, build_trans_config
  16. # ----------------- Model Components -----------------
  17. from models.detectors import build_model
  18. # ----------------- Train Components -----------------
  19. from engine import build_trainer
  20. def parse_args():
  21. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  22. # Random seed
  23. parser.add_argument('--seed', default=42, type=int)
  24. # GPU
  25. parser.add_argument('--cuda', action='store_true', default=False,
  26. help='use cuda.')
  27. # Image size
  28. parser.add_argument('-size', '--img_size', default=640, type=int,
  29. help='input image size')
  30. parser.add_argument('--eval_first', action='store_true', default=False,
  31. help='evaluate model before training.')
  32. # Outputs
  33. parser.add_argument('--tfboard', action='store_true', default=False,
  34. help='use tensorboard')
  35. parser.add_argument('--save_folder', default='weights/', type=str,
  36. help='path to save weight')
  37. parser.add_argument('--vis_tgt', action="store_true", default=False,
  38. help="visualize training data.")
  39. parser.add_argument('--vis_aux_loss', action="store_true", default=False,
  40. help="visualize aux loss.")
  41. # Mixing precision
  42. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  43. help="Adopting mix precision training.")
  44. # Batchsize
  45. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  46. help='batch size on all the GPUs.')
  47. # Epoch
  48. parser.add_argument('--max_epoch', default=150, type=int,
  49. help='max epoch.')
  50. parser.add_argument('--wp_epoch', default=1, type=int,
  51. help='warmup epoch.')
  52. parser.add_argument('--eval_epoch', default=10, type=int,
  53. help='after eval epoch, the model is evaluated on val dataset.')
  54. parser.add_argument('--no_aug_epoch', default=20, type=int,
  55. help='cancel strong augmentation.')
  56. # Model
  57. parser.add_argument('-m', '--model', default='yolov1', type=str,
  58. help='build yolo')
  59. parser.add_argument('-ct', '--conf_thresh', default=0.001, type=float,
  60. help='confidence threshold')
  61. parser.add_argument('-nt', '--nms_thresh', default=0.7, type=float,
  62. help='NMS threshold')
  63. parser.add_argument('--topk', default=1000, type=int,
  64. help='topk candidates dets of each level before NMS')
  65. parser.add_argument('-p', '--pretrained', default=None, type=str,
  66. help='load pretrained weight')
  67. parser.add_argument('-r', '--resume', default=None, type=str,
  68. help='keep training')
  69. parser.add_argument('--no_multi_labels', action='store_true', default=False,
  70. help='Perform NMS operations regardless of category.')
  71. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  72. help='Perform NMS operations regardless of category.')
  73. # Dataset
  74. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  75. help='data root')
  76. parser.add_argument('-d', '--dataset', default='coco',
  77. help='coco, voc, widerface, crowdhuman')
  78. parser.add_argument('--load_cache', action='store_true', default=False,
  79. help='Path to the cached data.')
  80. parser.add_argument('--num_workers', default=4, type=int,
  81. help='Number of workers used in dataloading')
  82. # Train trick
  83. parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
  84. help='Multi scale')
  85. parser.add_argument('--ema', action='store_true', default=False,
  86. help='Model EMA')
  87. parser.add_argument('--min_box_size', default=8.0, type=float,
  88. help='min size of target bounding box.')
  89. parser.add_argument('--mosaic', default=None, type=float,
  90. help='mosaic augmentation.')
  91. parser.add_argument('--mixup', default=None, type=float,
  92. help='mixup augmentation.')
  93. parser.add_argument('--grad_accumulate', default=1, type=int,
  94. help='gradient accumulation')
  95. # DDP train
  96. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  97. help='distributed training')
  98. parser.add_argument('--dist_url', default='env://',
  99. help='url used to set up distributed training')
  100. parser.add_argument('--world_size', default=1, type=int,
  101. help='number of distributed processes')
  102. parser.add_argument('--sybn', action='store_true', default=False,
  103. help='use sybn.')
  104. parser.add_argument('--find_unused_parameters', default=False, type=bool,
  105. help='set find_unused_parameters as True.')
  106. # Debug mode
  107. parser.add_argument('--debug', action='store_true', default=False,
  108. help='debug mode.')
  109. return parser.parse_args()
  110. def fix_random_seed(args):
  111. seed = args.seed + distributed_utils.get_rank()
  112. torch.manual_seed(seed)
  113. np.random.seed(seed)
  114. random.seed(seed)
  115. def train():
  116. args = parse_args()
  117. print("Setting Arguments.. : ", args)
  118. print("----------------------------------------------------------")
  119. # ---------------------------- Build DDP ----------------------------
  120. local_rank = local_process_rank = -1
  121. if args.distributed:
  122. distributed_utils.init_distributed_mode(args)
  123. print("git:\n {}\n".format(distributed_utils.get_sha()))
  124. try:
  125. # Multiple Mechine & Multiple GPUs (world size > 8)
  126. local_rank = torch.distributed.get_rank()
  127. local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
  128. except:
  129. # Single Mechine & Multiple GPUs (world size <= 8)
  130. local_rank = local_process_rank = torch.distributed.get_rank()
  131. world_size = distributed_utils.get_world_size()
  132. print("LOCAL RANK: ", local_rank)
  133. print("LOCAL_PROCESS_RANL: ", local_process_rank)
  134. print('WORLD SIZE: {}'.format(world_size))
  135. # ---------------------------- Build CUDA ----------------------------
  136. if args.cuda and torch.cuda.is_available():
  137. print('use cuda')
  138. device = torch.device("cuda")
  139. else:
  140. device = torch.device("cpu")
  141. # ---------------------------- Fix random seed ----------------------------
  142. fix_random_seed(args)
  143. # ---------------------------- Build config ----------------------------
  144. data_cfg = build_dataset_config(args)
  145. model_cfg = build_model_config(args)
  146. trans_cfg = build_trans_config(model_cfg['trans_type'])
  147. # ---------------------------- Build model ----------------------------
  148. ## Build model
  149. model, criterion = build_model(args, model_cfg, device, data_cfg['num_classes'], True)
  150. model = model.to(device).train()
  151. model_without_ddp = model
  152. if args.distributed:
  153. model = DDP(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_parameters)
  154. if args.sybn:
  155. print('use SyncBatchNorm ...')
  156. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  157. model_without_ddp = model.module
  158. ## Calcute Params & GFLOPs
  159. if distributed_utils.is_main_process:
  160. model_copy = deepcopy(model_without_ddp)
  161. model_copy.trainable = False
  162. model_copy.eval()
  163. compute_flops(model=model_copy,
  164. img_size=args.img_size,
  165. device=device)
  166. del model_copy
  167. if args.distributed:
  168. dist.barrier()
  169. # ---------------------------- Build Trainer ----------------------------
  170. trainer = build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model_without_ddp, criterion, world_size)
  171. # --------------------------------- Train: Start ---------------------------------
  172. ## Eval before training
  173. if args.eval_first and distributed_utils.is_main_process():
  174. # to check whether the evaluator can work
  175. model_eval = model_without_ddp
  176. trainer.eval(model_eval)
  177. return
  178. ## Satrt Training
  179. trainer.train(model)
  180. # --------------------------------- Train: End ---------------------------------
  181. # Empty cache after train loop
  182. del trainer
  183. if args.cuda:
  184. torch.cuda.empty_cache()
  185. if __name__ == '__main__':
  186. train()