|
|
@@ -68,12 +68,12 @@ class SingleLevelHead(nn.Module):
|
|
|
|
|
|
|
|
|
class MultiLevelHead(nn.Module):
|
|
|
- def __init__(self, cfg, in_dims, out_dim, num_classes=80):
|
|
|
+ def __init__(self, cfg, in_dims, out_dim, num_classes=80, num_levels=3):
|
|
|
super().__init__()
|
|
|
## ----------- Network Parameters -----------
|
|
|
self.multi_level_heads = nn.ModuleList(
|
|
|
[SingleLevelHead(
|
|
|
- in_dim,
|
|
|
+ in_dims[level],
|
|
|
out_dim,
|
|
|
num_classes,
|
|
|
cfg['num_cls_head'],
|
|
|
@@ -81,7 +81,7 @@ class MultiLevelHead(nn.Module):
|
|
|
cfg['head_act'],
|
|
|
cfg['head_norm'],
|
|
|
cfg['head_depthwise'])
|
|
|
- for in_dim in in_dims
|
|
|
+ for level in range(num_levels)
|
|
|
])
|
|
|
# --------- Basic Parameters ----------
|
|
|
self.in_dims = in_dims
|
|
|
@@ -108,8 +108,8 @@ class MultiLevelHead(nn.Module):
|
|
|
|
|
|
|
|
|
# build detection head
|
|
|
-def build_det_head(cfg, in_dim, out_dim, num_classes=80):
|
|
|
+def build_det_head(cfg, in_dim, out_dim, num_classes=80, num_levels=3):
|
|
|
if cfg['head'] == 'decoupled_head':
|
|
|
- head = MultiLevelHead(cfg, in_dim, out_dim, num_classes)
|
|
|
+ head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, num_levels)
|
|
|
|
|
|
return head
|