transformer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn.init import constant_, xavier_uniform_
  7. try:
  8. from .basic import get_activation, MLP, FFN
  9. except:
  10. from basic import get_activation, MLP, FFN
  11. def get_clones(module, N):
  12. if N <= 0:
  13. return None
  14. else:
  15. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  16. def inverse_sigmoid(x, eps=1e-5):
  17. x = x.clamp(min=0., max=1.)
  18. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  19. # ----------------- Transformer modules -----------------
  20. ## Transformer Encoder layer
  21. class TransformerEncoderLayer(nn.Module):
  22. def __init__(self,
  23. d_model :int = 256,
  24. num_heads :int = 8,
  25. mlp_ratio :float = 4.0,
  26. dropout :float = 0.1,
  27. act_type :str = "relu",
  28. ):
  29. super().__init__()
  30. # ----------- Basic parameters -----------
  31. self.d_model = d_model
  32. self.num_heads = num_heads
  33. self.mlp_ratio = mlp_ratio
  34. self.dropout = dropout
  35. self.act_type = act_type
  36. # ----------- Basic parameters -----------
  37. # Multi-head Self-Attn
  38. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  39. self.dropout = nn.Dropout(dropout)
  40. self.norm = nn.LayerNorm(d_model)
  41. # Feedforwaed Network
  42. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  43. def with_pos_embed(self, tensor, pos):
  44. return tensor if pos is None else tensor + pos
  45. def forward(self, src, pos_embed):
  46. """
  47. Input:
  48. src: [torch.Tensor] -> [B, N, C]
  49. pos_embed: [torch.Tensor] -> [B, N, C]
  50. Output:
  51. src: [torch.Tensor] -> [B, N, C]
  52. """
  53. q = k = self.with_pos_embed(src, pos_embed)
  54. # -------------- MHSA --------------
  55. src2 = self.self_attn(q, k, value=src)[0]
  56. src = src + self.dropout(src2)
  57. src = self.norm(src)
  58. # -------------- FFN --------------
  59. src = self.ffn(src)
  60. return src
  61. ## Transformer Encoder
  62. class TransformerEncoder(nn.Module):
  63. def __init__(self,
  64. d_model :int = 256,
  65. num_heads :int = 8,
  66. num_layers :int = 1,
  67. mlp_ratio :float = 4.0,
  68. pe_temperature : float = 10000.,
  69. dropout :float = 0.1,
  70. act_type :str = "relu",
  71. ):
  72. super().__init__()
  73. # ----------- Basic parameters -----------
  74. self.d_model = d_model
  75. self.num_heads = num_heads
  76. self.num_layers = num_layers
  77. self.mlp_ratio = mlp_ratio
  78. self.dropout = dropout
  79. self.act_type = act_type
  80. self.pe_temperature = pe_temperature
  81. self.pos_embed = None
  82. # ----------- Basic parameters -----------
  83. self.encoder_layers = get_clones(
  84. TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
  85. def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
  86. assert embed_dim % 4 == 0, \
  87. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  88. # ----------- Check cahed pos_embed -----------
  89. if self.pos_embed is not None and \
  90. self.pos_embed.shape[2:] == [h, w]:
  91. return self.pos_embed
  92. # ----------- Generate grid coords -----------
  93. grid_w = torch.arange(int(w), dtype=torch.float32)
  94. grid_h = torch.arange(int(h), dtype=torch.float32)
  95. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  96. pos_dim = embed_dim // 4
  97. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  98. omega = 1. / (temperature**omega)
  99. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  100. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  101. # shape: [1, N, C]
  102. pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
  103. pos_embed = pos_embed.to(device)
  104. self.pos_embed = pos_embed
  105. return pos_embed
  106. def forward(self, src):
  107. """
  108. Input:
  109. src: [torch.Tensor] -> [B, C, H, W]
  110. Output:
  111. src: [torch.Tensor] -> [B, C, H, W]
  112. """
  113. # -------- Transformer encoder --------
  114. channels, fmp_h, fmp_w = src.shape[1:]
  115. # [B, C, H, W] -> [B, N, C], N=HxW
  116. src_flatten = src.flatten(2).permute(0, 2, 1)
  117. memory = src_flatten
  118. # PosEmbed: [1, N, C]
  119. pos_embed = self.build_2d_sincos_position_embedding(
  120. src.device, fmp_w, fmp_h, channels, self.pe_temperature)
  121. # Transformer Encoder layer
  122. for encoder in self.encoder_layers:
  123. memory = encoder(memory, pos_embed=pos_embed)
  124. # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
  125. src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
  126. return src
  127. ## Transformer Decoder layer
  128. class PlainTransformerDecoderLayer(nn.Module):
  129. def __init__(self,
  130. d_model :int = 256,
  131. num_heads :int = 8,
  132. num_levels :int = 3,
  133. num_points :int = 4,
  134. mlp_ratio :float = 4.0,
  135. dropout :float = 0.1,
  136. act_type :str = "relu",
  137. ):
  138. super().__init__()
  139. # ----------- Basic parameters -----------
  140. self.d_model = d_model
  141. self.num_heads = num_heads
  142. self.num_levels = num_levels
  143. self.num_points = num_points
  144. self.mlp_ratio = mlp_ratio
  145. self.dropout = dropout
  146. self.act_type = act_type
  147. # ---------------- Network parameters ----------------
  148. ## Multi-head Self-Attn
  149. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
  150. self.dropout1 = nn.Dropout(dropout)
  151. self.norm1 = nn.LayerNorm(d_model)
  152. ## CrossAttention
  153. self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
  154. self.dropout2 = nn.Dropout(dropout)
  155. self.norm2 = nn.LayerNorm(d_model)
  156. ## FFN
  157. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  158. def with_pos_embed(self, tensor, pos):
  159. return tensor if pos is None else tensor + pos
  160. def forward(self,
  161. tgt,
  162. reference_points,
  163. memory,
  164. memory_spatial_shapes,
  165. attn_mask=None,
  166. memory_mask=None,
  167. query_pos_embed=None):
  168. # ---------------- MSHA for Object Query -----------------
  169. q = k = self.with_pos_embed(tgt, query_pos_embed)
  170. if attn_mask is not None:
  171. attn_mask = torch.where(
  172. attn_mask.bool(),
  173. torch.zeros(attn_mask.shape, dtype=tgt.dtype, device=attn_mask.device),
  174. torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype, device=attn_mask.device))
  175. tgt2 = self.self_attn(q, k, value=tgt)[0]
  176. tgt = tgt + self.dropout1(tgt2)
  177. tgt = self.norm1(tgt)
  178. # ---------------- CMHA for Object Query and Image-feature -----------------
  179. tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
  180. reference_points,
  181. memory,
  182. memory_spatial_shapes,
  183. memory_mask)
  184. tgt = tgt + self.dropout2(tgt2)
  185. tgt = self.norm2(tgt)
  186. # ---------------- FeedForward Network -----------------
  187. tgt = self.ffn(tgt)
  188. return tgt
  189. ## Transformer Decoder
  190. class PlainTransformerDecoder(nn.Module):
  191. def __init__(self,
  192. d_model :int = 256,
  193. num_heads :int = 8,
  194. num_layers :int = 1,
  195. num_levels :int = 3,
  196. num_points :int = 4,
  197. mlp_ratio :float = 4.0,
  198. dropout :float = 0.1,
  199. act_type :str = "relu",
  200. return_intermediate :bool = False,
  201. ):
  202. super().__init__()
  203. # ----------- Basic parameters -----------
  204. self.d_model = d_model
  205. self.num_heads = num_heads
  206. self.num_layers = num_layers
  207. self.mlp_ratio = mlp_ratio
  208. self.dropout = dropout
  209. self.act_type = act_type
  210. self.pos_embed = None
  211. # ----------- Network parameters -----------
  212. self.decoder_layers = get_clones(
  213. TransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
  214. self.num_layers = num_layers
  215. self.return_intermediate = return_intermediate
  216. def forward(self,
  217. tgt,
  218. ref_points_unact,
  219. memory,
  220. memory_spatial_shapes,
  221. bbox_head,
  222. score_head,
  223. query_pos_head,
  224. attn_mask=None,
  225. memory_mask=None):
  226. output = tgt
  227. dec_out_bboxes = []
  228. dec_out_logits = []
  229. ref_points_detach = F.sigmoid(ref_points_unact)
  230. for i, layer in enumerate(self.decoder_layers):
  231. ref_points_input = ref_points_detach.unsqueeze(2)
  232. query_pos_embed = query_pos_head(ref_points_detach)
  233. output = layer(output, ref_points_input, memory,
  234. memory_spatial_shapes, attn_mask,
  235. memory_mask, query_pos_embed)
  236. inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  237. ref_points_detach))
  238. dec_out_logits.append(score_head[i](output))
  239. if i == 0:
  240. dec_out_bboxes.append(inter_ref_bbox)
  241. else:
  242. dec_out_bboxes.append(
  243. F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  244. ref_points)))
  245. ref_points = inter_ref_bbox
  246. ref_points_detach = inter_ref_bbox.detach()
  247. return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)