train.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import os
  3. import random
  4. import argparse
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. import torch.distributed as dist
  9. from torch.nn.parallel import DistributedDataParallel as DDP
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops, collate_fn
  12. from utils.optimizer import build_optimizer
  13. from utils.lr_scheduler import build_wp_lr_scheduler, build_lr_scheduler
  14. from config import build_config
  15. from evaluator import build_evluator
  16. from datasets import build_dataset, build_dataloader, build_transform
  17. from models.detectors import build_model
  18. from engine import train_one_epoch
  19. def parse_args():
  20. parser = argparse.ArgumentParser('General 2D Object Detection', add_help=False)
  21. # Random seed
  22. parser.add_argument('--seed', default=42, type=int)
  23. # GPU
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='use cuda.')
  26. # Batch size
  27. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  28. help='total batch size on all GPUs.')
  29. # Model
  30. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x',
  31. help='build object detector')
  32. parser.add_argument('-r', '--resume', default=None, type=str,
  33. help='keep training')
  34. # Dataset
  35. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  36. help='data root')
  37. parser.add_argument('-d', '--dataset', default='coco',
  38. help='coco, voc, widerface, crowdhuman')
  39. parser.add_argument('--vis_tgt', action="store_true", default=False,
  40. help="visualize input data.")
  41. # Dataloader
  42. parser.add_argument('--num_workers', default=2, type=int,
  43. help='Number of workers used in dataloading')
  44. # Epoch
  45. parser.add_argument('--save_folder', default='weights/', type=str,
  46. help='path to save weight')
  47. parser.add_argument('--eval_first', action="store_true", default=False,
  48. help="visualize input data.")
  49. # DDP train
  50. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  51. help='distributed training')
  52. parser.add_argument('--dist_url', default='env://',
  53. help='url used to set up distributed training')
  54. parser.add_argument('--world_size', default=1, type=int,
  55. help='number of distributed processes')
  56. parser.add_argument('--sybn', action='store_true', default=False,
  57. help='use sybn.')
  58. # Debug setting
  59. parser.add_argument('--debug', action='store_true', default=False,
  60. help='debug codes.')
  61. return parser.parse_args()
  62. def fix_random_seed(args):
  63. seed = args.seed + distributed_utils.get_rank()
  64. torch.manual_seed(seed)
  65. np.random.seed(seed)
  66. random.seed(seed)
  67. def main():
  68. args = parse_args()
  69. print("Setting Arguments.. : ", args)
  70. print("----------------------------------------------------------")
  71. # path to save model
  72. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  73. os.makedirs(path_to_save, exist_ok=True)
  74. # ---------------------------- Build DDP ----------------------------
  75. local_rank = local_process_rank = -1
  76. if args.distributed:
  77. distributed_utils.init_distributed_mode(args)
  78. print("git:\n {}\n".format(distributed_utils.get_sha()))
  79. try:
  80. # Multiple Mechine & Multiple GPUs (world size > 8)
  81. local_rank = torch.distributed.get_rank()
  82. local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
  83. except:
  84. # Single Mechine & Multiple GPUs (world size <= 8)
  85. local_rank = local_process_rank = torch.distributed.get_rank()
  86. world_size = distributed_utils.get_world_size()
  87. per_gpu_batch = args.batch_size // world_size
  88. print("LOCAL RANK: ", local_rank)
  89. print("LOCAL_PROCESS_RANL: ", local_process_rank)
  90. print('WORLD SIZE: {}'.format(world_size))
  91. # ---------------------------- Build CUDA ----------------------------
  92. if args.cuda and torch.cuda.is_available():
  93. print('use cuda')
  94. device = torch.device("cuda")
  95. else:
  96. device = torch.device("cpu")
  97. # ---------------------------- Fix random seed ----------------------------
  98. fix_random_seed(args)
  99. # ---------------------------- Build config ----------------------------
  100. cfg = build_config(args)
  101. # ---------------------------- Build Dataset ----------------------------
  102. transforms = build_transform(cfg, is_train=True)
  103. dataset = build_dataset(args, cfg, transforms, is_train=True)
  104. # ---------------------------- Build Dataloader ----------------------------
  105. train_loader = build_dataloader(args, dataset, per_gpu_batch, collate_fn, is_train=True)
  106. # ---------------------------- Build model ----------------------------
  107. ## Build model
  108. model, criterion = build_model(args, cfg, is_val=True)
  109. model.to(device)
  110. criterion.to(device)
  111. model_without_ddp = model
  112. ## Calcute Params & GFLOPs
  113. if distributed_utils.is_main_process():
  114. model_copy = deepcopy(model_without_ddp)
  115. model_copy.trainable = False
  116. model_copy.eval()
  117. compute_flops(model=model_copy,
  118. min_size=cfg.test_min_size,
  119. max_size=cfg.test_max_size,
  120. device=device)
  121. del model_copy
  122. if args.distributed:
  123. dist.barrier()
  124. # ---------------------------- Build Optimizer ----------------------------
  125. cfg.grad_accumulate = max(cfg.batch_size_base // args.batch_size, 1)
  126. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  127. optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, args.resume)
  128. # ---------------------------- Build LR Scheduler ----------------------------
  129. wp_lr_scheduler = build_wp_lr_scheduler(cfg)
  130. lr_scheduler = build_lr_scheduler(cfg, optimizer, args.resume)
  131. # ---------------------------- Build DDP model ----------------------------
  132. if args.distributed:
  133. model = DDP(model, device_ids=[args.gpu])
  134. model_without_ddp = model.module
  135. # ---------------------------- Build Evaluator ----------------------------
  136. evaluator = build_evluator(args, cfg, device)
  137. # ----------------------- Eval before training -----------------------
  138. if args.eval_first and distributed_utils.is_main_process():
  139. evaluator.evaluate(model_without_ddp)
  140. return
  141. # ----------------------- Training -----------------------
  142. print("Start training")
  143. best_map = cfg.best_map
  144. for epoch in range(start_epoch, cfg.max_epoch):
  145. if args.distributed:
  146. train_loader.batch_sampler.sampler.set_epoch(epoch)
  147. # Train one epoch
  148. train_one_epoch(cfg,
  149. model,
  150. criterion,
  151. train_loader,
  152. optimizer,
  153. device,
  154. epoch,
  155. args.vis_tgt,
  156. wp_lr_scheduler,
  157. debug=args.debug)
  158. # LR Scheduler
  159. lr_scheduler.step()
  160. # Evaluate
  161. if distributed_utils.is_main_process():
  162. model_eval = model_without_ddp
  163. to_save = False
  164. if (epoch % cfg.eval_epoch) == 0 or (epoch == cfg.max_epoch - 1):
  165. if evaluator is None:
  166. to_save = True
  167. else:
  168. evaluator.evaluate(model_eval)
  169. # Save model
  170. if evaluator.map >= best_map:
  171. best_map = evaluator.map
  172. to_save = True
  173. if to_save:
  174. # save model
  175. print('Saving state, epoch:', epoch)
  176. torch.save({'model': model_eval.state_dict(),
  177. 'optimizer': optimizer.state_dict(),
  178. 'lr_scheduler': lr_scheduler.state_dict(),
  179. 'mAP': round(best_map*100, 3),
  180. 'epoch': epoch,
  181. 'args': args},
  182. os.path.join(path_to_save, '{}_best.pth'.format(args.model)))
  183. if args.distributed:
  184. dist.barrier()
  185. if args.debug:
  186. print("For debug mode, we only train the model with 1 epoch.")
  187. exit(0)
  188. if __name__ == '__main__':
  189. main()