Browse Source

modify RepRTCBlock

yjh0410 1 year ago
parent
commit
21edb1acd4

+ 0 - 57
config/model_config/rtdetr_config.py

@@ -117,61 +117,4 @@ rtdetr_cfg = {
         'trainer_type': 'rtdetr',
     },
 
-    'rtdetr_r101':{
-        # ---------------- Model config ----------------
-        ## Image Encoder - Backbone
-        'backbone': 'resnet101',
-        'backbone_norm': 'FrozeBN',
-        'pretrained': True,
-        'pretrained_weight': 'imagenet1k_v2',
-        'freeze_at': 0,
-        'freeze_stem_only': False,
-        'out_stride': [8, 16, 32],
-        'max_stride': 32,
-        ## Image Encoder - FPN
-        'fpn': 'hybrid_encoder',
-        'fpn_num_blocks': 4,
-        'fpn_act': 'silu',
-        'fpn_norm': 'BN',
-        'fpn_depthwise': False,
-        'hidden_dim': 384,
-        'en_num_heads': 8,
-        'en_num_layers': 1,
-        'en_ffn_dim': 2048,
-        'en_dropout': 0.0,
-        'pe_temperature': 10000.,
-        'en_act': 'gelu',
-        # Transformer Decoder
-        'transformer': 'rtdetr_transformer',
-        'de_num_heads': 8,
-        'de_num_layers': 6,
-        'de_ffn_dim': 2048,
-        'de_dropout': 0.0,
-        'de_act': 'relu',
-        'de_num_points': 4,
-        'num_queries': 300,
-        'learnt_init_query': False,
-        'pe_temperature': 10000.,
-        'dn_num_denoising': 100,
-        'dn_label_noise_ratio': 0.5,
-        'dn_box_noise_scale': 1,
-        # Head
-        'det_head': 'dino_head',
-        # ---------------- Assignment config ----------------
-        'matcher_hpy': {'cost_class': 2.0,
-                        'cost_bbox': 5.0,
-                        'cost_giou': 2.0,},
-        # ---------------- Loss config ----------------
-        'use_vfl': True,
-        'loss_coeff': {'class': 1,
-                       'bbox': 5,
-                       'giou': 2,},
-        # ---------------- Train config ----------------
-        ## input
-        'multi_scale': [0.5, 1.25],   # 320 -> 800
-        'trans_type': 'rtdetr_base',
-        # ---------------- Train config ----------------
-        'trainer_type': 'rtdetr',
-    },
-
 }

+ 14 - 11
models/detectors/rtdetr/basic_modules/basic.py

@@ -249,7 +249,7 @@ class RepVggBlock(nn.Module):
         self.in_dim = in_dim
         self.out_dim = out_dim
         self.conv1 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
-        self.conv2 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
+        self.conv2 = BasicConv(in_dim, out_dim, kernel_size=1, padding=0, act_type=None, norm_type=norm_type)
         self.act = get_activation(act_type) 
 
     def forward(self, x):
@@ -305,22 +305,25 @@ class RepRTCBlock(nn.Module):
                  ) -> None:
         super(RepRTCBlock, self).__init__()
         self.inter_dim = round(out_dim * expansion)
