main.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import os
  3. import random
  4. import argparse
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. import torch.distributed as dist
  9. from torch.nn.parallel import DistributedDataParallel as DDP
  10. from utils import distributed_utils
  11. from utils.misc import compute_flops, collate_fn
  12. from utils.misc import get_param_dict, ModelEMA
  13. from utils.optimizer import build_optimizer
  14. from utils.lr_scheduler import build_wp_lr_scheduler, build_lr_scheduler
  15. from config import build_config
  16. from evaluator import build_evluator
  17. from datasets import build_dataset, build_dataloader, build_transform
  18. from models.detectors import build_model
  19. from engine import train_one_epoch
  20. def parse_args():
  21. parser = argparse.ArgumentParser('General 2D Object Detection', add_help=False)
  22. # Random seed
  23. parser.add_argument('--seed', default=42, type=int)
  24. # GPU
  25. parser.add_argument('--cuda', action='store_true', default=False,
  26. help='use cuda.')
  27. # Batch size
  28. parser.add_argument('-bs', '--batch_size', default=16, type=int,
  29. help='total batch size on all GPUs.')
  30. # Model
  31. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x',
  32. help='build object detector')
  33. parser.add_argument('-p', '--pretrained', default=None, type=str,
  34. help='load pretrained weight')
  35. parser.add_argument('-r', '--resume', default=None, type=str,
  36. help='keep training')
  37. parser.add_argument('--ema', default=None, type=str,
  38. help='use Model EMA trick.')
  39. # Dataset
  40. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  41. help='data root')
  42. parser.add_argument('-d', '--dataset', default='coco',
  43. help='coco, voc, widerface, crowdhuman')
  44. parser.add_argument('--vis_tgt', action="store_true", default=False,
  45. help="visualize input data.")
  46. # Dataloader
  47. parser.add_argument('--num_workers', default=2, type=int,
  48. help='Number of workers used in dataloading')
  49. # Epoch
  50. parser.add_argument('--eval_epoch', default=2, type=int,
  51. help='interval between evaluations')
  52. parser.add_argument('--save_folder', default='weights/', type=str,
  53. help='path to save weight')
  54. parser.add_argument('--eval_first', action="store_true", default=False,
  55. help="visualize input data.")
  56. # DDP train
  57. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  58. help='distributed training')
  59. parser.add_argument('--dist_url', default='env://',
  60. help='url used to set up distributed training')
  61. parser.add_argument('--world_size', default=1, type=int,
  62. help='number of distributed processes')
  63. parser.add_argument('--sybn', action='store_true', default=False,
  64. help='use sybn.')
  65. parser.add_argument('--find_unused_parameters', action='store_true', default=False,
  66. help='set find_unused_parameters as True.')
  67. # Debug setting
  68. parser.add_argument('--debug', action='store_true', default=False,
  69. help='debug codes.')
  70. return parser.parse_args()
  71. def fix_random_seed(args):
  72. seed = args.seed + distributed_utils.get_rank()
  73. torch.manual_seed(seed)
  74. np.random.seed(seed)
  75. random.seed(seed)
  76. def main():
  77. args = parse_args()
  78. print("Setting Arguments.. : ", args)
  79. print("----------------------------------------------------------")
  80. # path to save model
  81. path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
  82. os.makedirs(path_to_save, exist_ok=True)
  83. # ---------------------------- Build DDP ----------------------------
  84. distributed_utils.init_distributed_mode(args)
  85. print("git:\n {}\n".format(distributed_utils.get_sha()))
  86. world_size = distributed_utils.get_world_size()
  87. print('World size: {}'.format(world_size))
  88. per_gpu_batch = args.batch_size // world_size
  89. # ---------------------------- Build CUDA ----------------------------
  90. if args.cuda and torch.cuda.is_available():
  91. print('use cuda')
  92. device = torch.device("cuda")
  93. else:
  94. device = torch.device("cpu")
  95. # ---------------------------- Fix random seed ----------------------------
  96. fix_random_seed(args)
  97. # ---------------------------- Build config ----------------------------
  98. cfg = build_config(args)
  99. print('Model config: ', cfg)
  100. # ---------------------------- Build Dataset ----------------------------
  101. transforms = build_transform(cfg, is_train=True)
  102. dataset, dataset_info = build_dataset(args, transforms, is_train=True)
  103. # ---------------------------- Build Dataloader ----------------------------
  104. train_loader = build_dataloader(args, dataset, per_gpu_batch, collate_fn, is_train=True)
  105. # ---------------------------- Build model ----------------------------
  106. ## Build model
  107. model, criterion = build_model(args, cfg, dataset_info['num_classes'], is_val=True)
  108. model.to(device)
  109. model_without_ddp = model
  110. ## Calcute Params & GFLOPs
  111. if distributed_utils.is_main_process():
  112. model_copy = deepcopy(model_without_ddp)
  113. model_copy.trainable = False
  114. model_copy.eval()
  115. compute_flops(model=model_copy,
  116. min_size=cfg['test_min_size'],
  117. max_size=cfg['test_max_size'],
  118. device=device)
  119. del model_copy
  120. if args.distributed:
  121. dist.barrier()
  122. # ---------------------------- Build Optimizer ----------------------------
  123. cfg['base_lr'] = cfg['base_lr'] * args.batch_size
  124. param_dicts = None
  125. if 'param_dict_type' in cfg.keys() and cfg['param_dict_type'] != 'default':
  126. print("- Param dict type: {}".format(cfg['param_dict_type']))
  127. param_dicts = get_param_dict(model_without_ddp, cfg)
  128. optimizer, start_epoch = build_optimizer(cfg, model_without_ddp, param_dicts, args.resume)
  129. # ---------------------------- Build LR Scheduler ----------------------------
  130. wp_lr_scheduler = build_wp_lr_scheduler(cfg, cfg['base_lr'])
  131. lr_scheduler = build_lr_scheduler(cfg, optimizer, args.resume)
  132. # ---------------------------- Build Model EMA ----------------------------
  133. model_ema = None
  134. if 'use_ema' in cfg.keys() and cfg['use_ema']:
  135. print("Build Model EMA for {}".format(args.model))
  136. model_ema = ModelEMA(cfg, model, start_epoch * len(train_loader))
  137. # ---------------------------- Build DDP model ----------------------------
  138. if args.distributed:
  139. model = DDP(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_parameters)
  140. model_without_ddp = model.module
  141. # ---------------------------- Build Evaluator ----------------------------
  142. evaluator = build_evluator(args, cfg, device)
  143. # ----------------------- Eval before training -----------------------
  144. if args.eval_first and distributed_utils.is_main_process():
  145. evaluator.evaluate(model_without_ddp)
  146. return
  147. # ----------------------- Training -----------------------
  148. print("Start training")
  149. best_map = -1.
  150. for epoch in range(start_epoch, cfg['max_epoch']):
  151. if args.distributed:
  152. train_loader.batch_sampler.sampler.set_epoch(epoch)
  153. # Train one epoch
  154. train_one_epoch(cfg,
  155. model,
  156. criterion,
  157. train_loader,
  158. optimizer,
  159. device,
  160. epoch,
  161. args.vis_tgt,
  162. wp_lr_scheduler,
  163. dataset_info['class_labels'],
  164. model_ema=model_ema,
  165. debug=args.debug)
  166. # LR Scheduler
  167. lr_scheduler.step()
  168. # Evaluate
  169. if distributed_utils.is_main_process():
  170. model_eval = model_ema.ema if model_ema is not None else model_without_ddp
  171. if (epoch % args.eval_epoch) == 0 or (epoch == cfg['max_epoch'] - 1):
  172. if evaluator is None:
  173. cur_map = 0.
  174. else:
  175. evaluator.evaluate(model_eval)
  176. cur_map = evaluator.map
  177. # Save model
  178. if cur_map > best_map:
  179. # update best-map
  180. best_map = cur_map
  181. # save model
  182. print('Saving state, epoch:', epoch + 1)
  183. torch.save({'model': model_eval.state_dict(),
  184. 'optimizer': optimizer.state_dict(),
  185. 'lr_scheduler': lr_scheduler.state_dict(),
  186. 'mAP': round(cur_map*100, 1),
  187. 'epoch': epoch,
  188. 'args': args},
  189. os.path.join(path_to_save, '{}_best.pth'.format(args.model)))
  190. if args.distributed:
  191. dist.barrier()
  192. if args.debug:
  193. print("For debug mode, we only train the model with 1 epoch.")
  194. exit(0)
  195. if __name__ == '__main__':
  196. main()