|
|
@@ -38,7 +38,7 @@ class CenterNet(nn.Module):
|
|
|
self.deploy = deploy
|
|
|
self.no_multi_labels = no_multi_labels
|
|
|
self.nms_class_agnostic = nms_class_agnostic
|
|
|
- self.head_dims = [round(512 * cfg['width']), round(256 * cfg['width']), round(128 * cfg['width'])]
|
|
|
+ self.head_dim = round(256 * cfg['width'])
|
|
|
|
|
|
# ---------------- Network Parameters ----------------
|
|
|
## Encoder
|
|
|
@@ -49,17 +49,17 @@ class CenterNet(nn.Module):
|
|
|
self.feat_dim = self.neck.out_dim
|
|
|
|
|
|
## Decoder
|
|
|
- self.decoder = build_decoder(cfg, self.feat_dim, self.head_dims)
|
|
|
+ self.decoder = build_decoder(cfg, self.feat_dim, self.head_dim)
|
|
|
|
|
|
## Head
|
|
|
self.det_head = nn.Sequential(
|
|
|
- build_det_head(cfg, self.head_dims[-1], self.head_dims[-1]),
|
|
|
- build_det_pred(self.head_dims[-1], self.head_dims[-1], self.stride, num_classes, 4)
|
|
|
+ build_det_head(cfg, self.head_dim, self.head_dim),
|
|
|
+ build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
|
|
|
)
|
|
|
## Aux Head
|
|
|
self.aux_det_head = nn.Sequential(
|
|
|
- build_det_head(cfg, self.head_dims[-1], self.head_dims[-1]),
|
|
|
- build_det_pred(self.head_dims[-1], self.head_dims[-1], self.stride, num_classes, 4)
|
|
|
+ build_det_head(cfg, self.head_dim, self.head_dim),
|
|
|
+ build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
|
|
|
)
|
|
|
|
|
|
# Post process
|