Ver código fonte

add a function to fix seed

冬落 2 anos atrás
pai
commit
92752325d4
1 arquivos alterados com 35 adições e 15 exclusões
  1. 35 15
      train.py

+ 35 - 15
train.py

@@ -1,6 +1,8 @@
 from __future__ import division
 
 import os
+import random
+import numpy as np
 import argparse
 from copy import deepcopy
 
@@ -25,26 +27,33 @@ from engine import build_trainer
 
 def parse_args():
     parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
-    # Basic
+    # Random seed
+    parser.add_argument('--seed', default=42, type=int)
+
+    # GPU
     parser.add_argument('--cuda', action='store_true', default=False,
                         help='use cuda.')
+    
+    # Image size
     parser.add_argument('-size', '--img_size', default=640, type=int, 
                         help='input image size')
-    parser.add_argument('--num_workers', default=4, type=int, 
-                        help='Number of workers used in dataloading')
+    parser.add_argument('--eval_first', action='store_true', default=False,
+                        help='evaluate model before training.')
+    
+    # Outputs
     parser.add_argument('--tfboard', action='store_true', default=False,
                         help='use tensorboard')
     parser.add_argument('--save_folder', default='weights/', type=str, 
                         help='path to save weight')
-    parser.add_argument('--eval_first', action='store_true', default=False,
-                        help='evaluate model before training.')
-    parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
-                        help="Adopting mix precision training.")
     parser.add_argument('--vis_tgt', action="store_true", default=False,
                         help="visualize training data.")
     parser.add_argument('--vis_aux_loss', action="store_true", default=False,
                         help="visualize aux loss.")
     
+    # Mixing precision
+    parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
+                        help="Adopting mix precision training.")
+    
     # Batchsize
     parser.add_argument('-bs', '--batch_size', default=16, type=int, 
                         help='batch size on all the GPUs.')
@@ -82,7 +91,8 @@ def parse_args():
                         help='coco, voc, widerface, crowdhuman')
     parser.add_argument('--load_cache', action='store_true', default=False,
                         help='load data into memory.')
-    
+    parser.add_argument('--num_workers', default=4, type=int, 
+                        help='Number of workers used in dataloading')
     # Train trick
     parser.add_argument('-ms', '--multi_scale', action='store_true', default=False,
                         help='Multi scale')
@@ -114,12 +124,19 @@ def parse_args():
     return parser.parse_args()
 
 
+def fix_random_seed(args):
+    seed = args.seed + distributed_utils.get_rank()
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+
+
 def train():
     args = parse_args()
     print("Setting Arguments.. : ", args)
     print("----------------------------------------------------------")
 
-    # Build DDP
+    # ---------------------------- Build DDP ----------------------------
     local_rank = local_process_rank = -1
     if args.distributed:
         distributed_utils.init_distributed_mode(args)
@@ -136,19 +153,23 @@ def train():
     print("LOCAL_PROCESS_RANL: ", local_process_rank)
     print('WORLD SIZE: {}'.format(world_size))
 
-    # Build CUDA
+    # ---------------------------- Build CUDA ----------------------------
     if args.cuda and torch.cuda.is_available():
         print('use cuda')
         device = torch.device("cuda")
     else:
         device = torch.device("cpu")
 
-    # Build Dataset & Model & Trans. Config
+    # ---------------------------- Fix random seed ----------------------------
+    fix_random_seed(args)
+
+    # ---------------------------- Build config ----------------------------
     data_cfg = build_dataset_config(args)
     model_cfg = build_model_config(args)
     trans_cfg = build_trans_config(model_cfg['trans_type'])
 
-    # Build Model
+    # ---------------------------- Build model ----------------------------
+    ## Build model
     model, criterion = build_model(args, model_cfg, device, data_cfg['num_classes'], True)
     model = model.to(device).train()
     model_without_ddp = model
@@ -158,8 +179,7 @@ def train():
     if args.distributed:
         model = DDP(model, device_ids=[args.gpu])
         model_without_ddp = model.module
-
-    # Calcute Params & GFLOPs
+    ## Calcute Params & GFLOPs
     if distributed_utils.is_main_process:
         model_copy = deepcopy(model_without_ddp)
         model_copy.trainable = False
@@ -171,7 +191,7 @@ def train():
     if args.distributed:
         dist.barrier()
 
-    # Build Trainer
+    # ---------------------------- Build Trainer ----------------------------
     trainer = build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model_without_ddp, criterion, world_size)
 
     # --------------------------------- Train: Start ---------------------------------