yjh0410 před 2 roky
rodič
revize
656ddb429b

+ 2 - 1
models/detectors/yolovx/yolovx.py

@@ -51,7 +51,8 @@ class YOLOvx(nn.Module):
         self.fpn_dims = self.fpn.out_dim
 
         ## ----------- Heads -----------
-        self.det_heads = build_det_head(cfg, self.fpn_dims, self.head_dim, num_classes)
+        self.det_heads = build_det_head(
+            cfg, self.fpn_dims, self.head_dim, num_classes, num_levels=len(self.stride))
 
         ## ----------- Preds -----------
         self.pred_layers = build_pred_layer(

+ 5 - 5
models/detectors/yolovx/yolovx_head.py

@@ -68,12 +68,12 @@ class SingleLevelHead(nn.Module):
     
 
 class MultiLevelHead(nn.Module):
-    def __init__(self, cfg, in_dims, out_dim, num_classes=80):
+    def __init__(self, cfg, in_dims, out_dim, num_classes=80, num_levels=3):
         super().__init__()
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [SingleLevelHead(
-                in_dim,
+                in_dims[level],
                 out_dim,
                 num_classes,
                 cfg['num_cls_head'],
@@ -81,7 +81,7 @@ class MultiLevelHead(nn.Module):
                 cfg['head_act'],
                 cfg['head_norm'],
                 cfg['head_depthwise'])
-                for in_dim in in_dims
+                for level in range(num_levels)
             ])
         # --------- Basic Parameters ----------
         self.in_dims = in_dims
@@ -108,8 +108,8 @@ class MultiLevelHead(nn.Module):
     
 
 # build detection head
-def build_det_head(cfg, in_dim, out_dim, num_classes=80):
+def build_det_head(cfg, in_dim, out_dim, num_classes=80, num_levels=3):
     if cfg['head'] == 'decoupled_head':
-        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes) 
+        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, num_levels) 
 
     return head