main.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import os
  3. import random
  4. import argparse
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. import torch.distributed as dist
  9. from torch.nn.parallel import DistributedDataParallel as DDP
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops, collate_fn
  12. from utils.optimizer import build_optimizer
  13. from utils.lr_scheduler import build_wp_lr_scheduler, build_lr_scheduler
  14. from config import build_config
  15. from evaluator import build_evluator
  16. from datasets import build_dataset, build_dataloader, build_transform
  17. from models.detectors import build_model
  18. from engine import train_one_epoch
  19. def parse_args():
  20. parser = argparse.ArgumentParser('General 2D Object Detection', add_help=False)
  21. # Random seed
  22. parser.add_argument('--seed', default=42, type=int)
  23. # GPU
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='use cuda.')
  26. # Batch size
  27. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  28. help='total batch size on all GPUs.')
  29. # Model
  30. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x',
  31. help='build object detector')
  32. parser.add_argument('-r', '--resume', default=None, type=str,
  33. help='keep training')
  34. # Dataset
  35. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  36. help='data root')
  37. parser.add_argument('-d', '--dataset', default='coco',
  38. help='coco, voc, widerface, crowdhuman')
  39. parser.add_argument('--vis_tgt', action="store_true", default=False,
  40. help="visualize input data.")
  41. # Dataloader
  42. parser.add_argument('--num_workers', default=2, type=int,
  43. help='Number of workers used in dataloading')
  44. # Epoch
  45. parser.add_argument('--save_folder', default='weights/', type=str,
  46. help='path to save weight')
  47. parser.add_argument('--eval_first', action="store_true", default=False,
  48. help="visualize input data.")
  49. # DDP train
  50. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  51. help='distributed training')
  52. parser.add_argument('--dist_url', default='env://',
  53. help='url used to set up distributed training')
  54. parser.add_argument('--world_size', default=1, type=int,
  55. help='number of distributed processes')
  56. parser.add_argument('--sybn', action='store_true', default=False,
  57. help='use sybn.')
  58. # Debug setting
  59. parser.add_argument('--debug', action='store_true', default=False,
  60. help='debug codes.')
  61. return parser.parse_args()
  62. def fix_random_seed(args):
  63. seed = args.seed + distributed_utils.get_rank()
  64. torch.manual_seed(seed)
  65. np.random.seed(seed)
  66. random.seed(seed)
  67. def main():
  68. args = parse_args()
  69. print("Setting Arguments.. : ", args)
  70. print("----------------------------------------------------------")
  71. # path to save model
  72. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  73. os.makedirs(path_to_save, exist_ok=True)
  74. # ---------------------------- Build DDP ----------------------------
  75. distributed_utils.init_distributed_mode(args)
  76. print("git:\n {}\n".format(distributed_utils.get_sha()))
  77. world_size = distributed_utils.get_world_size()
  78. print('World size: {}'.format(world_size))
  79. per_gpu_batch = args.batch_size // world_size
  80. # ---------------------------- Build CUDA ----------------------------
  81. if args.cuda and torch.cuda.is_available():
  82. print('use cuda')
  83. device = torch.device("cuda")
  84. else:
  85. device = torch.device("cpu")
  86. # ---------------------------- Fix random seed ----------------------------
  87. fix_random_seed(args)
  88. # ---------------------------- Build config ----------------------------
  89. cfg = build_config(args)
  90. print('Model config: ', cfg)
  91. # ---------------------------- Build Dataset ----------------------------
  92. transforms = build_transform(cfg, is_train=True)
  93. dataset = build_dataset(args, cfg, transforms, is_train=True)
  94. # ---------------------------- Build Dataloader ----------------------------
  95. train_loader = build_dataloader(args, dataset, per_gpu_batch, collate_fn, is_train=True)
  96. # ---------------------------- Build model ----------------------------
  97. ## Build model
  98. model, criterion = build_model(args, cfg, cfg.num_classes, is_val=True)
  99. model.to(device)
  100. model_without_ddp = model
  101. ## Calcute Params & GFLOPs
  102. if distributed_utils.is_main_process():
  103. model_copy = deepcopy(model_without_ddp)
  104. model_copy.trainable = False
  105. model_copy.eval()
  106. compute_flops(model=model_copy,
  107. min_size=cfg.test_min_size,
  108. max_size=cfg.test_max_size,
  109. device=device)
  110. del model_copy
  111. if args.distributed:
  112. dist.barrier()
  113. # ---------------------------- Build Optimizer ----------------------------
  114. cfg.grad_accumulate = max(16 // args.batch_size, 1)
  115. cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
  116. optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, args.resume)
  117. # ---------------------------- Build LR Scheduler ----------------------------
  118. wp_lr_scheduler = build_wp_lr_scheduler(cfg, cfg.base_lr)
  119. lr_scheduler = build_lr_scheduler(cfg, optimizer, args.resume)
  120. # ---------------------------- Build DDP model ----------------------------
  121. if args.distributed:
  122. model = DDP(model, device_ids=[args.gpu])
  123. model_without_ddp = model.module
  124. # ---------------------------- Build Evaluator ----------------------------
  125. evaluator = build_evluator(args, cfg, device)
  126. # ----------------------- Eval before training -----------------------
  127. if args.eval_first and distributed_utils.is_main_process():
  128. evaluator.evaluate(model_without_ddp)
  129. return
  130. # ----------------------- Training -----------------------
  131. print("Start training")
  132. best_map = -1.
  133. for epoch in range(start_epoch, cfg['max_epoch']):
  134. if args.distributed:
  135. train_loader.batch_sampler.sampler.set_epoch(epoch)
  136. # Train one epoch
  137. train_one_epoch(cfg,
  138. model,
  139. criterion,
  140. train_loader,
  141. optimizer,
  142. device,
  143. epoch,
  144. args.vis_tgt,
  145. wp_lr_scheduler,
  146. debug=args.debug)
  147. # LR Scheduler
  148. lr_scheduler.step()
  149. # Evaluate
  150. if distributed_utils.is_main_process():
  151. model_eval = model_without_ddp
  152. to_save = False
  153. if (epoch % args.eval_epoch) == 0 or (epoch == cfg['max_epoch'] - 1):
  154. if evaluator is None:
  155. to_save = True
  156. else:
  157. evaluator.evaluate(model_eval)
  158. # Save model
  159. if evaluator.map >= best_map:
  160. best_map = evaluator.map
  161. to_save = True
  162. if to_save:
  163. # save model
  164. print('Saving state, epoch:', epoch)
  165. torch.save({'model': model_eval.state_dict(),
  166. 'optimizer': optimizer.state_dict(),
  167. 'lr_scheduler': lr_scheduler.state_dict(),
  168. 'mAP': round(best_map*100, 1),
  169. 'epoch': epoch,
  170. 'args': args},
  171. os.path.join(path_to_save, '{}_best.pth'.format(args.model)))
  172. if args.distributed:
  173. dist.barrier()
  174. if args.debug:
  175. print("For debug mode, we only train the model with 1 epoch.")
  176. exit(0)
  177. if __name__ == '__main__':
  178. main()