| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- # https://github.com/facebookresearch/detr
- import torch
- import torch.nn as nn
- try:
- from .utils import get_clones, get_activation_fn
- except:
- from utils import get_clones, get_activation_fn
- class TransformerEncoder(nn.Module):
- def __init__(self,
- encoder_layer,
- num_layers,
- norm=None):
- super().__init__()
- # -------- Basic parameters --------
- self.num_layers = num_layers
- # -------- Model parameters --------
- self.layers = get_clones(encoder_layer, num_layers)
- self.norm = norm
- def forward(self, src, src_mask, pos_embed):
- output = src
- for layer in self.layers:
- output = layer(output, src_mask, pos_embed)
- if self.norm is not None:
- output = self.norm(output)
- return output
- class TransformerEncoderLayer(nn.Module):
- def __init__(self,
- hidden_dim :int = 256,
- num_heads :int = 8,
- ffn_dim :int = 2048,
- dropout :float = 0.1,
- act_type :str = "relu",
- pre_norm :bool = False,):
- super().__init__()
- # ---------- Basic parameters ----------
- self.hidden_dim = hidden_dim
- self.num_heads = num_heads
- self.ffn_dim = ffn_dim
- self.act_type = act_type
- self.pre_norm = pre_norm
- # ---------- Model parameters ----------
- # Multi-head Self-Attn
- self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
- self.dropout1 = nn.Dropout(dropout)
- self.norm1 = nn.LayerNorm(hidden_dim)
- ## Feedforward network
- self.linear1 = nn.Linear(hidden_dim, ffn_dim)
- self.activation = get_activation_fn(act_type)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(ffn_dim, hidden_dim)
- self.dropout2 = nn.Dropout(dropout)
- self.norm2 = nn.LayerNorm(hidden_dim)
- def with_pos_embed(self, tensor, pos_embed):
- return tensor if pos_embed is None else tensor + pos_embed
- def forward_post(self, src, src_mask, pos_embed):
- # MSHA
- q = k = self.with_pos_embed(src, pos_embed)
- src2 = self.self_attn(q, k, src, src_key_padding_mask=src_mask)[0]
- src = src + self.dropout1(src2)
- src = self.norm1(src)
- # FFN
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
- src = src + self.dropout2(src2)
- src = self.norm2(src)
- return src
- def forward_pre(self, src, src_mask, pos_embed):
- # MSHA
- src2 = self.norm1(src)
- q = k = self.with_pos_embed(src2, pos_embed)
- src2 = self.self_attn(q, k, src2, src_key_padding_mask=src_mask)[0]
- src = src + self.dropout1(src2)
- # FFN
- src2 = self.norm2(src)
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
- src = src + self.dropout2(src2)
- return src
- def forward(self, src, src_mask, pos_embed):
- if self.pre_norm:
- return self.forward_pre(src, src_mask, pos_embed)
- else:
- return self.forward_post(src, src_mask, pos_embed)
- if __name__ == "__main__":
- pass
|