yjh0410 1 rok pred
rodič
commit
efc97238a0

+ 7 - 0
yolo/config/gelan_config.py

@@ -16,10 +16,12 @@ class GElanBaseConfig(object):
         self.max_stride = 32
         self.num_levels = 3
         ## Backbone
+        self.backbone = 'gelan_c'
         self.bk_act = 'silu'
         self.bk_norm = 'BN'
         self.bk_depthwise = False
         self.bk_down_pooling = True
+        self.use_pretrained = True
         self.backbone_feats = {
             "c1": [64],
             "c2": [128, [128, 64],  256],
@@ -27,6 +29,7 @@ class GElanBaseConfig(object):
             "c4": [512, [512, 256], 512],
             "c5": [512, [512, 256], 512],
         }
+        self.scale = "l"
         self.backbone_depth = 1
         ## Neck
         self.neck           = 'spp_elan'
@@ -136,6 +139,10 @@ class GElanBaseConfig(object):
 class GElanCConfig(GElanBaseConfig):
     def __init__(self) -> None:
         super().__init__()
+        self.backbone = 'gelan_c'
+        self.use_pretrained = True
+        self.scale = "l"
+     
         # ---------------- Data process config ----------------
         self.mosaic_prob = 1.0
         self.mixup_prob  = 0.1

+ 47 - 8
yolo/models/gelan/gelan_backbone.py

@@ -6,12 +6,20 @@ try:
 except:
     from  gelan_basic import BasicConv, RepGElanLayer, ADown
 
+# IN1K pretrained weight
+pretrained_urls = {
+    's': None,
+    'm': None,
+    'l': None,
+    'x': None,
+}
 
 # ---------------------------- Basic functions ----------------------------
-class GElanBackbone(nn.Module):
+class GElanCBackbone(nn.Module):
     def __init__(self, cfg):
-        super(GElanBackbone, self).__init__()
+        super(GElanCBackbone, self).__init__()
         # ------------------ Basic setting ------------------
+        self.model_scale = cfg.scale
         self.feat_dims = [cfg.backbone_feats["c1"][-1],  # 64
                           cfg.backbone_feats["c2"][-1],  # 128
                           cfg.backbone_feats["c3"][-1],  # 256
@@ -80,7 +88,11 @@ class GElanBackbone(nn.Module):
 
         # Initialize all layers
         self.init_weights()
-        
+
+        # Load imagenet pretrained weight
+        if cfg.use_pretrained:
+            self.load_pretrained()
+
     def init_weights(self):
         """Initialize the parameters."""
         for m in self.modules():
@@ -89,6 +101,31 @@ class GElanBackbone(nn.Module):
                 # reset the Conv2d initialization parameters
                 m.reset_parameters()
 
+    def load_pretrained(self):
+        url = pretrained_urls[self.model_scale]
+        if url is not None:
+            print('Loading backbone pretrained weight from : {}'.format(url))
+            # checkpoint state dict
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = self.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print('Unused key: ', k)
+            # load the weight
+            self.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
+
     def forward(self, x):
         c1 = self.layer_1(x)
         c2 = self.layer_2(c1)
@@ -104,7 +141,10 @@ class GElanBackbone(nn.Module):
 ## build Yolo's Backbone
 def build_backbone(cfg): 
     # model
-    backbone = GElanBackbone(cfg)
+    if   cfg.backbone == "gelan_c":
+        backbone = GElanCBackbone(cfg)
+    else:
+        raise NotImplementedError("Unknown gelan backbone: {}".format(cfg.backbone))
         
     return backbone
 
@@ -112,12 +152,10 @@ def build_backbone(cfg):
 if __name__ == '__main__':
     import time
     from thop import profile
-    base_config = {
-        "bk_act": "silu",
-        "bk_norm": "BN"
-    }
     class BaseConfig(object):
         def __init__(self) -> None:
+            self.backbone = 'gelan_c'
+            self.use_pretrained = True
             self.bk_act = 'silu'
             self.bk_norm = 'BN'
             self.bk_depthwise = False
@@ -128,6 +166,7 @@ if __name__ == '__main__':
                 "c4": [512, [512, 256], 512],
                 "c5": [512, [512, 256], 512],
             }
+            self.scale = "l"
             self.backbone_depth = 1
 
     cfg = BaseConfig()