Browse Source

debug rtpdetr_r50

yjh0410 1 year ago
parent
commit
fc1fe0b69f

+ 4 - 4
config/model_config/rtpdetr_config.py

@@ -19,14 +19,14 @@ rtpdetr_cfg = {
         'hidden_dim': 256,
         'en_num_heads': 8,
         'en_num_layers': 6,
-        'en_ffn_dim': 1024,
+        'en_ffn_dim': 2048,
         'en_dropout': 0.0,
         'en_act': 'gelu',
         # Transformer Decoder
         'transformer': 'plain_detr_transformer',
         'de_num_heads': 8,
         'de_num_layers': 6,
-        'de_ffn_dim': 1024,
+        'de_ffn_dim': 2048,
         'de_dropout': 0.0,
         'de_act': 'gelu',
         'de_pre_norm': True,
@@ -49,9 +49,9 @@ rtpdetr_cfg = {
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
-        'trans_type': 'rtdetr_base',
+        'trans_type': 'rtdetr_l',
         # ---------------- Train config ----------------
-        'trainer_type': 'rtpdetr',
+        'trainer_type': 'rtdetr',
     },
 
 }

+ 2 - 2
models/detectors/rtpdetr/rtpdetr.py

@@ -368,7 +368,7 @@ if __name__ == '__main__':
         'hidden_dim': 256,
         'en_num_heads': 8,
         'en_num_layers': 6,
-        'en_mlp_ratio': 4.0,
+        'en_ffn_dim': 2048,
         'en_dropout': 0.0,
         'en_act': 'gelu',
         # Transformer Decoder
@@ -376,7 +376,7 @@ if __name__ == '__main__':
         'hidden_dim': 256,
         'de_num_heads': 8,
         'de_num_layers': 6,
-        'de_mlp_ratio': 4.0,
+        'de_ffn_dim': 2048,
         'de_dropout': 0.0,
         'de_act': 'gelu',
         'de_pre_norm': True,

+ 10 - 10
models/detectors/rtpdetr/rtpdetr_decoder.py

@@ -25,7 +25,7 @@ def build_transformer(cfg, return_intermediate=False):
                                     return_intermediate = return_intermediate,
                                     use_checkpoint      = cfg['use_checkpoint'],
                                     num_queries_one2one = cfg['num_queries_one2one'],
-                                    num_queries_one2many = cfg['num_queries_one2many'],
+                                    num_queries_one2many    = cfg['num_queries_one2many'],
                                     proposal_feature_levels = cfg['proposal_feature_levels'],
                                     proposal_in_stride      = cfg['out_stride'],
                                     proposal_tgt_strides    = cfg['proposal_tgt_strides'],
@@ -39,7 +39,7 @@ class PlainDETRTransformer(nn.Module):
                  # Decoder layer params
                  d_model        :int   = 256,
                  num_heads      :int   = 8,
-                 ffn_dim        :int = 1024,
+                 ffn_dim        :int   = 1024,
                  dropout        :float = 0.1,
                  act_type       :str   = "relu",
                  pre_norm       :bool  = False,
@@ -47,13 +47,13 @@ class PlainDETRTransformer(nn.Module):
                  feature_stride :int   = 16,
                  num_layers     :int   = 6,
                  # Decoder params
-                 return_intermediate :bool = False,
-                 use_checkpoint      :bool = False,
-                 num_queries_one2one :int = 300,
-                 num_queries_one2many :int = 1500,
-                 proposal_feature_levels :int = 3,
-                 proposal_in_stride      :int = 16,
-                 proposal_tgt_strides    :int = [8, 16, 32],
+                 return_intermediate     :bool = False,
+                 use_checkpoint          :bool = False,
+                 num_queries_one2one     :int  = 300,
+                 num_queries_one2many    :int  = 1500,
+                 proposal_feature_levels :int  = 3,
+                 proposal_in_stride      :int  = 16,
+                 proposal_tgt_strides    :int  = [8, 16, 32],
                  ):
         super().__init__()
         # ------------ Basic setting ------------
@@ -251,7 +251,7 @@ class PlainDETRTransformer(nn.Module):
         ))
 
         topk = self.two_stage_num_proposals
-        topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+        topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]
         topk_coords_unact = torch.gather(
             enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
         )