Bläddra i källkod

modify the image encoder of RT-DETR

yjh0410 2 år sedan
förälder
incheckning
ed5dc01d8c

+ 0 - 16
models/detectors/rtdetr/image_encoder/cnn_basic.py

@@ -115,18 +115,11 @@ class ELANBlock(nn.Module):
 
 
     def forward(self, x):
-        """
-        Input:
-            x: [B, C_in, H, W]
-        Output:
-            out: [B, C_out, H, W]
-        """
         x1 = self.cv1(x)
         x2 = self.cv2(x)
         x3 = self.cv3(x2)
         x4 = self.cv4(x3)
 
-        # [B, C, H, W] -> [B, 2C, H, W]
         out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
 
         return out
@@ -144,17 +137,8 @@ class DownSample(nn.Module):
         )
 
     def forward(self, x):
-        """
-        Input:
-            x: [B, C, H, W]
-        Output:
-            out: [B, C, H//2, W//2]
-        """
-        # [B, C, H, W] -> [B, C//2, H//2, W//2]
         x1 = self.cv1(self.mp(x))
         x2 = self.cv2(x)
-
-        # [B, C, H//2, W//2]
         out = torch.cat([x1, x2], dim=1)
 
         return out

+ 44 - 40
models/detectors/rtdetr/image_encoder/cnn_pafpn.py

