yjh0410 há 2 anos atrás
pai
commit
16906f013c
1 ficheiros alterados com 2 adições e 1 exclusões
  1. 2 1
      models/detectors/rtdetr/rtdetr_decoder.py

+ 2 - 1
models/detectors/rtdetr/rtdetr_decoder.py

@@ -14,13 +14,14 @@ class TransformerDecoder(nn.Module):
         self.num_queries = cfg['num_queries']
         self.num_deocder_layers = cfg['num_decoder_layers']
         self.return_intermediate = return_intermediate
+        self.ffn_dim = round(cfg['de_dim_feedforward']*cfg['width'])
 
         # -------------------- Network Parameters ---------------------
         ## Decoder
         decoder_layer = TRDecoderLayer(
             d_model=in_dim,
+            dim_feedforward=self.ffn_dim,
             num_heads=cfg['de_num_heads'],
-            dim_feedforward=cfg['de_dim_feedforward'],
             dropout=cfg['de_dropout'],
             act_type=cfg['de_act']
         )