yjh0410 преди 1 година
родител
ревизия
068e3b3ff4
променени са 2 файла, в които са добавени 18 реда и са изтрити 44 реда
  1. 14 40
      models/detectors/rtdetr/basic_modules/basic.py
  2. 4 4
      models/detectors/rtdetr/basic_modules/fpn.py

+ 14 - 40
models/detectors/rtdetr/basic_modules/basic.py

@@ -194,52 +194,26 @@ class BasicConv(nn.Module):
                  stride=1,                 # padding
                  act_type  :str = 'lrelu', # activation
                  norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
                 ):
         super(BasicConv, self).__init__()
         add_bias = False if norm_type else True
-        self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
-        self.norm = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        return self.act(self.norm(self.conv(x)))
-
-class DepthwiseConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                 # in channels
-                 out_dim,                # out channels 
-                 kernel_size=1,          # kernel size 
-                 padding=0,              # padding
-                 stride=1,               # padding
-                 act_type  :str = None,  # activation
-                 norm_type :str = 'BN',  # normalization
-                ):
-        super(DepthwiseConv, self).__init__()
-        assert in_dim == out_dim
-        add_bias = False if norm_type else True
-        self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=out_dim, bias=add_bias)
-        self.norm = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        return self.act(self.norm(self.conv(x)))
-
-class PointwiseConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                ):
-        super(DepthwiseConv, self).__init__()
-        assert in_dim == out_dim
-        add_bias = False if norm_type else True
-        self.conv = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, g=1, bias=add_bias)
-        self.norm = get_norm(norm_type, out_dim)
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+            self.norm2 = get_norm(norm_type, out_dim)
         self.act  = get_activation(act_type)
 
     def forward(self, x):
-        return self.act(self.norm(self.conv(x)))
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            return self.act(self.norm2(self.conv2(self.norm1(self.conv1(x)))))
 
 
 # ----------------- CNN Modules -----------------

+ 4 - 4
models/detectors/rtdetr/basic_modules/fpn.py

@@ -151,20 +151,20 @@ class HybridEncoder(nn.Module):
         ## P5 -> P4
         p5_in = self.reduce_layer_1(p5)
         p5_up = F.interpolate(p5_in, scale_factor=2.0)
-        p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
+        p4    = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
 
         ## P4 -> P3
         p4_in = self.reduce_layer_2(p4)
         p4_up = F.interpolate(p4_in, scale_factor=2.0)
-        p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
+        p3    = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
 
         # -------- Bottom up PAN --------
         ## P3 -> P4
         p3_ds = self.dowmsample_layer_1(p3)
-        p4 = self.bottom_up_layer_1(torch.cat([p4_in, p3_ds], dim=1))
+        p4    = self.bottom_up_layer_1(torch.cat([p4_in, p3_ds], dim=1))
 
         p4_ds = self.dowmsample_layer_2(p4)
-        p5 = self.bottom_up_layer_2(torch.cat([p5_in, p4_ds], dim=1))
+        p5    = self.bottom_up_layer_2(torch.cat([p5_in, p4_ds], dim=1))
 
         out_feats = [p3, p4, p5]