|
|
@@ -12,11 +12,11 @@ pretrained_urls = {
|
|
|
'c': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/gelan_c_in1k_76.7.pth",
|
|
|
}
|
|
|
|
|
|
-# ---------------------------- Basic functions ----------------------------
|
|
|
+# ----------------- GELAN backbone proposed by YOLOv9 -----------------
|
|
|
class GElanBackbone(nn.Module):
|
|
|
def __init__(self, cfg):
|
|
|
super(GElanBackbone, self).__init__()
|
|
|
- # ------------------ Basic setting ------------------
|
|
|
+ # ---------- Basic setting ----------
|
|
|
self.model_scale = cfg.scale
|
|
|
self.feat_dims = [cfg.backbone_feats["c1"][-1], # 64
|
|
|
cfg.backbone_feats["c2"][-1], # 128
|
|
|
@@ -25,7 +25,7 @@ class GElanBackbone(nn.Module):
|
|
|
cfg.backbone_feats["c5"][-1], # 512
|
|
|
]
|
|
|
|
|
|
- # ------------------ Network setting ------------------
|
|
|
+ # ---------- Network setting ----------
|
|
|
## P1/2
|
|
|
self.layer_1 = BasicConv(3, cfg.backbone_feats["c1"][0],
|
|
|
kernel_size=3, padding=1, stride=2,
|
|
|
@@ -95,8 +95,6 @@ class GElanBackbone(nn.Module):
|
|
|
"""Initialize the parameters."""
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, torch.nn.Conv2d):
|
|
|
- # In order to be consistent with the source code,
|
|
|
- # reset the Conv2d initialization parameters
|
|
|
m.reset_parameters()
|
|
|
|
|
|
def load_pretrained(self):
|
|
|
@@ -135,8 +133,7 @@ class GElanBackbone(nn.Module):
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
-# ---------------------------- Functions ----------------------------
|
|
|
-## build Yolo's Backbone
|
|
|
+# ------------ Functions ------------
|
|
|
def build_backbone(cfg):
|
|
|
# model
|
|
|
if cfg.backbone == "gelan":
|
|
|
@@ -177,20 +174,25 @@ if __name__ == '__main__':
|
|
|
}
|
|
|
self.scale = "s"
|
|
|
self.backbone_depth = 3
|
|
|
-
|
|
|
+ # 定义模型配置文件
|
|
|
cfg = BaseConfig()
|
|
|
+
|
|
|
+ # 构建GELAN主干网络
|
|
|
model = build_backbone(cfg)
|
|
|
+
|
|
|
+ # 随机生成输入数据
|
|
|
x = torch.randn(1, 3, 640, 640)
|
|
|
- t0 = time.time()
|
|
|
+
|
|
|
+ # 前向推理
|
|
|
outputs = model(x)
|
|
|
- t1 = time.time()
|
|
|
- print('Time: ', t1 - t0)
|
|
|
+
|
|
|
+ # 打印输出中的shape
|
|
|
for out in outputs:
|
|
|
print(out.shape)
|
|
|
|
|
|
- print('==============================')
|
|
|
+ # 计算模型的参数量和理论计算量
|
|
|
+ print('============ Params & FLOPs ============')
|
|
|
flops, params = profile(model, inputs=(x, ), verbose=False)
|
|
|
- print('==============================')
|
|
|
print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
|
|
|
print('Params : {:.2f} M'.format(params / 1e6))
|
|
|
|