yjh0410 1 år sedan
förälder
incheckning
3ee2312780
1 ändrade filer med 8 tillägg och 15 borttagningar
  1. 8 15
      yolo/utils/misc.py

+ 8 - 15
yolo/utils/misc.py

@@ -335,12 +335,7 @@ def replace_module(module, replaced_module_type, new_module_type, replace_func=N
     return model
 
 ## compute FLOPs & Parameters
-def compute_flops(model, img_size, device):
-    # Reparam
-    for m in model.modules():
-        if hasattr(m, 'fuse_convs'):
-            m.fuse_convs()
-            
+def compute_flops(model, img_size, device):            
     x = torch.randn(1, 3, img_size, img_size).to(device)
     print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)
@@ -349,7 +344,7 @@ def compute_flops(model, img_size, device):
 
 ## load trained weight
 def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
-    # check ckpt file
+    # Check ckpt file
     if path_to_ckpt is None:
         print('no weight file ...')
     else:
@@ -364,18 +359,16 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
 
         print('Finished loading model!')
 
-    # fuse rep conv
-    if fuse_rep_conv:
-        print("Fusing RepConv ...")
-        for m in model.modules():
-            if hasattr(m, 'fuse_convs'):
-                m.fuse_convs()
-
-    # fuse conv & bn
+    # Fuse conv & bn
     if fuse_cbn:
         print('Fusing Conv & BN ...')
         model = fuse_conv_bn(model)
 
+    # Fuse RepConv
+    if hasattr(model, "switch_deploy"):
+        print("Reparam ...")
+        model.switch_deploy()
+
     return model
 
 ## Model EMA