yjh0410 2 éve
szülő
commit
ed58a94572
3 módosított fájl, 15 hozzáadás és 5 törlés
  1. 2 2
      config/model_config/yolov7_config.py
  2. 10 0
      test.py
  3. 3 3
      utils/misc.py

+ 2 - 2
config/model_config/yolov7_config.py

@@ -54,7 +54,7 @@ yolov7_cfg = {
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'elannet_large',
-        'pretrained': True,
+        'pretrained': False,
         'bk_act': 'silu',
         'bk_norm': 'BN',
         'bk_dpw': False,
@@ -103,7 +103,7 @@ yolov7_cfg = {
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'elannet_huge',
-        'pretrained': True,
+        'pretrained': False,
         'bk_act': 'silu',
         'bk_norm': 'BN',
         'bk_dpw': False,

+ 10 - 0
test.py

@@ -210,6 +210,16 @@ if __name__ == '__main__':
         device=device)
     del model_copy
 
+    # resave model weight
+    if args.resave:
+        print('Resave: {}'.format(args.model.upper()))
+        checkpoint = torch.load(args.weight, map_location='cpu')
+        checkpoint_path = 'weights/{}/{}/{}_pure.pth'.format(args.dataset, args.model, args.model)
+        torch.save({'model': model.state_dict(),
+                    'mAP': checkpoint.pop("mAP"),
+                    'epoch': checkpoint.pop("epoch")}, 
+                    checkpoint_path)
+        
     print("================= DETECT =================")
     # run
     test(args=args,

+ 3 - 3
utils/misc.py

@@ -180,10 +180,10 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False):
         checkpoint = torch.load(path_to_ckpt, map_location='cpu')
         print('--------------------------------------')
         print('Best model infor:')
-        print('Epoch: {}'.format(checkpoint.pop("epoch")))
-        print('mAP: {}'.format(checkpoint.pop("mAP")))
+        print('Epoch: {}'.format(checkpoint["epoch"]))
+        print('mAP: {}'.format(checkpoint["mAP"]))
         print('--------------------------------------')
-        checkpoint_state_dict = checkpoint.pop("model")
+        checkpoint_state_dict = checkpoint["model"]
         model.load_state_dict(checkpoint_state_dict)
 
         print('Finished loading model!')