|
|
@@ -71,16 +71,6 @@ class DetHead(nn.Module):
|
|
|
self.cls_feats = nn.Sequential(*cls_feats)
|
|
|
self.reg_feats = nn.Sequential(*reg_feats)
|
|
|
|
|
|
- self.init_weights()
|
|
|
-
|
|
|
- def init_weights(self):
|
|
|
- """Initialize the parameters."""
|
|
|
- for m in self.modules():
|
|
|
- if isinstance(m, torch.nn.Conv2d):
|
|
|
- # In order to be consistent with the source code,
|
|
|
- # reset the Conv2d initialization parameters
|
|
|
- m.reset_parameters()
|
|
|
-
|
|
|
def forward(self, x):
|
|
|
"""
|
|
|
in_feats: (Tensor) [B, C, H, W]
|
|
|
@@ -111,6 +101,16 @@ class Yolov5DetHead(nn.Module):
|
|
|
self.cls_head_dim = cfg.head_dim
|
|
|
self.reg_head_dim = cfg.head_dim
|
|
|
|
|
|
+ # Initialize all layers
|
|
|
+ self.init_weights()
|
|
|
+
|
|
|
+ def init_weights(self):
|
|
|
+ """Initialize the parameters."""
|
|
|
+ for m in self.modules():
|
|
|
+ if isinstance(m, torch.nn.Conv2d):
|
|
|
+ # In order to be consistent with the source code,
|
|
|
+ # reset the Conv2d initialization parameters
|
|
|
+ m.reset_parameters()
|
|
|
|
|
|
def forward(self, feats):
|
|
|
"""
|