| 123456789101112131415161718192021 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- # https://github.com/facebookresearch/detr
- import copy
- import torch.nn as nn
- import torch.nn.functional as F
- def get_clones(module, N):
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
- def get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
-
|