Browse Source

we find NMS is beneficial

yjh0410 1 year ago
parent
commit
bf6fb21912
1 changed files with 6 additions and 4 deletions
  1. 6 4
      models/detectors/rtpdetr/rtpdetr_decoder.py

+ 6 - 4
models/detectors/rtpdetr/rtpdetr_decoder.py

@@ -342,25 +342,26 @@ if __name__ == '__main__':
 
     cfg = {
         'out_stride': 16,
+        'hidden_dim': 256,
         # Transformer Decoder
         'transformer': 'plain_detr_transformer',
-        'hidden_dim': 256,
-        'num_queries': 300,
         'de_num_heads': 8,
         'de_num_layers': 6,
         'de_mlp_ratio': 4.0,
-        'de_dropout': 0.1,
+        'de_dropout': 0.0,
         'de_act': 'gelu',
         'de_pre_norm': True,
         'rpe_hidden_dim': 512,
         'use_checkpoint': False,
         'proposal_feature_levels': 3,
         'proposal_tgt_strides': [8, 16, 32],
+        'num_queries_one2one': 300,
+        'num_queries_one2many': 100,
     }
     feat = torch.randn(1, cfg['hidden_dim'], 40, 40)
     mask = torch.zeros(1, 40, 40)
     pos_embed = torch.randn(1, cfg['hidden_dim'], 40, 40)
-    query_embed = torch.randn(cfg['num_queries'], cfg['hidden_dim'])
+    query_embed = torch.randn(cfg['num_queries_one2one'] + cfg['num_queries_one2many'], cfg['hidden_dim'])
 
     model = build_transformer(cfg, True)
 
@@ -396,6 +397,7 @@ if __name__ == '__main__':
 
     print('==============================')
     model.eval()
+    query_embed = torch.randn(cfg['num_queries_one2one'], cfg['hidden_dim'])
     flops, params = profile(model, inputs=(feat, mask, pos_embed, query_embed, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))