|
|
@@ -6,11 +6,10 @@ from .rtcdet_v2_basic import Conv
|
|
|
|
|
|
# Single-level Head
|
|
|
class SingleLevelHead(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, num_classes, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
|
|
|
+ def __init__(self, in_dim, cls_head_dim, reg_head_dim, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
|
|
|
super().__init__()
|
|
|
# --------- Basic Parameters ----------
|
|
|
self.in_dim = in_dim
|
|
|
- self.num_classes = num_classes
|
|
|
self.num_cls_head = num_cls_head
|
|
|
self.num_reg_head = num_reg_head
|
|
|
self.act_type = act_type
|
|
|
@@ -20,7 +19,7 @@ class SingleLevelHead(nn.Module):
|
|
|
# --------- Network Parameters ----------
|
|
|
## cls head
|
|
|
cls_feats = []
|
|
|
- self.cls_head_dim = max(out_dim, num_classes)
|
|
|
+ self.cls_head_dim = cls_head_dim
|
|
|
for i in range(num_cls_head):
|
|
|
if i == 0:
|
|
|
cls_feats.append(
|
|
|
@@ -38,7 +37,7 @@ class SingleLevelHead(nn.Module):
|
|
|
)
|
|
|
## reg head
|
|
|
reg_feats = []
|
|
|
- self.reg_head_dim = out_dim
|
|
|
+ self.reg_head_dim = reg_head_dim
|
|
|
for i in range(num_reg_head):
|
|
|
if i == 0:
|
|
|
reg_feats.append(
|
|
|
@@ -70,14 +69,14 @@ class SingleLevelHead(nn.Module):
|
|
|
|
|
|
# Multi-level Head
|
|
|
class MultiLevelHead(nn.Module):
|
|
|
- def __init__(self, cfg, in_dims, out_dim, num_classes=80, num_levels=3):
|
|
|
+ def __init__(self, cfg, in_dims, out_dim, num_classes=80, reg_max=16, num_levels=3):
|
|
|
super().__init__()
|
|
|
## ----------- Network Parameters -----------
|
|
|
self.multi_level_heads = nn.ModuleList(
|
|
|
[SingleLevelHead(
|
|
|
in_dims[level],
|
|
|
- out_dim,
|
|
|
- num_classes,
|
|
|
+ max(out_dim, num_classes), # cls head dim
|
|
|
+ max(out_dim//4, 4*reg_max), # reg head dim
|
|
|
cfg['num_cls_head'],
|
|
|
cfg['num_reg_head'],
|
|
|
cfg['head_act'],
|
|
|
@@ -110,8 +109,8 @@ class MultiLevelHead(nn.Module):
|
|
|
|
|
|
|
|
|
# build detection head
|
|
|
-def build_det_head(cfg, in_dim, out_dim, num_classes=80, num_levels=3):
|
|
|
+def build_det_head(cfg, in_dim, out_dim, num_classes=80, reg_max=16, num_levels=3):
|
|
|
if cfg['head'] == 'decoupled_head':
|
|
|
- head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, num_levels)
|
|
|
+ head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, reg_max, num_levels)
|
|
|
|
|
|
return head
|