yjh0410 2 年之前
父节点
当前提交
702a36acf2
共有 2 个文件被更改,包括 4 次插入4 次删除
  1. 2 2
      models/detectors/yolovx/yolovx_head.py
  2. 2 2
      models/detectors/yolovx/yolovx_pred.py

+ 2 - 2
models/detectors/yolovx/yolovx_head.py

@@ -75,7 +75,7 @@ class MultiLevelHead(nn.Module):
         self.num_classes = num_classes
 
         ## ----------- Network Parameters -----------
-        self.det_heads = nn.ModuleList(
+        self.multi_level_heads = nn.ModuleList(
             [SingleLevelHead(
                 in_dim,
                 out_dim,
@@ -95,7 +95,7 @@ class MultiLevelHead(nn.Module):
         """
         cls_feats = []
         reg_feats = []
-        for feat, head in zip(feats, self.det_heads):
+        for feat, head in zip(feats, self.multi_level_heads):
             # ---------------- Pred ----------------
             cls_feat, reg_feat = head(feat)
 

+ 2 - 2
models/detectors/yolovx/yolovx_pred.py

@@ -51,7 +51,7 @@ class SingleLevelPredLayer(nn.Module):
         return obj_pred, cls_pred, reg_pred
     
 
-class MultiLevelHead(nn.Module):
+class MultiLevelPredLayer(nn.Module):
     def __init__(self, cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
         super().__init__()
         # --------- Basic Parameters ----------
@@ -138,6 +138,6 @@ class MultiLevelHead(nn.Module):
 
 # build detection head
 def build_pred_layer(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
-    pred_layers = MultiLevelHead(cls_dim, reg_dim, strides, num_classes, num_coords, num_levels) 
+    pred_layers = MultiLevelPredLayer(cls_dim, reg_dim, strides, num_classes, num_coords, num_levels) 
 
     return pred_layers