|
|
@@ -72,14 +72,14 @@ class SingleLevelHead(nn.Module):
|
|
|
|
|
|
# Multi-level Head
|
|
|
class MultiLevelHead(nn.Module):
|
|
|
- def __init__(self, cfg, in_dims, cls_out_dim, reg_out_dim, num_levels=3):
|
|
|
+ def __init__(self, cfg, in_dims, num_levels=3, num_classes=80, reg_max=16):
|
|
|
super().__init__()
|
|
|
## ----------- Network Parameters -----------
|
|
|
self.multi_level_heads = nn.ModuleList(
|
|
|
[SingleLevelHead(
|
|
|
in_dims[level],
|
|
|
- cls_out_dim, # cls head dim
|
|
|
- reg_out_dim, # reg head dim
|
|
|
+ max(in_dims[0], num_classes), # cls head out_dim
|
|
|
+ max(in_dims[0]//4, 16, 4*reg_max), # reg head out_dim
|
|
|
cfg['num_cls_head'],
|
|
|
cfg['num_reg_head'],
|
|
|
cfg['head_act'],
|
|
|
@@ -111,9 +111,9 @@ class MultiLevelHead(nn.Module):
|
|
|
|
|
|
|
|
|
# build detection head
|
|
|
-def build_det_head(cfg, in_dims, cls_out_dim, reg_out_dim, num_levels=3):
|
|
|
+def build_det_head(cfg, in_dims, num_levels=3, num_classes=80, reg_max=16):
|
|
|
if cfg['head'] == 'decoupled_head':
|
|
|
- head = MultiLevelHead(cfg, in_dims, cls_out_dim, reg_out_dim, num_levels)
|
|
|
+ head = MultiLevelHead(cfg, in_dims, num_levels, num_classes, reg_max)
|
|
|
|
|
|
return head
|
|
|
|
|
|
@@ -134,7 +134,7 @@ if __name__ == '__main__':
|
|
|
cls_out_dim = 256
|
|
|
reg_out_dim = 64
|
|
|
# Head-1
|
|
|
- model = build_det_head(cfg, fpn_dims, cls_out_dim, reg_out_dim, num_levels=3)
|
|
|
+ model = build_det_head(cfg, fpn_dims, num_levels=3, num_classes=80, reg_max=16)
|
|
|
print(model)
|
|
|
fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
|
|
|
t0 = time.time()
|