Browse Source

update GELAN-S

yjh0410 1 year ago
parent
commit
0df860879c
2 changed files with 5 additions and 3 deletions
  1. 3 1
      yolo/test.py
  2. 2 2
      yolo/utils/misc.py

+ 3 - 1
yolo/test.py

@@ -41,6 +41,8 @@ def parse_args():
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
+    parser.add_argument('--rep_conv', action='store_true', default=False,
+                        help='fuse Rep VGG block')
 
     # Data setting
     parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
@@ -127,7 +129,7 @@ if __name__ == '__main__':
     model = build_model(args, cfg, is_val=False)
 
     # Load trained weight
-    model = load_weight(model, args.weight, args.fuse_conv_bn)
+    model = load_weight(model, args.weight, args.fuse_conv_bn, args.rep_conv)
     model.to(device).eval()
 
     # Compute FLOPs and Params

+ 2 - 2
yolo/utils/misc.py

@@ -343,7 +343,7 @@ def compute_flops(model, img_size, device):
     print('Params : {:.2f} M'.format(params / 1e6))
 
 ## load trained weight
-def load_weight(model, path_to_ckpt, fuse_cbn=False):
+def load_weight(model, path_to_ckpt, fuse_cbn=False, rep_conv=False):
     # Check ckpt file
     if path_to_ckpt is None:
         print('no weight file ...')
@@ -365,7 +365,7 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False):
         model = fuse_conv_bn(model)
 
     # Fuse RepConv
-    if hasattr(model, "switch_to_deploy"):
+    if hasattr(model, "switch_to_deploy") and rep_conv:
         print("Reparam ...")
         model.switch_to_deploy()