yjh0410 2 년 전
부모
커밋
edb51fed2b
2개의 변경된 파일14개의 추가작업 그리고 11개의 파일을 삭제
  1. 2 1
      models/detectors/yolovx/yolovx.py
  2. 12 10
      models/detectors/yolovx/yolovx_head.py

+ 2 - 1
models/detectors/yolovx/yolovx.py

@@ -55,7 +55,8 @@ class YOLOvx(nn.Module):
 
         ## ----------- Preds -----------
         self.pred_layers = build_pred_layer(
-            self.head_dim, self.head_dim, self.stride, num_classes, num_coords=4, num_levels=len(self.stride))
+            self.det_heads.cls_head_dim, self.det_heads.reg_head_dim,
+            self.stride, num_classes, num_coords=4, num_levels=len(self.stride))
 
 
     ## post-process

+ 12 - 10
models/detectors/yolovx/yolovx_head.py

@@ -19,36 +19,36 @@ class SingleLevelHead(nn.Module):
         # --------- Network Parameters ----------
         ## cls head
         cls_feats = []
-        self.cls_out_dim = out_dim
+        self.cls_head_dim = max(out_dim, num_classes)
         for i in range(num_cls_head):
             if i == 0:
                 cls_feats.append(
-                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                    Conv(in_dim, self.cls_head_dim, k=3, p=1, s=1, 
                          act_type=act_type,
                          norm_type=norm_type,
                          depthwise=depthwise)
                         )
             else:
                 cls_feats.append(
-                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                    Conv(self.cls_head_dim, self.cls_head_dim, k=3, p=1, s=1, 
                         act_type=act_type,
                         norm_type=norm_type,
                         depthwise=depthwise)
                         )      
         ## reg head
         reg_feats = []
-        self.reg_out_dim = out_dim
+        self.reg_head_dim = out_dim
         for i in range(num_reg_head):
             if i == 0:
                 reg_feats.append(
-                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                    Conv(in_dim, self.reg_head_dim, k=3, p=1, s=1, 
                          act_type=act_type,
                          norm_type=norm_type,
                          depthwise=depthwise)
                         )
             else:
                 reg_feats.append(
-                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                    Conv(self.reg_head_dim, self.reg_head_dim, k=3, p=1, s=1, 
                          act_type=act_type,
                          norm_type=norm_type,
                          depthwise=depthwise)
@@ -70,10 +70,6 @@ class SingleLevelHead(nn.Module):
 class MultiLevelHead(nn.Module):
     def __init__(self, cfg, in_dims, out_dim, num_classes=80):
         super().__init__()
-        # --------- Basic Parameters ----------
-        self.in_dims = in_dims
-        self.num_classes = num_classes
-
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [SingleLevelHead(
@@ -87,6 +83,12 @@ class MultiLevelHead(nn.Module):
                 cfg['head_depthwise'])
                 for in_dim in in_dims
             ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.num_classes = num_classes
+
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
 
 
     def forward(self, feats):