|
|
@@ -14,13 +14,14 @@ class TransformerDecoder(nn.Module):
|
|
|
self.num_queries = cfg['num_queries']
|
|
|
self.num_deocder_layers = cfg['num_decoder_layers']
|
|
|
self.return_intermediate = return_intermediate
|
|
|
+ self.ffn_dim = round(cfg['de_dim_feedforward']*cfg['width'])
|
|
|
|
|
|
# -------------------- Network Parameters ---------------------
|
|
|
## Decoder
|
|
|
decoder_layer = TRDecoderLayer(
|
|
|
d_model=in_dim,
|
|
|
+ dim_feedforward=self.ffn_dim,
|
|
|
num_heads=cfg['de_num_heads'],
|
|
|
- dim_feedforward=cfg['de_dim_feedforward'],
|
|
|
dropout=cfg['de_dropout'],
|
|
|
act_type=cfg['de_act']
|
|
|
)
|