|
|
@@ -165,7 +165,6 @@ class DSBlock(nn.Module):
|
|
|
self.sm1 = Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
self.sm2 = Conv(inter_dim, inter_dim, k=5, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
self.sm3 = Conv(inter_dim, inter_dim, k=7, p=3, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
|
|
|
- self.sm_aggregation = Conv(inter_dim*3, inter_dim*3, k=1, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
|
|
|
def channel_shuffle(self, x, groups):
|
|
|
@@ -196,7 +195,6 @@ class DSBlock(nn.Module):
|
|
|
x1 = self.maxpool(x1)
|
|
|
# branch-2
|
|
|
x2 = torch.cat([self.sm1(x2), self.sm2(x2), self.sm3(x2)], dim=1)
|
|
|
- x2 = self.sm_aggregation(x2)
|
|
|
# channel shuffle
|
|
|
out = torch.cat([x1, x2], dim=1)
|
|
|
out = self.channel_shuffle(out, groups=4)
|