__init__.py 836 B

1234567891011121314151617
  1. from .transformer import DETRTransformer
  2. def build_transformer(cfg, return_intermediate_dec):
  3. if cfg.transformer == "detr_transformer":
  4. return DETRTransformer(hidden_dim = cfg.hidden_dim,
  5. num_heads = cfg.num_heads,
  6. ffn_dim = cfg.feedforward_dim,
  7. num_enc_layers = cfg.num_enc_layers,
  8. num_dec_layers = cfg.num_dec_layers,
  9. dropout = cfg.dropout,
  10. act_type = cfg.tr_act,
  11. pre_norm = cfg.pre_norm,
  12. return_intermediate_dec=return_intermediate_dec)
  13. else:
  14. raise NotImplementedError("Unknown transformer: {}".format(cfg.transformer))