-        self.input_proj = BasicConv(in_dim, self.inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.Sequential(*(
-            RepVggBlock(self.inter_dim, self.inter_dim, act_type, norm_type)
-            for _ in range(num_blocks)))
-        self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv1 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv2 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module = nn.ModuleList([RepVggBlock(self.inter_dim, self.inter_dim, act_type, norm_type)
+                                     for _ in range(num_blocks)])
+        self.conv3 = BasicConv(self.inter_dim, out_dim, kernel_size=3, padding=1, act_type=act_type, norm_type=norm_type)
 
     def forward(self, x):
         # Input proj
-        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
-        out = list([x1, x2])
+        x1 = self.conv1(x)
+        x2 = self.conv2(x)
 
-        # Bottlenecl
-        out.extend(m(out[-1]) for m in self.m)
+        # Core module
+        out = [x1]
+        for m in self.module:
+            x2 = m(x2)
+            out.append(x2)
 
         # Output proj
-        out = self.output_proj(torch.cat(out, dim=1))
+        out = self.conv3(sum(out))
 
         return out
 

+ 49 - 34
models/detectors/rtdetr/basic_modules/fpn.py

@@ -62,13 +62,11 @@ class HybridEncoder(nn.Module):
         c3, c4, c5 = in_dims
 
         # ---------------- Input projs ----------------
-        self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.input_proj_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
+        self.input_proj_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
+        self.input_proj_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
 
         # ---------------- Downsample ----------------
-        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
-        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
 
         # ---------------- Transformer Encoder ----------------
         self.transformer_encoder = TransformerEncoder(d_model        = self.out_dim,
@@ -82,38 +80,50 @@ class HybridEncoder(nn.Module):
 
         # ---------------- Top dwon FPN ----------------
         ## P5 -> P4
+        self.reduce_layer_1 = BasicConv(self.out_dim, self.out_dim,
+                                        kernel_size=1, padding=0, stride=1,
+                                        act_type=act_type, norm_type=norm_type)
         self.top_down_layer_1 = RepRTCBlock(in_dim     = self.out_dim * 2,
-                                            out_dim     = self.out_dim,
-                                            num_blocks  = num_blocks,
-                                            expansion   = expansion,
-                                            act_type    = act_type,
-                                            norm_type   = norm_type,
+                                            out_dim    = self.out_dim,
+                                            num_blocks = num_blocks,
+                                            expansion  = expansion,
+                                            act_type   = act_type,
+                                            norm_type  = norm_type,
                                            )
         ## P4 -> P3
+        self.reduce_layer_2 = BasicConv(self.out_dim, self.out_dim,
+                                        kernel_size=1, padding=0, stride=1,
+                                        act_type=act_type, norm_type=norm_type)
         self.top_down_layer_2 = RepRTCBlock(in_dim     = self.out_dim * 2,
-                                            out_dim     = self.out_dim,
-                                            num_blocks  = num_blocks,
-                                            expansion   = expansion,
-                                            act_type    = act_type,
-                                            norm_type   = norm_type,
+                                            out_dim    = self.out_dim,
+                                            num_blocks = num_blocks,
+                                            expansion  = expansion,
+                                            act_type   = act_type,
+                                            norm_type  = norm_type,
                                             )
         
         # ---------------- Bottom up PAN----------------
         ## P3 -> P4
-        self.bottom_up_layer_1 = RepRTCBlock(in_dim      = self.out_dim * 2,
-                                             out_dim     = self.out_dim,
-                                             num_blocks  = num_blocks,
-                                             expansion   = expansion,
-                                             act_type    = act_type,
-                                             norm_type   = norm_type,
+        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=act_type, norm_type=norm_type)
+        self.bottom_up_layer_1 = RepRTCBlock(in_dim     = self.out_dim * 2,
+                                             out_dim    = self.out_dim,
+                                             num_blocks = num_blocks,
+                                             expansion  = expansion,
+                                             act_type   = act_type,
+                                             norm_type  = norm_type,
                                              )
         ## P4 -> P5
-        self.bottom_up_layer_2 = RepRTCBlock(in_dim      = self.out_dim * 2,
-                                             out_dim     = self.out_dim,
-                                             num_blocks  = num_blocks,
-                                             expansion   = expansion,
-                                             act_type    = act_type,
-                                             norm_type   = norm_type,
+        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=act_type, norm_type=norm_type)
+        self.bottom_up_layer_2 = RepRTCBlock(in_dim     = self.out_dim * 2,
+                                             out_dim    = self.out_dim,
+                                             num_blocks = num_blocks,
+                                             expansion  = expansion,
+                                             act_type   = act_type,
+                                             norm_type  = norm_type,
                                              )
 
         self.init_weights()
@@ -130,26 +140,31 @@ class HybridEncoder(nn.Module):
         c3, c4, c5 = features
 
         # -------- Input projs --------
-        p5 = self.reduce_layer_1(c5)
-        p4 = self.reduce_layer_2(c4)
-        p3 = self.reduce_layer_3(c3)
+        p5 = self.input_proj_1(c5)
+        p4 = self.input_proj_2(c4)
+        p3 = self.input_proj_3(c3)
 
         # -------- Transformer encoder --------
         p5 = self.transformer_encoder(p5)
 
         # -------- Top down FPN --------
-        p5_up = F.interpolate(p5, scale_factor=2.0)
+        ## P5 -> P4
+        p5_in = self.reduce_layer_1(p5)
+        p5_up = F.interpolate(p5_in, scale_factor=2.0)
         p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
 
-        p4_up = F.interpolate(p4, scale_factor=2.0)
+        ## P4 -> P3
+        p4_in = self.reduce_layer_2(p4)
+        p4_up = F.interpolate(p4_in, scale_factor=2.0)
         p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
 
         # -------- Bottom up PAN --------
+        ## P3 -> P4
         p3_ds = self.dowmsample_layer_1(p3)
-        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
+        p4 = self.bottom_up_layer_1(torch.cat([p4_in, p3_ds], dim=1))
 
         p4_ds = self.dowmsample_layer_2(p4)
-        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
+        p5 = self.bottom_up_layer_2(torch.cat([p5_in, p4_ds], dim=1))
 
         out_feats = [p3, p4, p5]