|
|
@@ -335,12 +335,7 @@ def replace_module(module, replaced_module_type, new_module_type, replace_func=N
|
|
|
return model
|
|
|
|
|
|
## compute FLOPs & Parameters
|
|
|
-def compute_flops(model, img_size, device):
|
|
|
- # Reparam
|
|
|
- for m in model.modules():
|
|
|
- if hasattr(m, 'fuse_convs'):
|
|
|
- m.fuse_convs()
|
|
|
-
|
|
|
+def compute_flops(model, img_size, device):
|
|
|
x = torch.randn(1, 3, img_size, img_size).to(device)
|
|
|
print('==============================')
|
|
|
flops, params = profile(model, inputs=(x, ), verbose=False)
|
|
|
@@ -349,7 +344,7 @@ def compute_flops(model, img_size, device):
|
|
|
|
|
|
## load trained weight
|
|
|
def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
|
|
|
- # check ckpt file
|
|
|
+ # Check ckpt file
|
|
|
if path_to_ckpt is None:
|
|
|
print('no weight file ...')
|
|
|
else:
|
|
|
@@ -364,18 +359,16 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
|
|
|
|
|
|
print('Finished loading model!')
|
|
|
|
|
|
- # fuse rep conv
|
|
|
- if fuse_rep_conv:
|
|
|
- print("Fusing RepConv ...")
|
|
|
- for m in model.modules():
|
|
|
- if hasattr(m, 'fuse_convs'):
|
|
|
- m.fuse_convs()
|
|
|
-
|
|
|
- # fuse conv & bn
|
|
|
+ # Fuse conv & bn
|
|
|
if fuse_cbn:
|
|
|
print('Fusing Conv & BN ...')
|
|
|
model = fuse_conv_bn(model)
|
|
|
|
|
|
+ # Fuse RepConv
|
|
|
+ if hasattr(model, "switch_deploy"):
|
|
|
+ print("Reparam ...")
|
|
|
+ model.switch_deploy()
|
|
|
+
|
|
|
return model
|
|
|
|
|
|
## Model EMA
|