train.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import os
  3. import random
  4. import argparse
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. import torch.distributed as dist
  9. from torch.nn.parallel import DistributedDataParallel as DDP
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops, collate_fn
  12. from utils.optimizer import build_optimizer
  13. from utils.lr_scheduler import build_wp_lr_scheduler, build_lr_scheduler
  14. from config import build_config
  15. from evaluator import build_evluator
  16. from datasets import build_dataset, build_dataloader, build_transform
  17. from models.detectors import build_model
  18. from engine import train_one_epoch
  19. def parse_args():
  20. parser = argparse.ArgumentParser('General 2D Object Detection', add_help=False)
  21. # Random seed
  22. parser.add_argument('--seed', default=42, type=int)
  23. # GPU
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='use cuda.')
  26. # Batch size
  27. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  28. help='total batch size on all GPUs.')
  29. # Model
  30. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x',
  31. help='build object detector')
  32. parser.add_argument('-r', '--resume', default=None, type=str,
  33. help='keep training')
  34. # Dataset
  35. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  36. help='data root')
  37. parser.add_argument('-d', '--dataset', default='coco',
  38. help='coco, voc, widerface, crowdhuman')
  39. parser.add_argument('--vis_tgt', action="store_true", default=False,
  40. help="visualize input data.")
  41. # Dataloader
  42. parser.add_argument('--num_workers', default=2, type=int,
  43. help='Number of workers used in dataloading')
  44. # Epoch
  45. parser.add_argument('--save_folder', default='weights/', type=str,
  46. help='path to save weight')
  47. parser.add_argument('--eval_first', action="store_true", default=False,
  48. help="visualize input data.")
  49. # DDP train
  50. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  51. help='distributed training')
  52. parser.add_argument('--dist_url', default='env://',
  53. help='url used to set up distributed training')
  54. parser.add_argument('--world_size', default=1, type=int,
  55. help='number of distributed processes')
  56. parser.add_argument('--sybn', action='store_true', default=False,
  57. help='use sybn.')
  58. # Debug setting
  59. parser.add_argument('--debug', action='store_true', default=False,
  60. help='debug codes.')
  61. return parser.parse_args()
  62. def fix_random_seed(args):
  63. seed = args.seed + distributed_utils.get_rank()
  64. torch.manual_seed(seed)
  65. np.random.seed(seed)
  66. random.seed(seed)
  67. def main():
  68. args = parse_args()
  69. print("Setting Arguments.. : ", args)
  70. print("----------------------------------------------------------")
  71. # path to save model
  72. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  73. os.makedirs(path_to_save, exist_ok=True)
  74. # ---------------------------- Build DDP ----------------------------
  75. local_rank = local_process_rank = -1
  76. if args.distributed:
  77. distributed_utils.init_distributed_mode(args)
  78. print("git:\n {}\n".format(distributed_utils.get_sha()))
  79. try:
  80. # Multiple Mechine & Multiple GPUs (world size > 8)
  81. local_rank = torch.distributed.get_rank()
  82. local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
  83. except:
  84. # Single Mechine & Multiple GPUs (world size <= 8)
  85. local_rank = local_process_rank = torch.distributed.get_rank()
  86. world_size = distributed_utils.get_world_size()
  87. per_gpu_batch = args.batch_size // world_size
  88. print("LOCAL RANK: ", local_rank)
  89. print("LOCAL_PROCESS_RANL: ", local_process_rank)
  90. print('WORLD SIZE: {}'.format(world_size))
  91. # ---------------------------- Build CUDA ----------------------------
  92. if args.cuda and torch.cuda.is_available():
  93. print('use cuda')
  94. device = torch.device("cuda")
  95. else:
  96. device = torch.device("cpu")
  97. # ---------------------------- Fix random seed ----------------------------
  98. fix_random_seed(args)
  99. # ---------------------------- Build config ----------------------------
  100. cfg = build_config(args)
  101. # ---------------------------- Build Dataset ----------------------------
  102. transforms = build_transform(cfg, is_train=True)
  103. dataset = build_dataset(args, cfg, transforms, is_train=True)
  104. # ---------------------------- Build Dataloader ----------------------------
  105. train_loader = build_dataloader(args, dataset, per_gpu_batch, collate_fn, is_train=True)
  106. # ---------------------------- Build model ----------------------------
  107. ## Build model
  108. model, criterion = build_model(args, cfg, is_val=True)
  109. model.to(device)
  110. model_without_ddp = model
  111. ## Calcute Params & GFLOPs
  112. if distributed_utils.is_main_process():
  113. model_copy = deepcopy(model_without_ddp)
  114. model_copy.trainable = False
  115. model_copy.eval()
  116. compute_flops(model=model_copy,
  117. min_size=cfg.test_min_size,
  118. max_size=cfg.test_max_size,
  119. device=device)
  120. del model_copy
  121. if args.distributed:
  122. dist.barrier()
  123. # ---------------------------- Build Optimizer ----------------------------
  124. cfg.grad_accumulate = max(cfg.batch_size_base // args.batch_size, 1)
  125. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  126. optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, args.resume)
  127. # ---------------------------- Build LR Scheduler ----------------------------
  128. wp_lr_scheduler = build_wp_lr_scheduler(cfg)
  129. lr_scheduler = build_lr_scheduler(cfg, optimizer, args.resume)
  130. # ---------------------------- Build DDP model ----------------------------
  131. if args.distributed:
  132. model = DDP(model, device_ids=[args.gpu])
  133. model_without_ddp = model.module
  134. # ---------------------------- Build Evaluator ----------------------------
  135. evaluator = build_evluator(args, cfg, device)
  136. # ----------------------- Eval before training -----------------------
  137. if args.eval_first and distributed_utils.is_main_process():
  138. evaluator.evaluate(model_without_ddp)
  139. return
  140. # ----------------------- Training -----------------------
  141. print("Start training")
  142. best_map = -1.
  143. for epoch in range(start_epoch, cfg.max_epoch):
  144. if args.distributed:
  145. train_loader.batch_sampler.sampler.set_epoch(epoch)
  146. # Train one epoch
  147. train_one_epoch(cfg,
  148. model,
  149. criterion,
  150. train_loader,
  151. optimizer,
  152. device,
  153. epoch,
  154. args.vis_tgt,
  155. wp_lr_scheduler,
  156. debug=args.debug)
  157. # LR Scheduler
  158. lr_scheduler.step()
  159. # Evaluate
  160. if distributed_utils.is_main_process():
  161. model_eval = model_without_ddp
  162. to_save = False
  163. if (epoch % cfg.eval_epoch) == 0 or (epoch == cfg.max_epoch - 1):
  164. if evaluator is None:
  165. to_save = True
  166. else:
  167. evaluator.evaluate(model_eval)
  168. # Save model
  169. if evaluator.map >= best_map:
  170. best_map = evaluator.map
  171. to_save = True
  172. if to_save:
  173. # save model
  174. print('Saving state, epoch:', epoch)
  175. torch.save({'model': model_eval.state_dict(),
  176. 'optimizer': optimizer.state_dict(),
  177. 'lr_scheduler': lr_scheduler.state_dict(),
  178. 'mAP': round(best_map*100, 1),
  179. 'epoch': epoch,
  180. 'args': args},
  181. os.path.join(path_to_save, '{}_best.pth'.format(args.model)))
  182. if args.distributed:
  183. dist.barrier()
  184. if args.debug:
  185. print("For debug mode, we only train the model with 1 epoch.")
  186. exit(0)
  187. if __name__ == '__main__':
  188. main()