train.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import os
  3. import random
  4. import argparse
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. import torch.distributed as dist
  9. from torch.nn.parallel import DistributedDataParallel as DDP
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops, collate_fn
  12. from utils.optimizer import build_optimizer
  13. from utils.lr_scheduler import build_wp_lr_scheduler, build_lr_scheduler
  14. from config import build_config
  15. from evaluator import build_evluator
  16. from datasets import build_dataset, build_dataloader, build_transform
  17. from models.detectors import build_model
  18. from engine import train_one_epoch
  19. def parse_args():
  20. parser = argparse.ArgumentParser('General 2D Object Detection', add_help=False)
  21. # Random seed
  22. parser.add_argument('--seed', default=42, type=int)
  23. # GPU
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='use cuda.')
  26. # Batch size
  27. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  28. help='total batch size on all GPUs.')
  29. # Model
  30. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x',
  31. help='build object detector')
  32. parser.add_argument('-r', '--resume', default=None, type=str,
  33. help='keep training')
  34. # Dataset
  35. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  36. help='data root')
  37. parser.add_argument('-d', '--dataset', default='coco',
  38. help='coco, voc, widerface, crowdhuman')
  39. parser.add_argument('--vis_tgt', action="store_true", default=False,
  40. help="visualize input data.")
  41. # Dataloader
  42. parser.add_argument('--num_workers', default=2, type=int,
  43. help='Number of workers used in dataloading')
  44. # Epoch
  45. parser.add_argument('--save_folder', default='weights/', type=str,
  46. help='path to save weight')
  47. parser.add_argument('--eval_first', action="store_true", default=False,
  48. help="visualize input data.")
  49. # DDP train
  50. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  51. help='distributed training')
  52. parser.add_argument('--dist_url', default='env://',
  53. help='url used to set up distributed training')
  54. parser.add_argument('--world_size', default=1, type=int,
  55. help='number of distributed processes')
  56. parser.add_argument('--sybn', action='store_true', default=False,
  57. help='use sybn.')
  58. # Debug setting
  59. parser.add_argument('--debug', action='store_true', default=False,
  60. help='debug codes.')
  61. return parser.parse_args()
  62. def fix_random_seed(args):
  63. seed = args.seed + distributed_utils.get_rank()
  64. torch.manual_seed(seed)
  65. np.random.seed(seed)
  66. random.seed(seed)
  67. def main():
  68. args = parse_args()
  69. print("Setting Arguments.. : ", args)
  70. print("----------------------------------------------------------")
  71. # path to save model
  72. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  73. os.makedirs(path_to_save, exist_ok=True)
  74. # ---------------------------- Build DDP ----------------------------
  75. distributed_utils.init_distributed_mode(args)
  76. print("git:\n {}\n".format(distributed_utils.get_sha()))
  77. world_size = distributed_utils.get_world_size()
  78. print('World size: {}'.format(world_size))
  79. per_gpu_batch = args.batch_size // world_size
  80. # ---------------------------- Build CUDA ----------------------------
  81. if args.cuda and torch.cuda.is_available():
  82. print('use cuda')
  83. device = torch.device("cuda")
  84. else:
  85. device = torch.device("cpu")
  86. # ---------------------------- Fix random seed ----------------------------
  87. fix_random_seed(args)
  88. # ---------------------------- Build config ----------------------------
  89. cfg = build_config(args)
  90. print('Model config: ', cfg)
  91. # ---------------------------- Build Dataset ----------------------------
  92. transforms = build_transform(cfg, is_train=True)
  93. dataset = build_dataset(args, cfg, transforms, is_train=True)
  94. # ---------------------------- Build Dataloader ----------------------------
  95. train_loader = build_dataloader(args, dataset, per_gpu_batch, collate_fn, is_train=True)
  96. # ---------------------------- Build model ----------------------------
  97. ## Build model
  98. model, criterion = build_model(args, cfg, is_val=True)
  99. model.to(device)
  100. model_without_ddp = model
  101. ## Calcute Params & GFLOPs
  102. if distributed_utils.is_main_process():
  103. model_copy = deepcopy(model_without_ddp)
  104. model_copy.trainable = False
  105. model_copy.eval()
  106. compute_flops(model=model_copy,
  107. min_size=cfg.test_min_size,
  108. max_size=cfg.test_max_size,
  109. device=device)
  110. del model_copy
  111. if args.distributed:
  112. dist.barrier()
  113. # ---------------------------- Build Optimizer ----------------------------
  114. cfg.grad_accumulate = max(16 // args.batch_size, 1)
  115. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  116. optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, args.resume)
  117. # ---------------------------- Build LR Scheduler ----------------------------
  118. cfg.warmup_iters = cfg.warmup_iters * cfg.grad_accumulate
  119. wp_lr_scheduler = build_wp_lr_scheduler(cfg)
  120. lr_scheduler = build_lr_scheduler(cfg, optimizer, args.resume)
  121. # ---------------------------- Build DDP model ----------------------------
  122. if args.distributed:
  123. model = DDP(model, device_ids=[args.gpu])
  124. model_without_ddp = model.module
  125. # ---------------------------- Build Evaluator ----------------------------
  126. evaluator = build_evluator(args, cfg, device)
  127. # ----------------------- Eval before training -----------------------
  128. if args.eval_first and distributed_utils.is_main_process():
  129. evaluator.evaluate(model_without_ddp)
  130. return
  131. # ----------------------- Training -----------------------
  132. print("Start training")
  133. best_map = -1.
  134. for epoch in range(start_epoch, cfg.max_epoch):
  135. if args.distributed:
  136. train_loader.batch_sampler.sampler.set_epoch(epoch)
  137. # Train one epoch
  138. train_one_epoch(cfg,
  139. model,
  140. criterion,
  141. train_loader,
  142. optimizer,
  143. device,
  144. epoch,
  145. args.vis_tgt,
  146. wp_lr_scheduler,
  147. debug=args.debug)
  148. # LR Scheduler
  149. lr_scheduler.step()
  150. # Evaluate
  151. if distributed_utils.is_main_process():
  152. model_eval = model_without_ddp
  153. to_save = False
  154. if (epoch % args.eval_epoch) == 0 or (epoch == cfg.max_epoch - 1):
  155. if evaluator is None:
  156. to_save = True
  157. else:
  158. evaluator.evaluate(model_eval)
  159. # Save model
  160. if evaluator.map >= best_map:
  161. best_map = evaluator.map
  162. to_save = True
  163. if to_save:
  164. # save model
  165. print('Saving state, epoch:', epoch)
  166. torch.save({'model': model_eval.state_dict(),
  167. 'optimizer': optimizer.state_dict(),
  168. 'lr_scheduler': lr_scheduler.state_dict(),
  169. 'mAP': round(best_map*100, 1),
  170. 'epoch': epoch,
  171. 'args': args},
  172. os.path.join(path_to_save, '{}_best.pth'.format(args.model)))
  173. if args.distributed:
  174. dist.barrier()
  175. if args.debug:
  176. print("For debug mode, we only train the model with 1 epoch.")
  177. exit(0)
  178. if __name__ == '__main__':
  179. main()