yjh0410 2 anni fa
parent
commit
dd78b5564b
2 ha cambiato i file con 13 aggiunte e 4 eliminazioni
  1. 4 4
      config/model_config/yolovx_config.py
  2. 9 0
      train.py

+ 4 - 4
config/model_config/yolovx_config.py

@@ -40,7 +40,7 @@ yolovx_cfg = {
         'reg_max': 16,
         # ---------------- Train config ----------------
         ## Input
-        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
         'trans_type': 'yolovx_nano',
         # ---------------- Assignment config ----------------
         ## Matcher
@@ -94,7 +94,7 @@ yolovx_cfg = {
         'reg_max': 16,
         # ---------------- Train config ----------------
         ## Input
-        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
         'trans_type': 'yolovx_nano',
         # ---------------- Assignment config ----------------
         ## Matcher
@@ -148,7 +148,7 @@ yolovx_cfg = {
         'reg_max': 16,
         # ---------------- Train config ----------------
         ## Input
-        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
         'trans_type': 'yolovx_small',
         # ---------------- Assignment config ----------------
         ## Matcher
@@ -202,7 +202,7 @@ yolovx_cfg = {
         'reg_max': 16,
         # ---------------- Train config ----------------
         ## Input
-        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
         'trans_type': 'yolovx_medium',
         # ---------------- Assignment config ----------------
         ## Matcher

+ 9 - 0
train.py

@@ -130,6 +130,15 @@ def train():
 
     # Build Model
     model, criterion = build_model(args, model_cfg, device, data_cfg['num_classes'], True)
+
+    # Keep training
+    if distributed_utils.is_main_process and args.resume is not None:
+        print('keep training: ', args.resume)
+        checkpoint = torch.load(args.resume, map_location='cpu')
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        model.load_state_dict(checkpoint_state_dict)
+
     model = model.to(device).train()
     model_without_ddp = model
     if args.sybn and args.distributed: