train.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from __future__ import division
  2. import os
  3. import random
  4. import numpy as np
  5. import argparse
  6. from copy import deepcopy
  7. # ----------------- Torch Components -----------------
  8. import torch
  9. import torch.distributed as dist
  10. from torch.nn.parallel import DistributedDataParallel as DDP
  11. # ----------------- Extra Components -----------------
  12. from utils import distributed_utils
  13. from utils.misc import compute_flops, build_dataloader, CollateFunc
  14. from utils.ema import ModelEMA
  15. # ----------------- Config Components -----------------
  16. from config import build_config
  17. # ----------------- Data Components -----------------
  18. from dataset.build import build_dataset, build_transform
  19. # ----------------- Evaluator Components -----------------
  20. from evaluator.map_evaluator import MapEvaluator
  21. # ----------------- Model Components -----------------
  22. from models import build_model
  23. # ----------------- Train Components -----------------
  24. from engine import build_trainer
  25. def parse_args():
  26. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  27. # Random seed
  28. parser.add_argument('--seed', default=42, type=int)
  29. # GPU
  30. parser.add_argument('--cuda', action='store_true', default=False,
  31. help='use cuda.')
  32. # Image size
  33. parser.add_argument('--eval_first', action='store_true', default=False,
  34. help='evaluate model before training.')
  35. # Outputs
  36. parser.add_argument('--tfboard', action='store_true', default=False,
  37. help='use tensorboard')
  38. parser.add_argument('--save_folder', default='weights/', type=str,
  39. help='path to save weight')
  40. parser.add_argument('--vis_tgt', action="store_true", default=False,
  41. help="visualize training data.")
  42. parser.add_argument('--vis_aux_loss', action="store_true", default=False,
  43. help="visualize aux loss.")
  44. # Mixing precision
  45. parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
  46. help="Adopting mix precision training.")
  47. # Batchsize
  48. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  49. help='batch size on all the GPUs.')
  50. # Model
  51. parser.add_argument('-m', '--model', default='yolo_n', type=str,
  52. help='build yolo')
  53. parser.add_argument('-p', '--pretrained', default=None, type=str,
  54. help='load pretrained weight')
  55. parser.add_argument('-r', '--resume', default=None, type=str,
  56. help='keep training')
  57. # Dataset
  58. parser.add_argument('--root', default='D:/python_work/dataset/VOCdevkit/',
  59. help='data root')
  60. parser.add_argument('-d', '--dataset', default='coco',
  61. help='coco, voc')
  62. parser.add_argument('--num_workers', default=4, type=int,
  63. help='Number of workers used in dataloading')
  64. # DDP train
  65. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  66. help='distributed training')
  67. parser.add_argument('--dist_url', default='env://',
  68. help='url used to set up distributed training')
  69. parser.add_argument('--world_size', default=1, type=int,
  70. help='number of distributed processes')
  71. parser.add_argument('--sybn', action='store_true', default=False,
  72. help='use sybn.')
  73. parser.add_argument('--find_unused_parameters', action='store_true', default=False,
  74. help='set find_unused_parameters as True.')
  75. # Debug mode
  76. parser.add_argument('--debug', action='store_true', default=False,
  77. help='debug mode.')
  78. return parser.parse_args()
  79. def fix_random_seed(args):
  80. seed = args.seed + distributed_utils.get_rank()
  81. torch.manual_seed(seed)
  82. np.random.seed(seed)
  83. random.seed(seed)
  84. def train():
  85. args = parse_args()
  86. print("Setting Arguments.. : ", args)
  87. print("----------------------------------------------------------")
  88. # ---------------------------- Build DDP ----------------------------
  89. local_rank = local_process_rank = -1
  90. if args.distributed:
  91. distributed_utils.init_distributed_mode(args)
  92. print("git:\n {}\n".format(distributed_utils.get_sha()))
  93. try:
  94. # Multiple Mechine & Multiple GPUs (world size > 8)
  95. local_rank = torch.distributed.get_rank()
  96. local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
  97. except:
  98. # Single Mechine & Multiple GPUs (world size <= 8)
  99. local_rank = local_process_rank = torch.distributed.get_rank()
  100. world_size = distributed_utils.get_world_size()
  101. print("LOCAL RANK: ", local_rank)
  102. print("LOCAL_PROCESS_RANL: ", local_process_rank)
  103. print('WORLD SIZE: {}'.format(world_size))
  104. # ---------------------------- Build CUDA ----------------------------
  105. if args.cuda and torch.cuda.is_available():
  106. print('use cuda')
  107. device = torch.device("cuda")
  108. else:
  109. device = torch.device("cpu")
  110. # ---------------------------- Fix random seed ----------------------------
  111. fix_random_seed(args)
  112. # ---------------------------- Build config ----------------------------
  113. cfg = build_config(args)
  114. # ---------------------------- Build Transform ----------------------------
  115. train_transform = build_transform(cfg, is_train=True)
  116. val_transform = build_transform(cfg, is_train=False)
  117. # ---------------------------- Build Dataset & Dataloader ----------------------------
  118. dataset = build_dataset(args, cfg, train_transform, is_train=True)
  119. train_loader = build_dataloader(args, dataset, args.batch_size // world_size, CollateFunc())
  120. # ---------------------------- Build Evaluator ----------------------------
  121. evaluator = MapEvaluator(cfg = cfg,
  122. dataset_name = args.dataset,
  123. data_dir = args.root,
  124. device = device,
  125. transform = val_transform
  126. )
  127. # ---------------------------- Build model ----------------------------
  128. ## Build model
  129. model, criterion = build_model(args, cfg, is_val=True)
  130. model = model.to(device).train()
  131. model_without_ddp = model
  132. # ---------------------------- Build Model-EMA ----------------------------
  133. if cfg.use_ema and distributed_utils.get_rank() in [-1, 0]:
  134. print('Build ModelEMA for {} ...'.format(args.model))
  135. model_ema = ModelEMA(model, cfg.ema_decay, cfg.ema_tau, args.resume)
  136. else:
  137. model_ema = None
  138. ## Calcute Params & GFLOPs
  139. if distributed_utils.is_main_process:
  140. model_copy = deepcopy(model_without_ddp)
  141. model_copy.trainable = False
  142. model_copy.eval()
  143. compute_flops(model=model_copy,
  144. img_size=cfg.test_img_size,
  145. device=device)
  146. del model_copy
  147. if args.distributed:
  148. dist.barrier()
  149. ## Build DDP model
  150. if args.distributed:
  151. model = DDP(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_parameters)
  152. if args.sybn:
  153. print('use SyncBatchNorm ...')
  154. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  155. model_without_ddp = model.module
  156. if args.distributed:
  157. dist.barrier()
  158. # ---------------------------- Build Trainer ----------------------------
  159. trainer = build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
  160. ## Eval before training
  161. if args.eval_first and distributed_utils.is_main_process():
  162. # to check whether the evaluator can work
  163. model_eval = model_without_ddp
  164. trainer.eval(model_eval)
  165. return
  166. # garbage = torch.randn(640, 1024, 73, 73).to(device) # 15 G
  167. # ---------------------------- Train pipeline ----------------------------
  168. trainer.train(model)
  169. # Empty cache after train loop
  170. del trainer
  171. del garbage
  172. if args.cuda:
  173. torch.cuda.empty_cache()
  174. if __name__ == '__main__':
  175. train()