|
|
@@ -85,11 +85,10 @@ class Conv(nn.Module):
|
|
|
# ---------------------------- Core Modules ----------------------------
|
|
|
## Scale Modulation Block
|
|
|
class SMBlock(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
+ def __init__(self, in_dim, out_dim=None, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super(SMBlock, self).__init__()
|
|
|
# -------------- Basic parameters --------------
|
|
|
self.in_dim = in_dim
|
|
|
- self.out_dim = out_dim
|
|
|
self.inter_dim = in_dim // 2
|
|
|
# -------------- Network parameters --------------
|
|
|
self.cv1 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
@@ -107,8 +106,13 @@ class SMBlock(nn.Module):
|
|
|
Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type),
|
|
|
Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
)
|
|
|
- ## Output proj
|
|
|
- self.out_proj = Conv(self.inter_dim*4, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ ## Aggregation proj
|
|
|
+ self.sm_aggregation = Conv(self.inter_dim*3, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+
|
|
|
+ # Output proj
|
|
|
+ self.out_proj = None
|
|
|
+ if out_dim is not None:
|
|
|
+ self.out_proj = Conv(self.inter_dim*2, out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
|
|
|
def channel_shuffle(self, x, groups):
|
|
|
@@ -128,33 +132,74 @@ class SMBlock(nn.Module):
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
+ """
|
|
|
+ Input:
|
|
|
+ x: (Tensor) -> [B, C_in, H, W]
|
|
|
+ Output:
|
|
|
+ out: (Tensor) -> [B, C_out, H, W]
|
|
|
+ """
|
|
|
x1, x2 = torch.chunk(x, 2, dim=1)
|
|
|
+ # branch-1
|
|
|
x1 = self.cv1(x1)
|
|
|
+ # branch-2
|
|
|
x2 = self.cv2(x2)
|
|
|
+ x2 = torch.cat([self.sm1(x2), self.sm2(x2), self.sm3(x2)], dim=1)
|
|
|
+ x2 = self.sm_aggregation(x2)
|
|
|
+ # channel shuffle
|
|
|
+ out = torch.cat([x1, x2], dim=1)
|
|
|
+ out = self.channel_shuffle(out, groups=2)
|
|
|
|
|
|
- x3 = self.sm1(x2)
|
|
|
- x4 = self.sm2(x3)
|
|
|
- x5 = self.sm3(x4)
|
|
|
- out = self.out_proj(torch.cat([x1, x3, x4, x5], dim=1))
|
|
|
-
|
|
|
- out = self.channel_shuffle(out, groups=4)
|
|
|
+ if self.out_proj:
|
|
|
+ out = self.out_proj(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
## DownSample Block
|
|
|
class DSBlock(nn.Module):
|
|
|
- def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
+ def __init__(self, in_dim, act_type='silu', norm_type='BN', depthwise=False):
|
|
|
super().__init__()
|
|
|
+ # branch-1
|
|
|
self.maxpool = nn.MaxPool2d((2, 2), 2)
|
|
|
- self.conv = Conv(in_dim//2, in_dim//2, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
- self.out_proj = Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ # branch-2
|
|
|
+ inter_dim = in_dim // 2
|
|
|
+ self.sm1 = Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
+ self.sm2 = Conv(inter_dim, inter_dim, k=5, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
+ self.sm3 = Conv(inter_dim, inter_dim, k=7, p=3, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
+ self.sm_aggregation = Conv(inter_dim*3, inter_dim*3, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+
|
|
|
+
|
|
|
+ def channel_shuffle(self, x, groups):
|
|
|
+ # type: (torch.Tensor, int) -> torch.Tensor
|
|
|
+ batchsize, num_channels, height, width = x.data.size()
|
|
|
+ per_group_dim = num_channels // groups
|
|
|
+
|
|
|
+ # reshape
|
|
|
+ x = x.view(batchsize, groups, per_group_dim, height, width)
|
|
|
+
|
|
|
+ x = torch.transpose(x, 1, 2).contiguous()
|
|
|
+
|
|
|
+ # flatten
|
|
|
+ x = x.view(batchsize, -1, height, width)
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
|
|
|
def forward(self, x):
|
|
|
+ """
|
|
|
+ Input:
|
|
|
+ x: (Tensor) -> [B, C, H, W]
|
|
|
+ Output:
|
|
|
+ out: (Tensor) -> [B, 2C, H/2, W/2]
|
|
|
+ """
|
|
|
x1, x2 = torch.chunk(x, 2, dim=1)
|
|
|
+ # branch-1
|
|
|
x1 = self.maxpool(x1)
|
|
|
- x2 = self.conv(x2)
|
|
|
+ # branch-2
|
|
|
+ x2 = torch.cat([self.sm1(x2), self.sm2(x2), self.sm3(x2)], dim=1)
|
|
|
+ x2 = self.sm_aggregation(x2)
|
|
|
+ # channel shuffle
|
|
|
out = torch.cat([x1, x2], dim=1)
|
|
|
- out = self.out_proj(out)
|
|
|
+ out = self.channel_shuffle(out, groups=4)
|
|
|
|
|
|
return out
|
|
|
|
|
|
@@ -182,11 +227,9 @@ def build_reduce_layer(cfg, in_dim, out_dim):
|
|
|
## build fpn's downsample layer
|
|
|
def build_downsample_layer(cfg, in_dim, out_dim):
|
|
|
if cfg['fpn_downsample_layer'] == 'conv':
|
|
|
- layer = Conv(in_dim, out_dim, k=3, s=2, p=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
|
|
|
+ layer = Conv(in_dim, out_dim, k=3, s=2, p=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'], depthwise=cfg['fpn_depthwise'])
|
|
|
elif cfg['fpn_downsample_layer'] == 'maxpool':
|
|
|
assert in_dim == out_dim
|
|
|
layer = nn.MaxPool2d((2, 2), stride=2)
|
|
|
- elif cfg['fpn_downsample_layer'] == 'dsblock':
|
|
|
- layer = DSBlock(in_dim, out_dim, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'], depthwise=cfg['fpn_depthwise'])
|
|
|
|
|
|
return layer
|