transformer_encoder.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. # https://github.com/facebookresearch/detr
  3. import torch
  4. import torch.nn as nn
  5. try:
  6. from .utils import get_clones, get_activation_fn
  7. except:
  8. from utils import get_clones, get_activation_fn
  9. class TransformerEncoder(nn.Module):
  10. def __init__(self,
  11. encoder_layer,
  12. num_layers,
  13. norm=None):
  14. super().__init__()
  15. # -------- Basic parameters --------
  16. self.num_layers = num_layers
  17. # -------- Model parameters --------
  18. self.layers = get_clones(encoder_layer, num_layers)
  19. self.norm = norm
  20. def forward(self, src, src_mask, pos_embed):
  21. output = src
  22. for layer in self.layers:
  23. output = layer(output, src_mask, pos_embed)
  24. if self.norm is not None:
  25. output = self.norm(output)
  26. return output
  27. class TransformerEncoderLayer(nn.Module):
  28. def __init__(self,
  29. hidden_dim :int = 256,
  30. num_heads :int = 8,
  31. ffn_dim :int = 2048,
  32. dropout :float = 0.1,
  33. act_type :str = "relu",
  34. pre_norm :bool = False,):
  35. super().__init__()
  36. # ---------- Basic parameters ----------
  37. self.hidden_dim = hidden_dim
  38. self.num_heads = num_heads
  39. self.ffn_dim = ffn_dim
  40. self.act_type = act_type
  41. self.pre_norm = pre_norm
  42. # ---------- Model parameters ----------
  43. # Multi-head Self-Attn
  44. self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
  45. self.dropout1 = nn.Dropout(dropout)
  46. self.norm1 = nn.LayerNorm(hidden_dim)
  47. ## Feedforward network
  48. self.linear1 = nn.Linear(hidden_dim, ffn_dim)
  49. self.activation = get_activation_fn(act_type)
  50. self.dropout = nn.Dropout(dropout)
  51. self.linear2 = nn.Linear(ffn_dim, hidden_dim)
  52. self.dropout2 = nn.Dropout(dropout)
  53. self.norm2 = nn.LayerNorm(hidden_dim)
  54. def with_pos_embed(self, tensor, pos_embed):
  55. return tensor if pos_embed is None else tensor + pos_embed
  56. def forward_post(self, src, src_mask, pos_embed):
  57. # MSHA
  58. q = k = self.with_pos_embed(src, pos_embed)
  59. src2 = self.self_attn(q, k, src, src_key_padding_mask=src_mask)[0]
  60. src = src + self.dropout1(src2)
  61. src = self.norm1(src)
  62. # FFN
  63. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  64. src = src + self.dropout2(src2)
  65. src = self.norm2(src)
  66. return src
  67. def forward_pre(self, src, src_mask, pos_embed):
  68. # MSHA
  69. src2 = self.norm1(src)
  70. q = k = self.with_pos_embed(src2, pos_embed)
  71. src2 = self.self_attn(q, k, src2, src_key_padding_mask=src_mask)[0]
  72. src = src + self.dropout1(src2)
  73. # FFN
  74. src2 = self.norm2(src)
  75. src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
  76. src = src + self.dropout2(src2)
  77. return src
  78. def forward(self, src, src_mask, pos_embed):
  79. if self.pre_norm:
  80. return self.forward_pre(src, src_mask, pos_embed)
  81. else:
  82. return self.forward_post(src, src_mask, pos_embed)
  83. if __name__ == "__main__":
  84. pass