|
|
@@ -93,13 +93,14 @@ class SMBlock(nn.Module):
|
|
|
self.expand_ratio = expand_ratio
|
|
|
self.inter_dim = round(in_dim * expand_ratio)
|
|
|
# -------------- Network parameters --------------
|
|
|
+ self.cv1 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.cv2 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
## Scale Modulation
|
|
|
- self.sm0 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
self.sm1 = Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
self.sm2 = Conv(self.inter_dim, self.inter_dim, k=5, p=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
self.sm3 = Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
## Output proj
|
|
|
- self.cv3 = Conv(self.inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.out_proj = Conv(self.inter_dim*4, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
|
|
|
def channel_shuffle(self, x, groups):
|
|
|
@@ -120,13 +121,17 @@ class SMBlock(nn.Module):
|
|
|
|
|
|
def forward(self, x):
|
|
|
x1, x2 = torch.chunk(x, 2, dim=1)
|
|
|
- x3 = self.sm1(self.sm0(x2))
|
|
|
+ x1 = self.cv1(x1)
|
|
|
+ x2 = self.cv2(x2)
|
|
|
+
|
|
|
+ x3 = self.sm1(x2)
|
|
|
x4 = self.sm2(x3)
|
|
|
x5 = self.sm3(x4)
|
|
|
- out = torch.cat([x1, x3, x4, x5], dim=1)
|
|
|
- out = self.cv3(out)
|
|
|
+ out = self.out_proj(torch.cat([x1, x3, x4, x5], dim=1))
|
|
|
+
|
|
|
+ out = self.channel_shuffle(out, groups=4)
|
|
|
|
|
|
- return self.channel_shuffle(out, groups=4)
|
|
|
+ return out
|
|
|
|
|
|
|
|
|
# ---------------------------- FPN Modules ----------------------------
|