@@ -6,69 +6,73 @@ from .cnn_basic import (Conv, build_reduce_layer, build_downsample_layer, build_
 
 
 # YOLO-Style PaFPN
-class YoloPaFPN(nn.Module):
-    def __init__(self, cfg, in_dims=[256, 512, 1024], out_dim=None):
-        super(YoloPaFPN, self).__init__()
+class YolovxPaFPN(nn.Module):
+    def __init__(self, cfg, in_dims=[512, 1024, 1024], out_dim=None, input_proj=False):
+        super(YolovxPaFPN, self).__init__()
         # --------------------------- Basic Parameters ---------------------------
         self.in_dims = in_dims
-        c3, c4, c5 = in_dims
-        width = cfg['width']
-
-        # --------------------------- Network Parameters ---------------------------
-        ## top dwon
-        ### P5 -> P4
-        self.reduce_layer_1 = build_reduce_layer(cfg, c5, round(512*width))
-        self.reduce_layer_2 = build_reduce_layer(cfg, c4, round(512*width))
-        self.top_down_layer_1 = build_fpn_block(cfg, round(512*width) + round(512*width), round(512*width))
-
-        ### P4 -> P3
-        self.reduce_layer_3 = build_reduce_layer(cfg, round(512*width), round(256*width))
-        self.reduce_layer_4 = build_reduce_layer(cfg, c3, round(256*width))
-        self.top_down_layer_2 = build_fpn_block(cfg, round(256*width) + round(256*width), round(256*width))
-
-        ## bottom up
-        ### P3 -> P4
-        self.downsample_layer_1 = build_downsample_layer(cfg, round(256*width), round(256*width))
-        self.bottom_up_layer_1 = build_fpn_block(cfg, round(256*width) + round(256*width), round(512*width))
-
-        ### P4 -> P5
-        self.downsample_layer_2 = build_downsample_layer(cfg, round(512*width), round(512*width))
-        self.bottom_up_layer_2 = build_fpn_block(cfg, round(512*width) + round(512*width), round(1024*width))
+        if input_proj:
+            self.fpn_dims = [round(256*cfg['width']), round(512*cfg['width']), round(1024*cfg['width'])]
+        else:
+            self.fpn_dims = in_dims
+
+        # --------------------------- Input proj ---------------------------
+        self.input_projs = nn.ModuleList([nn.Conv2d(in_dim, fpn_dim, kernel_size=1)
+                                          for in_dim, fpn_dim in zip(in_dims, self.fpn_dims)])
+        
+        # --------------------------- Top-down FPN---------------------------
+        ## P5 -> P4
+        self.reduce_layer_1 = build_reduce_layer(cfg, self.fpn_dims[2], self.fpn_dims[2]//2)
+        self.top_down_layer_1 = build_fpn_block(cfg, self.fpn_dims[1] + self.fpn_dims[2]//2, self.fpn_dims[1])
+
+        ## P4 -> P3
+        self.reduce_layer_2 = build_reduce_layer(cfg, self.fpn_dims[1], self.fpn_dims[1]//2)
+        self.top_down_layer_2 = build_fpn_block(cfg, self.fpn_dims[0] + self.fpn_dims[1]//2, self.fpn_dims[0])
+
+        # --------------------------- Bottom-up FPN ---------------------------
+        ## P3 -> P4
+        self.downsample_layer_1 = build_downsample_layer(cfg, self.fpn_dims[0], self.fpn_dims[0])
+        self.bottom_up_layer_1 = build_fpn_block(cfg, self.fpn_dims[0] + self.fpn_dims[1]//2, self.fpn_dims[1])
+
+        ## P4 -> P5
+        self.downsample_layer_2 = build_downsample_layer(cfg, self.fpn_dims[1], self.fpn_dims[1])
+        self.bottom_up_layer_2 = build_fpn_block(cfg, self.fpn_dims[1] + self.fpn_dims[2]//2, self.fpn_dims[2])
                 
-        ## output proj layers
+        # --------------------------- Output proj ---------------------------
         if out_dim is not None:
             self.out_layers = nn.ModuleList([
                 Conv(in_dim, out_dim, k=1,
                      act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
-                     for in_dim in [round(256*width), round(512*width), round(1024*width)]
+                     for in_dim in self.fpn_dims
                      ])
             self.out_dim = [out_dim] * 3
         else:
             self.out_layers = None
-            self.out_dim = [round(256*width), round(512*width), round(1024*width)]
+            self.out_dim = self.fpn_dims
 
 
     def forward(self, features):
-        c3, c4, c5 = features
+        fpn_feats = [layer(feat) for feat, layer in zip(features, self.input_projs)]
+        c3, c4, c5 = fpn_feats
 
         # Top down
         ## P5 -> P4
         c6 = self.reduce_layer_1(c5)
-        c7 = self.reduce_layer_2(c4)
-        c8 = torch.cat([F.interpolate(c6, scale_factor=2.0), c7], dim=1)
+        c7 = F.interpolate(c6, scale_factor=2.0)
+        c8 = torch.cat([c7, c4], dim=1)
         c9 = self.top_down_layer_1(c8)
         ## P4 -> P3
-        c10 = self.reduce_layer_3(c9)
-        c11 = self.reduce_layer_4(c3)
-        c12 = torch.cat([F.interpolate(c10, scale_factor=2.0), c11], dim=1)
+        c10 = self.reduce_layer_2(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)
+        c12 = torch.cat([c11, c3], dim=1)
         c13 = self.top_down_layer_2(c12)
 
         # Bottom up
-        # p3 -> P4
+        ## p3 -> P4
         c14 = self.downsample_layer_1(c13)
         c15 = torch.cat([c14, c10], dim=1)
         c16 = self.bottom_up_layer_1(c15)
-        # P4 -> P5
+        ## P4 -> P5
         c17 = self.downsample_layer_2(c16)
         c18 = torch.cat([c17, c6], dim=1)
         c19 = self.bottom_up_layer_2(c18)
@@ -85,10 +89,10 @@ class YoloPaFPN(nn.Module):
         return out_feats
 
 
-def build_fpn(cfg, in_dims, out_dim=None):
+def build_fpn(cfg, in_dims, out_dim=None, input_proj=False):
     model = cfg['fpn']
     # build pafpn
-    if model == 'yolo_pafpn':
-        fpn_net = YoloPaFPN(cfg, in_dims, out_dim)
+    if model == 'YolovxPaFPN':
+        fpn_net = YolovxPaFPN(cfg, in_dims, out_dim, input_proj)
 
     return fpn_net