| 1234567891011121314151617 |
- from .transformer import DETRTransformer
- def build_transformer(cfg, return_intermediate_dec):
- if cfg.transformer == "detr_transformer":
- return DETRTransformer(hidden_dim = cfg.hidden_dim,
- num_heads = cfg.num_heads,
- ffn_dim = cfg.feedforward_dim,
- num_enc_layers = cfg.num_enc_layers,
- num_dec_layers = cfg.num_dec_layers,
- dropout = cfg.dropout,
- act_type = cfg.tr_act,
- pre_norm = cfg.pre_norm,
- return_intermediate_dec=return_intermediate_dec)
- else:
- raise NotImplementedError("Unknown transformer: {}".format(cfg.transformer))
-
|