|
|
@@ -2,9 +2,9 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
try:
|
|
|
- from .yolov1_basic import BasicConv
|
|
|
+ from .modules import BasicConv
|
|
|
except:
|
|
|
- from yolov1_basic import BasicConv
|
|
|
+ from modules import BasicConv
|
|
|
|
|
|
|
|
|
class Yolov1DetHead(nn.Module):
|
|
|
@@ -16,60 +16,30 @@ class Yolov1DetHead(nn.Module):
|
|
|
self.reg_head_dim = cfg.head_dim
|
|
|
self.num_cls_head = cfg.num_cls_head
|
|
|
self.num_reg_head = cfg.num_reg_head
|
|
|
- self.act_type = cfg.head_act
|
|
|
- self.norm_type = cfg.head_norm
|
|
|
- self.depthwise = cfg.head_depthwise
|
|
|
|
|
|
# --------- Network Parameters ----------
|
|
|
## cls head
|
|
|
cls_feats = []
|
|
|
for i in range(self.num_cls_head):
|
|
|
if i == 0:
|
|
|
- cls_feats.append(
|
|
|
- BasicConv(in_dim, self.cls_head_dim,
|
|
|
- kernel_size=3, padding=1, stride=1,
|
|
|
- act_type = self.act_type,
|
|
|
- norm_type = self.norm_type,
|
|
|
- depthwise = self.depthwise)
|
|
|
- )
|
|
|
+ cls_feats.append(BasicConv(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
|
|
|
else:
|
|
|
- cls_feats.append(
|
|
|
- BasicConv(self.cls_head_dim, self.cls_head_dim,
|
|
|
- kernel_size=3, padding=1, stride=1,
|
|
|
- act_type = self.act_type,
|
|
|
- norm_type = self.norm_type,
|
|
|
- depthwise = self.depthwise)
|
|
|
- )
|
|
|
+ cls_feats.append(BasicConv(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
|
|
|
## reg head
|
|
|
reg_feats = []
|
|
|
for i in range(self.num_reg_head):
|
|
|
if i == 0:
|
|
|
- reg_feats.append(
|
|
|
- BasicConv(in_dim, self.reg_head_dim,
|
|
|
- kernel_size=3, padding=1, stride=1,
|
|
|
- act_type = self.act_type,
|
|
|
- norm_type = self.norm_type,
|
|
|
- depthwise = self.depthwise)
|
|
|
- )
|
|
|
+ reg_feats.append(BasicConv(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
|
|
|
else:
|
|
|
- reg_feats.append(
|
|
|
- BasicConv(self.reg_head_dim, self.reg_head_dim,
|
|
|
- kernel_size=3, padding=1, stride=1,
|
|
|
- act_type = self.act_type,
|
|
|
- norm_type = self.norm_type,
|
|
|
- depthwise = self.depthwise)
|
|
|
- )
|
|
|
+ reg_feats.append(BasicConv(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
|
|
|
self.cls_feats = nn.Sequential(*cls_feats)
|
|
|
self.reg_feats = nn.Sequential(*reg_feats)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def init_weights(self):
|
|
|
- """Initialize the parameters."""
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, torch.nn.Conv2d):
|
|
|
- # In order to be consistent with the source code,
|
|
|
- # reset the Conv2d initialization parameters
|
|
|
m.reset_parameters()
|
|
|
|
|
|
def forward(self, x):
|
|
|
@@ -92,9 +62,6 @@ if __name__=='__main__':
|
|
|
self.out_stride = 32
|
|
|
self.max_stride = 32
|
|
|
## Head
|
|
|
- self.head_act = 'lrelu'
|
|
|
- self.head_norm = 'BN'
|
|
|
- self.head_depthwise = False
|
|
|
self.head_dim = 256
|
|
|
self.num_cls_head = 2
|
|
|
self.num_reg_head = 2
|