소스 검색

modify LODet

yjh0410 2 년 전
부모
커밋
840b09769c
3개의 변경된 파일70개의 추가작업 그리고 27개의 파일을 삭제
  1. 1 1
      config/model_config/lodet_config.py
  2. 8 8
      models/detectors/lodet/lodet_backbone.py
  3. 61 18
      models/detectors/lodet/lodet_basic.py

+ 1 - 1
config/model_config/lodet_config.py

@@ -22,7 +22,7 @@ lodet_cfg = {
     'fpn': 'lodet_pafpn',
     'fpn_core_block': 'smblock',
     'fpn_reduce_layer': 'conv',
-    'fpn_downsample_layer': 'dsblock',
+    'fpn_downsample_layer': 'conv',
     'fpn_act': 'silu',
     'fpn_norm': 'BN',
     'fpn_depthwise': True,

+ 8 - 8
models/detectors/lodet/lodet_backbone.py

@@ -26,23 +26,23 @@ class ScaleModulationNet(nn.Module):
 
         # P2/4
         self.layer_2 = nn.Sequential(   
-            DSBlock(16, 16, act_type, norm_type, depthwise),             
-            SMBlock(16, 32, act_type, norm_type, depthwise)
+            DSBlock(16, act_type, norm_type, depthwise),             
+            SMBlock(32, None, act_type, norm_type, depthwise)
         )
         # P3/8
         self.layer_3 = nn.Sequential(
-            DSBlock(32, 32, act_type, norm_type, depthwise),             
-            SMBlock(32, 64, act_type, norm_type, depthwise)
+            DSBlock(32, act_type, norm_type, depthwise),             
+            SMBlock(64, None, act_type, norm_type, depthwise)
         )
         # P4/16
         self.layer_4 = nn.Sequential(
-            DSBlock(64, 64, act_type, norm_type, depthwise),             
-            SMBlock(64, 128, act_type, norm_type, depthwise)
+            DSBlock(64, act_type, norm_type, depthwise),             
+            SMBlock(128, None, act_type, norm_type, depthwise)
         )
         # P5/32
         self.layer_5 = nn.Sequential(
-            DSBlock(128, 128, act_type, norm_type, depthwise),             
-            SMBlock(128, 256, act_type, norm_type, depthwise)
+            DSBlock(128, act_type, norm_type, depthwise),             
+            SMBlock(256, None, act_type, norm_type, depthwise)
         )
 
 

+ 61 - 18
models/detectors/lodet/lodet_basic.py

@@ -85,11 +85,10 @@ class Conv(nn.Module):
 # ---------------------------- Core Modules ----------------------------
 ## Scale Modulation Block
 class SMBlock(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+    def __init__(self, in_dim, out_dim=None, act_type='silu', norm_type='BN', depthwise=False):
         super(SMBlock, self).__init__()
         # -------------- Basic parameters --------------
         self.in_dim = in_dim
-        self.out_dim = out_dim
         self.inter_dim = in_dim // 2
         # -------------- Network parameters --------------
         self.cv1 = Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
@@ -107,8 +106,13 @@ class SMBlock(nn.Module):
             Conv(self.inter_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type),
             Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
             )
-        ## Output proj
-        self.out_proj = Conv(self.inter_dim*4, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        ## Aggregation proj
+        self.sm_aggregation = Conv(self.inter_dim*3, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+        # Output proj
+        self.out_proj = None
+        if out_dim is not None:
+            self.out_proj = Conv(self.inter_dim*2, out_dim, k=1, act_type=act_type, norm_type=norm_type)
 
 
     def channel_shuffle(self, x, groups):
@@ -128,33 +132,74 @@ class SMBlock(nn.Module):
     
 
     def forward(self, x):
+        """
+        Input:
+            x: (Tensor) -> [B, C_in, H, W]
+        Output:
+            out: (Tensor) -> [B, C_out, H, W]
+        """
         x1, x2 = torch.chunk(x, 2, dim=1)
+        # branch-1
         x1 = self.cv1(x1)
+        # branch-2
         x2 = self.cv2(x2)
+        x2 = torch.cat([self.sm1(x2), self.sm2(x2), self.sm3(x2)], dim=1)
+        x2 = self.sm_aggregation(x2)
+        # channel shuffle
+        out = torch.cat([x1, x2], dim=1)
+        out = self.channel_shuffle(out, groups=2)
 
-        x3 = self.sm1(x2)
-        x4 = self.sm2(x3)
-        x5 = self.sm3(x4)
-        out = self.out_proj(torch.cat([x1, x3, x4, x5], dim=1))
-        
-        out = self.channel_shuffle(out, groups=4)
+        if self.out_proj:
+            out = self.out_proj(out)
 
         return out
 
 ## DownSample Block
 class DSBlock(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+    def __init__(self, in_dim, act_type='silu', norm_type='BN', depthwise=False):
         super().__init__()
+        # branch-1
         self.maxpool = nn.MaxPool2d((2, 2), 2)
-        self.conv = Conv(in_dim//2, in_dim//2, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.out_proj = Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-2
+        inter_dim = in_dim // 2
+        self.sm1 = Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.sm2 = Conv(inter_dim, inter_dim, k=5, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.sm3 = Conv(inter_dim, inter_dim, k=7, p=3, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.sm_aggregation = Conv(inter_dim*3, inter_dim*3, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def channel_shuffle(self, x, groups):
+        # type: (torch.Tensor, int) -> torch.Tensor
+        batchsize, num_channels, height, width = x.data.size()
+        per_group_dim = num_channels // groups
+
+        # reshape
+        x = x.view(batchsize, groups, per_group_dim, height, width)
+
+        x = torch.transpose(x, 1, 2).contiguous()
+
+        # flatten
+        x = x.view(batchsize, -1, height, width)
+
+        return x
+    
 
     def forward(self, x):
+        """
+        Input:
+            x: (Tensor) -> [B, C, H, W]
+        Output:
+            out: (Tensor) -> [B, 2C, H/2, W/2]
+        """
         x1, x2 = torch.chunk(x, 2, dim=1)
+        # branch-1
         x1 = self.maxpool(x1)
-        x2 = self.conv(x2)
+        # branch-2
+        x2 = torch.cat([self.sm1(x2), self.sm2(x2), self.sm3(x2)], dim=1)
+        x2 = self.sm_aggregation(x2)
+        # channel shuffle
         out = torch.cat([x1, x2], dim=1)
-        out = self.out_proj(out)
+        out = self.channel_shuffle(out, groups=4)
 
         return out
 
@@ -182,11 +227,9 @@ def build_reduce_layer(cfg, in_dim, out_dim):
 ## build fpn's downsample layer
 def build_downsample_layer(cfg, in_dim, out_dim):
     if cfg['fpn_downsample_layer'] == 'conv':
-        layer = Conv(in_dim, out_dim, k=3, s=2, p=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+        layer = Conv(in_dim, out_dim, k=3, s=2, p=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'], depthwise=cfg['fpn_depthwise'])
     elif cfg['fpn_downsample_layer'] == 'maxpool':
         assert in_dim == out_dim
         layer = nn.MaxPool2d((2, 2), stride=2)
-    elif cfg['fpn_downsample_layer'] == 'dsblock':
-        layer = DSBlock(in_dim, out_dim, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'], depthwise=cfg['fpn_depthwise'])
         
     return layer