浏览代码

keep training YOLOv5-L from 211 epoch

yjh0410 2 年之前
父节点
当前提交
40891e6e48
共有 1 个文件被更改,包括 19 次插入19 次删除
  1. 19 19
      models/__init__.py

+ 19 - 19
models/__init__.py

@@ -61,32 +61,18 @@ def build_model(args,
             # model state dict
             model_state_dict = model.state_dict()
             # check
-            new_checkpoint_state_dict = {}
-
             for k in list(checkpoint_state_dict.keys()):
-                v = checkpoint_state_dict[k]
-                if 'reduce_layer_3' in k:
-                    k_new = k.split('.')
-                    k_new[1] = 'downsample_layer_1'
-                    k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
-                elif 'reduce_layer_4' in k:
-                    k_new = k.split('.')
-                    k_new[1] = 'downsample_layer_2'
-                    k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
-                new_checkpoint_state_dict[k] = v
-
-            for k in list(new_checkpoint_state_dict.keys()):
                 if k in model_state_dict:
                     shape_model = tuple(model_state_dict[k].shape)
-                    shape_checkpoint = tuple(new_checkpoint_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
                     if shape_model != shape_checkpoint:
-                        new_checkpoint_state_dict.pop(k)
+                        checkpoint_state_dict.pop(k)
                         print(k)
                 else:
-                    new_checkpoint_state_dict.pop(k)
+                    checkpoint_state_dict.pop(k)
                     print(k)
 
-            model.load_state_dict(new_checkpoint_state_dict, strict=False)
+            model.load_state_dict(checkpoint_state_dict, strict=False)
 
         # keep training
         if args.resume is not None:
@@ -94,7 +80,21 @@ def build_model(args,
             checkpoint = torch.load(args.resume, map_location='cpu')
             # checkpoint state dict
             checkpoint_state_dict = checkpoint.pop("model")
-            model.load_state_dict(checkpoint_state_dict)
+            # check
+            new_checkpoint_state_dict = {}
+
+            for k in list(checkpoint_state_dict.keys()):
+                v = checkpoint_state_dict[k]
+                if 'reduce_layer_3' in k:
+                    k_new = k.split('.')
+                    k_new[1] = 'downsample_layer_1'
+                    k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
+                elif 'reduce_layer_4' in k:
+                    k_new = k.split('.')
+                    k_new[1] = 'downsample_layer_2'
+                    k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
+                new_checkpoint_state_dict[k] = v
+            model.load_state_dict(new_checkpoint_state_dict)
 
         return model, criterion