yjh0410 1 year ago
parent
commit
8afe92f802
1 changed files with 15 additions and 13 deletions
  1. 15 13
      yolo/models/gelan/gelan_backbone.py

+ 15 - 13
yolo/models/gelan/gelan_backbone.py

@@ -12,11 +12,11 @@ pretrained_urls = {
     'c': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/gelan_c_in1k_76.7.pth",
 }
 
-# ---------------------------- Basic functions ----------------------------
+# ----------------- GELAN backbone proposed by YOLOv9 -----------------
 class GElanBackbone(nn.Module):
     def __init__(self, cfg):
         super(GElanBackbone, self).__init__()
-        # ------------------ Basic setting ------------------
+        # ---------- Basic setting ----------
         self.model_scale = cfg.scale
         self.feat_dims = [cfg.backbone_feats["c1"][-1],  # 64
                           cfg.backbone_feats["c2"][-1],  # 128
@@ -25,7 +25,7 @@ class GElanBackbone(nn.Module):
                           cfg.backbone_feats["c5"][-1],  # 512
                           ]
         
-        # ------------------ Network setting ------------------
+        # ---------- Network setting ----------
         ## P1/2
         self.layer_1 = BasicConv(3, cfg.backbone_feats["c1"][0],
                                  kernel_size=3, padding=1, stride=2,
@@ -95,8 +95,6 @@ class GElanBackbone(nn.Module):
         """Initialize the parameters."""
         for m in self.modules():
             if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
                 m.reset_parameters()
 
     def load_pretrained(self):
@@ -135,8 +133,7 @@ class GElanBackbone(nn.Module):
         return outputs
 
 
-# ---------------------------- Functions ----------------------------
-## build Yolo's Backbone
+# ------------ Functions ------------
 def build_backbone(cfg): 
     # model
     if   cfg.backbone == "gelan":
@@ -177,20 +174,25 @@ if __name__ == '__main__':
             }
             self.scale = "s"
             self.backbone_depth = 3
-
+    # 定义模型配置文件
     cfg = BaseConfig()
+
+    # 构建GELAN主干网络
     model = build_backbone(cfg)
+
+    # 随机生成输入数据
     x = torch.randn(1, 3, 640, 640)
-    t0 = time.time()
+
+    # 前向推理
     outputs = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
+
+    # 打印输出中的shape
     for out in outputs:
         print(out.shape)
 
-    print('==============================')
+    # 计算模型的参数量和理论计算量
+    print('============ Params & FLOPs ============')
     flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
     print('Params : {:.2f} M'.format(params / 1e6))