Browse Source

redesign SMBLock

yjh0410 2 years ago
parent
commit
704eac0ec6
2 changed files with 8 additions and 12 deletions
  1. 1 1
      models/detectors/lodet/lodet_backbone.py
  2. 7 11
      models/detectors/lodet/lodet_basic.py

+ 1 - 1
models/detectors/lodet/lodet_backbone.py

@@ -39,7 +39,7 @@ class ScaleModulationNet(nn.Module):
         # P5/32
         # P5/32
         self.layer_5 = nn.Sequential(
         self.layer_5 = nn.Sequential(
             nn.MaxPool2d((2, 2), stride=2),             
             nn.MaxPool2d((2, 2), stride=2),             
-            SMBlock(256, 256, 0.25, act_type, norm_type, depthwise)
+            SMBlock(256, 256, 0.5, act_type, norm_type, depthwise)
         )
         )
 
 
 
 

+ 7 - 11
models/detectors/lodet/lodet_basic.py

@@ -93,10 +93,8 @@ class SMBlock(nn.Module):
         self.expand_ratio = expand_ratio
         self.expand_ratio = expand_ratio
         self.inter_dim = round(in_dim * expand_ratio)
         self.inter_dim = round(in_dim * expand_ratio)
         # -------------- Network parameters --------------
         # -------------- Network parameters --------------
-        ## Input proj
-        self.cv1 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         ## Scale Modulation
         ## Scale Modulation
+        self.sm0 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.sm1 = Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm1 = Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm2 = Conv(self.inter_dim, self.inter_dim, k=5, p=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm2 = Conv(self.inter_dim, self.inter_dim, k=5, p=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm3 = Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm3 = Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
@@ -121,16 +119,14 @@ class SMBlock(nn.Module):
     
     
 
 
     def forward(self, x):
     def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.sm1(self.cv2(x))
-        x3 = self.sm2(x2)
-        x4 = self.sm3(x3)
-        out = torch.cat([x1, x2, x3, x4], dim=1)
-        out = self.channel_shuffle(out, groups=4)
-
+        x1, x2 = torch.chunk(x, 2, dim=1)
+        x3 = self.sm1(self.sm0(x2))
+        x4 = self.sm2(x3)
+        x5 = self.sm3(x4)
+        out = torch.cat([x1, x3, x4, x5], dim=1)
         out = self.cv3(out)
         out = self.cv3(out)
 
 
-        return out
+        return self.channel_shuffle(out, groups=4)
 
 
 
 
 # ---------------------------- FPN Modules ----------------------------
 # ---------------------------- FPN Modules ----------------------------