transformer_decoder.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. # https://github.com/facebookresearch/detr
  3. import torch
  4. import torch.nn as nn
  5. try:
  6. from .utils import get_clones, get_activation_fn
  7. except:
  8. from utils import get_clones, get_activation_fn
  9. class TransformerDecoder(nn.Module):
  10. def __init__(self,
  11. decoder_layer,
  12. num_layers,
  13. norm=None,
  14. return_intermediate=False):
  15. super().__init__()
  16. # --------- Basic parameters ---------
  17. self.num_layers = num_layers
  18. self.return_intermediate = return_intermediate
  19. # --------- Model parameters ---------
  20. self.layers = get_clones(decoder_layer, num_layers)
  21. self.norm = norm
  22. def forward(self,
  23. tgt,
  24. tgt_mask,
  25. memory,
  26. memory_mask,
  27. memory_pos,
  28. query_pos):
  29. output = tgt
  30. intermediate = []
  31. for layer in self.layers:
  32. output = layer(output,
  33. tgt_mask,
  34. memory,
  35. memory_mask,
  36. memory_pos,
  37. query_pos)
  38. if self.return_intermediate:
  39. intermediate.append(self.norm(output))
  40. if self.norm is not None:
  41. output = self.norm(output)
  42. if self.return_intermediate:
  43. intermediate.pop()
  44. intermediate.append(output)
  45. if self.return_intermediate:
  46. return torch.stack(intermediate)
  47. return output.unsqueeze(0) # [M, N, B, C]
  48. class TransformerDecoderLayer(nn.Module):
  49. def __init__(self,
  50. hidden_dim,
  51. num_heads,
  52. ffn_dim=2048,
  53. dropout=0.1,
  54. act_type="relu",
  55. pre_norm=False):
  56. super().__init__()
  57. # ---------- Basic parameters ----------
  58. self.hidden_dim = hidden_dim
  59. self.num_heads = num_heads
  60. self.ffn_dim = ffn_dim
  61. self.act_type = act_type
  62. self.pre_norm = pre_norm
  63. # ---------- Model parameters ----------
  64. ## MHSA for object queries
  65. self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
  66. self.dropout1 = nn.Dropout(dropout)
  67. self.norm1 = nn.LayerNorm(hidden_dim)
  68. ## MHCA for object queries
  69. self.multihead_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
  70. self.dropout2 = nn.Dropout(dropout)
  71. self.norm2 = nn.LayerNorm(hidden_dim)
  72. ## Feedforward network
  73. self.linear1 = nn.Linear(hidden_dim, ffn_dim)
  74. self.activation = get_activation_fn(act_type)
  75. self.dropout = nn.Dropout(dropout)
  76. self.linear2 = nn.Linear(ffn_dim, hidden_dim)
  77. self.dropout3 = nn.Dropout(dropout)
  78. self.norm3 = nn.LayerNorm(hidden_dim)
  79. def with_pos_embed(self, tensor, pos_embed):
  80. return tensor if pos_embed is None else tensor + pos_embed
  81. def forward_post(self,
  82. tgt,
  83. tgt_mask,
  84. memory,
  85. memory_mask,
  86. memory_pos,
  87. query_pos,
  88. ):
  89. # MHSA for object queries
  90. q = k = self.with_pos_embed(tgt, query_pos)
  91. tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask)[0]
  92. tgt = tgt + self.dropout1(tgt2)
  93. tgt = self.norm1(tgt)
  94. # MHCA between object queries and image features
  95. q = self.with_pos_embed(tgt, query_pos)
  96. k = self.with_pos_embed(memory, memory_pos)
  97. tgt2 = self.multihead_attn(q, k, memory, key_padding_mask=memory_mask)[0]
  98. tgt = tgt + self.dropout2(tgt2)
  99. tgt = self.norm2(tgt)
  100. # FFN
  101. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  102. tgt = tgt + self.dropout3(tgt2)
  103. tgt = self.norm3(tgt)
  104. return tgt
  105. def forward_pre(self,
  106. tgt,
  107. tgt_mask,
  108. memory,
  109. memory_mask,
  110. memory_pos,
  111. query_pos,
  112. ):
  113. # MHSA for object queries
  114. tgt2 = self.norm1(tgt)
  115. q = k = self.with_pos_embed(tgt2, query_pos)
  116. tgt2 = self.self_attn(q, k, tgt2, attn_mask=tgt_mask)[0]
  117. tgt = tgt + self.dropout1(tgt2)
  118. tgt2 = self.norm2(tgt)
  119. # MHCA between object queries and image features
  120. q = self.with_pos_embed(tgt2, query_pos)
  121. k = self.with_pos_embed(memory, memory_pos)
  122. tgt2 = self.multihead_attn(q, k, memory, key_padding_mask=memory_mask)[0]
  123. tgt = tgt + self.dropout2(tgt2)
  124. # FFN
  125. tgt2 = self.norm3(tgt)
  126. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  127. tgt = tgt + self.dropout3(tgt2)
  128. return tgt
  129. def forward(self,
  130. tgt,
  131. tgt_mask,
  132. memory,
  133. memory_mask,
  134. memory_pos,
  135. query_pos,):
  136. if self.pre_norm:
  137. return self.forward_pre(tgt, tgt_mask, memory, memory_mask, memory_pos, query_pos)
  138. else:
  139. return self.forward_post(tgt, tgt_mask, memory, memory_mask, memory_pos, query_pos)
  140. if __name__ == "__main__":
  141. pass