transformer.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import math
  2. import copy
  3. import warnings
  4. from typing import List
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from ..basic.mlp import FFN, MLP
  9. from ..basic.conv import LayerNorm2D, BasicConv
  10. # ----------------- Basic Ops -----------------
  11. def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  12. """Copy from timm"""
  13. with torch.no_grad():
  14. """Copy from timm"""
  15. def norm_cdf(x):
  16. return (1. + math.erf(x / math.sqrt(2.))) / 2.
  17. if (mean < a - 2 * std) or (mean > b + 2 * std):
  18. warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  19. "The distribution of values may be incorrect.",
  20. stacklevel=2)
  21. l = norm_cdf((a - mean) / std)
  22. u = norm_cdf((b - mean) / std)
  23. tensor.uniform_(2 * l - 1, 2 * u - 1)
  24. tensor.erfinv_()
  25. tensor.mul_(std * math.sqrt(2.))
  26. tensor.add_(mean)
  27. tensor.clamp_(min=a, max=b)
  28. return tensor
  29. def get_clones(module, N):
  30. if N <= 0:
  31. return None
  32. else:
  33. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  34. def inverse_sigmoid(x, eps=1e-5):
  35. x = x.clamp(min=0., max=1.)
  36. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  37. def build_transformer(cfg, num_classes=80, return_intermediate=False):
  38. if cfg['transformer'] == 'plain_detr_transformer':
  39. return PlainDETRTransformer(d_model = cfg['hidden_dim'],
  40. num_heads = cfg['de_num_heads'],
  41. ffn_dim = cfg['de_ffn_dim'],
  42. dropout = cfg['de_dropout'],
  43. act_type = cfg['de_act'],
  44. pre_norm = cfg['de_pre_norm'],
  45. rpe_hidden_dim = cfg['rpe_hidden_dim'],
  46. feature_stride = cfg['out_stride'],
  47. num_layers = cfg['de_num_layers'],
  48. return_intermediate = return_intermediate,
  49. use_checkpoint = cfg['use_checkpoint'],
  50. num_queries_one2one = cfg['num_queries_one2one'],
  51. num_queries_one2many = cfg['num_queries_one2many'],
  52. proposal_feature_levels = cfg['proposal_feature_levels'],
  53. proposal_in_stride = cfg['out_stride'],
  54. proposal_tgt_strides = cfg['proposal_tgt_strides'],
  55. )
  56. elif cfg['transformer'] == 'rtdetr_transformer':
  57. return RTDETRTransformer(in_dims = cfg['backbone_feat_dims'],
  58. hidden_dim = cfg['hidden_dim'],
  59. strides = cfg['out_stride'],
  60. num_classes = num_classes,
  61. num_queries = cfg['num_queries'],
  62. num_heads = cfg['de_num_heads'],
  63. num_layers = cfg['de_num_layers'],
  64. num_levels = 3,
  65. num_points = cfg['de_num_points'],
  66. ffn_dim = cfg['de_ffn_dim'],
  67. dropout = cfg['de_dropout'],
  68. act_type = cfg['de_act'],
  69. pre_norm = cfg['de_pre_norm'],
  70. return_intermediate = return_intermediate,
  71. num_denoising = cfg['dn_num_denoising'],
  72. label_noise_ratio = cfg['dn_label_noise_ratio'],
  73. box_noise_scale = cfg['dn_box_noise_scale'],
  74. learnt_init_query = cfg['learnt_init_query'],
  75. )
  76. # ----------------- Transformer Encoder -----------------
  77. class TransformerEncoderLayer(nn.Module):
  78. def __init__(self,
  79. d_model :int = 256,
  80. num_heads :int = 8,
  81. ffn_dim :int = 1024,
  82. dropout :float = 0.1,
  83. act_type :str = "relu",
  84. pre_norm :bool = False,
  85. ):
  86. super().__init__()
  87. # ----------- Basic parameters -----------
  88. self.d_model = d_model
  89. self.num_heads = num_heads
  90. self.ffn_dim = ffn_dim
  91. self.dropout = dropout
  92. self.act_type = act_type
  93. self.pre_norm = pre_norm
  94. # ----------- Basic parameters -----------
  95. # Multi-head Self-Attn
  96. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  97. self.dropout = nn.Dropout(dropout)
  98. self.norm = nn.LayerNorm(d_model)
  99. # Feedforwaed Network
  100. self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
  101. def with_pos_embed(self, tensor, pos):
  102. return tensor if pos is None else tensor + pos
  103. def forward_pre_norm(self, src, pos_embed):
  104. """
  105. Input:
  106. src: [torch.Tensor] -> [B, N, C]
  107. pos_embed: [torch.Tensor] -> [B, N, C]
  108. Output:
  109. src: [torch.Tensor] -> [B, N, C]
  110. """
  111. src = self.norm(src)
  112. q = k = self.with_pos_embed(src, pos_embed)
  113. # -------------- MHSA --------------
  114. src2 = self.self_attn(q, k, value=src)[0]
  115. src = src + self.dropout(src2)
  116. # -------------- FFN --------------
  117. src = self.ffn(src)
  118. return src
  119. def forward_post_norm(self, src, pos_embed):
  120. """
  121. Input:
  122. src: [torch.Tensor] -> [B, N, C]
  123. pos_embed: [torch.Tensor] -> [B, N, C]
  124. Output:
  125. src: [torch.Tensor] -> [B, N, C]
  126. """
  127. q = k = self.with_pos_embed(src, pos_embed)
  128. # -------------- MHSA --------------
  129. src2 = self.self_attn(q, k, value=src)[0]
  130. src = src + self.dropout(src2)
  131. src = self.norm(src)
  132. # -------------- FFN --------------
  133. src = self.ffn(src)
  134. return src
  135. def forward(self, src, pos_embed):
  136. if self.pre_norm:
  137. return self.forward_pre_norm(src, pos_embed)
  138. else:
  139. return self.forward_post_norm(src, pos_embed)
  140. class TransformerEncoder(nn.Module):
  141. def __init__(self,
  142. d_model :int = 256,
  143. num_heads :int = 8,
  144. num_layers :int = 1,
  145. ffn_dim :int = 1024,
  146. pe_temperature :float = 10000.,
  147. dropout :float = 0.1,
  148. act_type :str = "relu",
  149. pre_norm :bool = False,
  150. ):
  151. super().__init__()
  152. # ----------- Basic parameters -----------
  153. self.d_model = d_model
  154. self.num_heads = num_heads
  155. self.num_layers = num_layers
  156. self.ffn_dim = ffn_dim
  157. self.dropout = dropout
  158. self.act_type = act_type
  159. self.pre_norm = pre_norm
  160. self.pe_temperature = pe_temperature
  161. self.pos_embed = None
  162. # ----------- Basic parameters -----------
  163. self.encoder_layers = get_clones(
  164. TransformerEncoderLayer(d_model, num_heads, ffn_dim, dropout, act_type, pre_norm), num_layers)
  165. def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
  166. assert embed_dim % 4 == 0, \
  167. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  168. # ----------- Check cahed pos_embed -----------
  169. if self.pos_embed is not None and \
  170. self.pos_embed.shape[2:] == [h, w]:
  171. return self.pos_embed
  172. # ----------- Generate grid coords -----------
  173. grid_w = torch.arange(int(w), dtype=torch.float32)
  174. grid_h = torch.arange(int(h), dtype=torch.float32)
  175. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  176. pos_dim = embed_dim // 4
  177. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  178. omega = 1. / (temperature**omega)
  179. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  180. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  181. # shape: [1, N, C]
  182. pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
  183. pos_embed = pos_embed.to(device)
  184. self.pos_embed = pos_embed
  185. return pos_embed
  186. def forward(self, src):
  187. """
  188. Input:
  189. src: [torch.Tensor] -> [B, C, H, W]
  190. Output:
  191. src: [torch.Tensor] -> [B, C, H, W]
  192. """
  193. # -------- Transformer encoder --------
  194. channels, fmp_h, fmp_w = src.shape[1:]
  195. # [B, C, H, W] -> [B, N, C], N=HxW
  196. src_flatten = src.flatten(2).permute(0, 2, 1).contiguous()
  197. memory = src_flatten
  198. # PosEmbed: [1, N, C]
  199. pos_embed = self.build_2d_sincos_position_embedding(
  200. src.device, fmp_w, fmp_h, channels, self.pe_temperature)
  201. # Transformer Encoder layer
  202. for encoder in self.encoder_layers:
  203. memory = encoder(memory, pos_embed=pos_embed)
  204. # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
  205. src = memory.permute(0, 2, 1).contiguous()
  206. src = src.view([-1, channels, fmp_h, fmp_w])
  207. return src