rtdetr_basic.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import copy
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch import nn, Tensor
  7. # ------------------------------- Basic Modules -------------------------------
  8. def get_activation(act_type=None):
  9. if act_type == 'relu':
  10. return nn.ReLU(inplace=True)
  11. elif act_type == 'gelu':
  12. return nn.GELU()
  13. elif act_type == 'lrelu':
  14. return nn.LeakyReLU(0.1, inplace=True)
  15. elif act_type == 'mish':
  16. return nn.Mish(inplace=True)
  17. elif act_type == 'silu':
  18. return nn.SiLU(inplace=True)
  19. def get_norm(norm_type, dim):
  20. if norm_type == 'BN':
  21. return nn.BatchNorm2d(dim)
  22. elif norm_type == 'GN':
  23. return nn.GroupNorm(num_groups=32, num_channels=dim)
  24. elif norm_type == 'LN':
  25. return nn.LayerNorm(dim)
  26. def get_clones(module, N):
  27. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  28. def build_multi_head_attention(d_model, num_heads, dropout, attn_type='mhsa'):
  29. if attn_type == 'mhsa':
  30. attn_layer = MultiHeadAttention(d_model, num_heads, dropout)
  31. elif attn_type == 's_mhsa':
  32. attn_layer = None
  33. return attn_layer
  34. # ------------------------------- MLP -------------------------------
  35. class MLP(nn.Module):
  36. """ Very simple multi-layer perceptron (also called FFN)"""
  37. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  38. super().__init__()
  39. self.num_layers = num_layers
  40. h = [hidden_dim] * (num_layers - 1)
  41. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  42. def forward(self, x):
  43. for i, layer in enumerate(self.layers):
  44. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  45. return x
  46. # ------------------------------- Transformer Modules -------------------------------
  47. ## Vanilla Multi-Head Attention
  48. class MultiHeadAttention(nn.Module):
  49. def __init__(self, d_model, num_heads, dropout=0.) -> None:
  50. super().__init__()
  51. # --------------- Basic parameters ---------------
  52. self.d_model = d_model
  53. self.num_heads = num_heads
  54. self.dropout = dropout
  55. self.scale = (d_model // num_heads) ** -0.5
  56. # --------------- Network parameters ---------------
  57. self.q_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
  58. self.k_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
  59. self.v_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
  60. self.out_proj = nn.Linear(d_model, d_model)
  61. self.dropout = nn.Dropout(dropout)
  62. def forward(self, query, key, value):
  63. """
  64. Inputs:
  65. query : (Tensor) -> [B, Nq, C]
  66. key : (Tensor) -> [B, Nk, C]
  67. value : (Tensor) -> [B, Nk, C]
  68. """
  69. bs = query.shape[0]
  70. Nq = query.shape[1]
  71. Nk = key.shape[1]
  72. # ----------------- Input proj -----------------
  73. query = self.q_proj(query)
  74. key = self.k_proj(key)
  75. value = self.v_proj(value)
  76. # ----------------- Multi-head Attn -----------------
  77. ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
  78. query = query.view(bs, Nq, self.num_heads, self.d_model // self.num_heads)
  79. query = query.permute(0, 2, 1, 3).contiguous()
  80. key = key.view(bs, Nk, self.num_heads, self.d_model // self.num_heads)
  81. key = key.permute(0, 2, 1, 3).contiguous()
  82. value = value.view(bs, Nk, self.num_heads, self.d_model // self.num_heads)
  83. value = value.permute(0, 2, 1, 3).contiguous()
  84. # Attention
  85. ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
  86. sim_matrix = torch.matmul(query, key.transpose(-1, -2)) * self.scale
  87. sim_matrix = torch.softmax(sim_matrix, dim=-1)
  88. # ----------------- Output -----------------
  89. out = torch.matmul(sim_matrix, value) # [B, H, Nq, C_h]
  90. out = out.permute(0, 2, 1, 3).contiguous().view(bs, Nq, -1)
  91. out = self.out_proj(out)
  92. return out
  93. ## Transformer Encoder layer
  94. class TREncoderLayer(nn.Module):
  95. def __init__(self,
  96. d_model,
  97. num_heads,
  98. dim_feedforward=2048,
  99. dropout=0.1,
  100. act_type="relu",
  101. attn_type='mhsa'
  102. ):
  103. super().__init__()
  104. # Multi-head Self-Attn
  105. self.self_attn = build_multi_head_attention(d_model, num_heads, dropout, attn_type)
  106. # Feedforwaed Network
  107. self.linear1 = nn.Linear(d_model, dim_feedforward)
  108. self.dropout = nn.Dropout(dropout)
  109. self.linear2 = nn.Linear(dim_feedforward, d_model)
  110. self.norm1 = nn.LayerNorm(d_model)
  111. self.norm2 = nn.LayerNorm(d_model)
  112. self.dropout1 = nn.Dropout(dropout)
  113. self.dropout2 = nn.Dropout(dropout)
  114. self.activation = get_activation(act_type)
  115. def with_pos_embed(self, tensor, pos: Optional[Tensor]):
  116. return tensor if pos is None else tensor + pos
  117. def forward(self, src, pos):
  118. """
  119. Input:
  120. src: [torch.Tensor] -> [B, N, C]
  121. pos: [torch.Tensor] -> [B, N, C]
  122. Output:
  123. src: [torch.Tensor] -> [B, N, C]
  124. """
  125. q = k = self.with_pos_embed(src, pos)
  126. # self-attn
  127. src2 = self.self_attn(q, k, value=src)
  128. # reshape: [B, N, C] -> [B, C, H, W]
  129. src = src + self.dropout1(src2)
  130. src = self.norm1(src)
  131. # ffpn
  132. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  133. src = src + self.dropout2(src2)
  134. src = self.norm2(src)
  135. return src
  136. ## Transformer Decoder layer
  137. class TRDecoderLayer(nn.Module):
  138. def __init__(self, d_model, num_heads, dim_feedforward=2048, dropout=0.1, act_type="relu", attn_type='mhsa'):
  139. super().__init__()
  140. # Multi-head Self-Attn
  141. self.self_attn = build_multi_head_attention(d_model, num_heads, dropout, attn_type)
  142. self.cross_attn = build_multi_head_attention(d_model, num_heads, dropout)
  143. # Feedforward Network
  144. self.linear1 = nn.Linear(d_model, dim_feedforward)
  145. self.dropout = nn.Dropout(dropout)
  146. self.linear2 = nn.Linear(dim_feedforward, d_model)
  147. self.norm1 = nn.LayerNorm(d_model)
  148. self.norm2 = nn.LayerNorm(d_model)
  149. self.norm3 = nn.LayerNorm(d_model)
  150. self.dropout1 = nn.Dropout(dropout)
  151. self.dropout2 = nn.Dropout(dropout)
  152. self.dropout3 = nn.Dropout(dropout)
  153. self.activation = get_activation(act_type)
  154. def with_pos_embed(self, tensor, pos: Optional[Tensor]):
  155. return tensor if pos is None else tensor + pos
  156. def forward(self, tgt, tgt_query_pos, memory, memory_pos):
  157. # self attention
  158. tgt2 = self.self_attn(
  159. query=self.with_pos_embed(tgt, tgt_query_pos),
  160. key=self.with_pos_embed(tgt, tgt_query_pos),
  161. value=tgt)[0]
  162. tgt = tgt + self.dropout1(tgt2)
  163. tgt = self.norm1(tgt)
  164. # cross attention
  165. tgt2 = self.cross_attn(
  166. query=self.with_pos_embed(tgt, tgt_query_pos),
  167. key=self.with_pos_embed(memory, memory_pos),
  168. value=memory)
  169. tgt = tgt + self.dropout2(tgt2)
  170. tgt = self.norm2(tgt)
  171. # ffn
  172. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  173. tgt = tgt + self.dropout3(tgt2)
  174. tgt = self.norm3(tgt)
  175. return tgt