Browse Source

try to reproduce GELAN-S

yjh0410 1 year ago
parent
commit
c87c471506
3 changed files with 79 additions and 8 deletions
  1. 70 4
      yolo/config/gelan_config.py
  2. 4 4
      yolo/models/gelan/gelan_backbone.py
  3. 5 0
      yolo/utils/misc.py

+ 70 - 4
yolo/config/gelan_config.py

@@ -2,7 +2,9 @@
 
 
 def build_gelan_config(args):
-    if   args.model == 'gelan_c':
+    if   args.model == 'gelan_s':
+        return GElanSConfig()
+    elif args.model == 'gelan_c':
         return GElanCConfig()
     else:
         raise NotImplementedError("No config for model: {}".format(args.model))
@@ -16,7 +18,7 @@ class GElanBaseConfig(object):
         self.max_stride = 32
         self.num_levels = 3
         ## Backbone
-        self.backbone = 'gelan_c'
+        self.backbone = 'gelan'
         self.bk_act = 'silu'
         self.bk_norm = 'BN'
         self.bk_depthwise = False
@@ -139,11 +141,75 @@ class GElanBaseConfig(object):
 class GElanCConfig(GElanBaseConfig):
     def __init__(self) -> None:
         super().__init__()
-        self.backbone = 'gelan_c'
+        self.backbone = 'gelan'
         self.use_pretrained = True
         self.scale = "l"
      
         # ---------------- Data process config ----------------
         self.mosaic_prob = 1.0
         self.mixup_prob  = 0.1
-        self.copy_paste  = 0.5
+        self.copy_paste  = 0.5
+
+# GELAN-S
+class GElanSConfig(GElanBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.reg_max  = 16
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        self.num_levels = 3
+        ## Backbone
+        self.backbone = 'gelan'
+        self.bk_act   = 'silu'
+        self.bk_norm  = 'BN'
+        self.bk_depthwise = False
+        self.use_pretrained = True
+        self.backbone_feats = {
+            "c1": [32],
+            "c2": [64,  [64, 32],   64],
+            "c3": [64,  [64, 32],   128],
+            "c4": [128, [128, 64],  256],
+            "c5": [256, [256, 128], 256],
+        }
+        self.scale = "l"
+        self.backbone_depth = 3
+        ## Neck
+        self.neck           = 'spp_elan'
+        self.neck_act       = 'silu'
+        self.neck_norm      = 'BN'
+        self.spp_pooling_size  = 5
+        self.spp_inter_dim     = 128
+        self.spp_out_dim       = 256
+        ## FPN
+        self.fpn      = 'gelan_pafpn'
+        self.fpn_act  = 'silu'
+        self.fpn_norm = 'BN'
+        self.fpn_depthwise = False
+        self.fpn_depth    = 3
+        self.fpn_feats_td = {
+            "p4": [[256, 128], 256],
+            "p3": [[128, 64],  128],
+        }
+        self.fpn_feats_bu = {
+            "p4": [[256, 128], 256],
+            "p5": [[256, 128], 256],
+        }
+        ## Head
+        self.head      = 'gelan_head'
+        self.head_act  = 'silu'
+        self.head_norm = 'BN'
+        self.head_depthwise = False
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.5           # approximated by the YOLOX's mixup
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))

+ 4 - 4
yolo/models/gelan/gelan_backbone.py

@@ -15,9 +15,9 @@ pretrained_urls = {
 }
 
 # ---------------------------- Basic functions ----------------------------
-class GElanCBackbone(nn.Module):
+class GElanBackbone(nn.Module):
     def __init__(self, cfg):
-        super(GElanCBackbone, self).__init__()
+        super(GElanBackbone, self).__init__()
         # ------------------ Basic setting ------------------
         self.model_scale = cfg.scale
         self.feat_dims = [cfg.backbone_feats["c1"][-1],  # 64
@@ -141,8 +141,8 @@ class GElanCBackbone(nn.Module):
 ## build Yolo's Backbone
 def build_backbone(cfg): 
     # model
-    if   cfg.backbone == "gelan_c":
-        backbone = GElanCBackbone(cfg)
+    if   cfg.backbone == "gelan":
+        backbone = GElanBackbone(cfg)
     else:
         raise NotImplementedError("Unknown gelan backbone: {}".format(cfg.backbone))
         

+ 5 - 0
yolo/utils/misc.py

@@ -336,6 +336,11 @@ def replace_module(module, replaced_module_type, new_module_type, replace_func=N
 
 ## compute FLOPs & Parameters
 def compute_flops(model, img_size, device):
+    # Reparam
+    for m in model.modules():
+        if hasattr(m, 'fuse_convs'):
+            m.fuse_convs()
+            
     x = torch.randn(1, 3, img_size, img_size).to(device)
     print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)