yjh0410 2 년 전
부모
커밋
0a597a5333
1개의 변경된 파일11개의 추가작업 그리고 6개의 파일을 삭제
  1. 11 6
      models/detectors/lodet/lodet_basic.py

+ 11 - 6
models/detectors/lodet/lodet_basic.py

@@ -93,13 +93,14 @@ class SMBlock(nn.Module):
         self.expand_ratio = expand_ratio
         self.inter_dim = round(in_dim * expand_ratio)
         # -------------- Network parameters --------------
+        self.cv1 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         ## Scale Modulation
-        self.sm0 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
         self.sm1 = Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm2 = Conv(self.inter_dim, self.inter_dim, k=5, p=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.sm3 = Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         ## Output proj
-        self.cv3 = Conv(self.inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.out_proj = Conv(self.inter_dim*4, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
 
 
     def channel_shuffle(self, x, groups):
@@ -120,13 +121,17 @@ class SMBlock(nn.Module):
 
     def forward(self, x):
         x1, x2 = torch.chunk(x, 2, dim=1)
-        x3 = self.sm1(self.sm0(x2))
+        x1 = self.cv1(x1)
+        x2 = self.cv2(x2)
+
+        x3 = self.sm1(x2)
         x4 = self.sm2(x3)
         x5 = self.sm3(x4)
-        out = torch.cat([x1, x3, x4, x5], dim=1)
-        out = self.cv3(out)
+        out = self.out_proj(torch.cat([x1, x3, x4, x5], dim=1))
+        
+        out = self.channel_shuffle(out, groups=4)
 
-        return self.channel_shuffle(out, groups=4)
+        return out
 
 
 # ---------------------------- FPN Modules ----------------------------