|
|
@@ -7,7 +7,17 @@ except:
|
|
|
from yolov8_basic import Conv, ELAN_CSP_Block
|
|
|
|
|
|
|
|
|
-# ---------------------------- Backbones ----------------------------
|
|
|
+# ---------------------------- ImageNet pretrained weights ----------------------------
|
|
|
+model_urls = {
|
|
|
+ 'elan_cspnet_nano': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_nano.pth",
|
|
|
+ 'elan_cspnet_small': None,
|
|
|
+ 'elan_cspnet_medium': None,
|
|
|
+ 'elan_cspnet_large': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_large.pth",
|
|
|
+ 'elan_cspnet_huge': None,
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+# ---------------------------- Basic functions ----------------------------
|
|
|
## ELAN-CSPNet
|
|
|
class ELAN_CSPNet(nn.Module):
|
|
|
def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
@@ -56,6 +66,36 @@ class ELAN_CSPNet(nn.Module):
|
|
|
|
|
|
|
|
|
# ---------------------------- Functions ----------------------------
|
|
|
+## load pretrained weight
|
|
|
+def load_weight(model, model_name):
|
|
|
+ # load weight
|
|
|
+ print('Loading pretrained weight ...')
|
|
|
+ url = model_urls[model_name]
|
|
|
+ if url is not None:
|
|
|
+ checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
+ url=url, map_location="cpu", check_hash=True)
|
|
|
+ # checkpoint state dict
|
|
|
+ checkpoint_state_dict = checkpoint.pop("model")
|
|
|
+ # model state dict
|
|
|
+ model_state_dict = model.state_dict()
|
|
|
+ # check
|
|
|
+ for k in list(checkpoint_state_dict.keys()):
|
|
|
+ if k in model_state_dict:
|
|
|
+ shape_model = tuple(model_state_dict[k].shape)
|
|
|
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
|
|
+ if shape_model != shape_checkpoint:
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
+ else:
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
+ print(k)
|
|
|
+
|
|
|
+ model.load_state_dict(checkpoint_state_dict)
|
|
|
+ else:
|
|
|
+ print('No pretrained for {}'.format(model_name))
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
## build ELAN-Net
|
|
|
def build_backbone(cfg):
|
|
|
# model
|
|
|
@@ -67,8 +107,20 @@ def build_backbone(cfg):
|
|
|
norm_type=cfg['bk_norm'],
|
|
|
depthwise=cfg['bk_dpw']
|
|
|
)
|
|
|
-
|
|
|
feat_dims = backbone.feat_dims
|
|
|
+
|
|
|
+ # check whether to load imagenet pretrained weight
|
|
|
+ if cfg['pretrained']:
|
|
|
+ if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['ratio'] == 2.0:
|
|
|
+ backbone = load_weight(backbone, model_name='elan_cspnet_nano')
|
|
|
+ elif cfg['width'] == 0.5 and cfg['depth'] == 0.34 and cfg['ratio'] == 2.0:
|
|
|
+ backbone = load_weight(backbone, model_name='elan_cspnet_small')
|
|
|
+ elif cfg['width'] == 0.75 and cfg['depth'] == 0.67 and cfg['ratio'] == 1.5:
|
|
|
+ backbone = load_weight(backbone, model_name='elan_cspnet_medium')
|
|
|
+ elif cfg['width'] == 1.0 and cfg['depth'] == 1.0 and cfg['ratio'] == 1.0:
|
|
|
+ backbone = load_weight(backbone, model_name='elan_cspnet_large')
|
|
|
+ elif cfg['width'] == 1.25 and cfg['depth'] == 1.34 and cfg['ratio'] == 1.0:
|
|
|
+ backbone = load_weight(backbone, model_name='elan_cspnet_huge')
|
|
|
|
|
|
return backbone, feat_dims
|
|
|
|