utils.py 601 B

123456789101112131415161718192021
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. # https://github.com/facebookresearch/detr
  3. import copy
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. def get_clones(module, N):
  7. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  8. def get_activation_fn(activation):
  9. """Return an activation function given a string"""
  10. if activation == "relu":
  11. return F.relu
  12. if activation == "gelu":
  13. return F.gelu
  14. if activation == "glu":
  15. return F.glu
  16. raise RuntimeError(F"activation should be relu/gelu, not {activation}.")