transformer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. try:
  7. from .basic import FFN
  8. except:
  9. from basic import FFN
  10. def get_clones(module, N):
  11. if N <= 0:
  12. return None
  13. else:
  14. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  15. def inverse_sigmoid(x, eps=1e-5):
  16. x = x.clamp(min=0., max=1.)
  17. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  18. # ----------------- Basic Transformer Ops -----------------
  19. def multi_scale_deformable_attn_pytorch(
  20. value: torch.Tensor,
  21. value_spatial_shapes: torch.Tensor,
  22. sampling_locations: torch.Tensor,
  23. attention_weights: torch.Tensor,
  24. ) -> torch.Tensor:
  25. bs, _, num_heads, embed_dims = value.shape
  26. _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
  27. value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
  28. sampling_grids = 2 * sampling_locations - 1
  29. sampling_value_list = []
  30. for level, (H_, W_) in enumerate(value_spatial_shapes):
  31. # bs, H_*W_, num_heads, embed_dims ->
  32. # bs, H_*W_, num_heads*embed_dims ->
  33. # bs, num_heads*embed_dims, H_*W_ ->
  34. # bs*num_heads, embed_dims, H_, W_
  35. value_l_ = (
  36. value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
  37. )
  38. # bs, num_queries, num_heads, num_points, 2 ->
  39. # bs, num_heads, num_queries, num_points, 2 ->
  40. # bs*num_heads, num_queries, num_points, 2
  41. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
  42. # bs*num_heads, embed_dims, num_queries, num_points
  43. sampling_value_l_ = F.grid_sample(
  44. value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
  45. )
  46. sampling_value_list.append(sampling_value_l_)
  47. # (bs, num_queries, num_heads, num_levels, num_points) ->
  48. # (bs, num_heads, num_queries, num_levels, num_points) ->
  49. # (bs, num_heads, 1, num_queries, num_levels*num_points)
  50. attention_weights = attention_weights.transpose(1, 2).reshape(
  51. bs * num_heads, 1, num_queries, num_levels * num_points
  52. )
  53. output = (
  54. (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
  55. .sum(-1)
  56. .view(bs, num_heads * embed_dims, num_queries)
  57. )
  58. return output.transpose(1, 2).contiguous()
  59. class MSDeformableAttention(nn.Module):
  60. def __init__(self,
  61. embed_dim=256,
  62. num_heads=8,
  63. num_levels=4,
  64. num_points=4):
  65. """
  66. Multi-Scale Deformable Attention Module
  67. """
  68. super(MSDeformableAttention, self).__init__()
  69. self.embed_dim = embed_dim
  70. self.num_heads = num_heads
  71. self.num_levels = num_levels
  72. self.num_points = num_points
  73. self.total_points = num_heads * num_levels * num_points
  74. self.head_dim = embed_dim // num_heads
  75. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  76. self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
  77. self.attention_weights = nn.Linear(embed_dim, self.total_points)
  78. self.value_proj = nn.Linear(embed_dim, embed_dim)
  79. self.output_proj = nn.Linear(embed_dim, embed_dim)
  80. try:
  81. # use cuda op
  82. from deformable_detr_ops import ms_deformable_attn
  83. self.ms_deformable_attn_core = ms_deformable_attn
  84. except:
  85. # use torch func
  86. self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
  87. self._reset_parameters()
  88. def _reset_parameters(self):
  89. """
  90. Default initialization for Parameters of Module.
  91. """
  92. nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
  93. thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
  94. 2.0 * math.pi / self.num_heads
  95. )
  96. grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  97. grid_init = (
  98. (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
  99. .view(self.num_heads, 1, 1, 2)
  100. .repeat(1, self.num_levels, self.num_points, 1)
  101. )
  102. for i in range(self.num_points):
  103. grid_init[:, :, i, :] *= i + 1
  104. with torch.no_grad():
  105. self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
  106. # attention weight
  107. nn.init.constant_(self.attention_weights.weight, 0.0)
  108. nn.init.constant_(self.attention_weights.bias, 0.0)
  109. # proj
  110. nn.init.xavier_uniform_(self.value_proj.weight)
  111. nn.init.constant_(self.value_proj.bias, 0.0)
  112. nn.init.xavier_uniform_(self.output_proj.weight)
  113. nn.init.constant_(self.output_proj.bias, 0.0)
  114. def forward(self,
  115. query,
  116. reference_points,
  117. value,
  118. value_spatial_shapes,
  119. value_mask=None):
  120. """
  121. Args:
  122. query (Tensor): [bs, query_length, C]
  123. reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
  124. bottom-right (1, 1), including padding area
  125. value (Tensor): [bs, value_length, C]
  126. value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
  127. value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
  128. Returns:
  129. output (Tensor): [bs, Length_{query}, C]
  130. """
  131. bs, num_query = query.shape[:2]
  132. num_value = value.shape[1]
  133. assert sum([s[0] * s[1] for s in value_spatial_shapes]) == num_value
  134. # Value projection
  135. value = self.value_proj(value)
  136. # fill "0" for the padding part
  137. if value_mask is not None:
  138. value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
  139. value *= value_mask
  140. # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
  141. value = value.reshape([bs, num_value, self.num_heads, -1])
  142. # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
  143. sampling_offsets = self.sampling_offsets(query).reshape(
  144. [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
  145. # [bs, all_hw, num_head, nun_level*num_sample_point]
  146. attention_weights = self.attention_weights(query).reshape(
  147. [bs, num_query, self.num_heads, self.num_levels * self.num_points])
  148. # [bs, all_hw, num_head, nun_level, num_sample_point]
  149. attention_weights = attention_weights.softmax(-1).reshape(
  150. [bs, num_query, self.num_heads, self.num_levels, self.num_points])
  151. # [bs, num_query, num_heads, num_levels, num_points, 2]
  152. if reference_points.shape[-1] == 2:
  153. # reference_points [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
  154. # sampling_offsets [bs, all_hw, nun_head, num_level, num_sample_point, 2]
  155. # offset_normalizer [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
  156. # references_points + sampling_offsets
  157. offset_normalizer = value_spatial_shapes.flip([1]).reshape(
  158. [1, 1, 1, self.num_levels, 1, 2])
  159. sampling_locations = (
  160. reference_points[:, :, None, :, None, :]
  161. + sampling_offsets / offset_normalizer
  162. )
  163. elif reference_points.shape[-1] == 4:
  164. sampling_locations = (
  165. reference_points[:, :, None, :, None, :2]
  166. + sampling_offsets
  167. / self.num_points
  168. * reference_points[:, :, None, :, None, 2:]
  169. * 0.5)
  170. else:
  171. raise ValueError(
  172. "Last dim of reference_points must be 2 or 4, but get {} instead.".
  173. format(reference_points.shape[-1]))
  174. # Multi-scale Deformable attention
  175. output = self.ms_deformable_attn_core(
  176. value, value_spatial_shapes, sampling_locations, attention_weights)
  177. # Output project
  178. output = self.output_proj(output)
  179. return output
  180. # ----------------- Transformer modules -----------------
  181. ## Transformer Encoder layer
  182. class TransformerEncoderLayer(nn.Module):
  183. def __init__(self,
  184. d_model :int = 256,
  185. num_heads :int = 8,
  186. ffn_dim :int = 1024,
  187. dropout :float = 0.1,
  188. act_type :str = "relu",
  189. ):
  190. super().__init__()
  191. # ----------- Basic parameters -----------
  192. self.d_model = d_model
  193. self.num_heads = num_heads
  194. self.ffn_dim = ffn_dim
  195. self.dropout = dropout
  196. self.act_type = act_type
  197. # ----------- Basic parameters -----------
  198. # Multi-head Self-Attn
  199. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  200. self.dropout = nn.Dropout(dropout)
  201. self.norm = nn.LayerNorm(d_model)
  202. # Feedforwaed Network
  203. self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
  204. def with_pos_embed(self, tensor, pos):
  205. return tensor if pos is None else tensor + pos
  206. def forward(self, src, pos_embed):
  207. """
  208. Input:
  209. src: [torch.Tensor] -> [B, N, C]
  210. pos_embed: [torch.Tensor] -> [B, N, C]
  211. Output:
  212. src: [torch.Tensor] -> [B, N, C]
  213. """
  214. q = k = self.with_pos_embed(src, pos_embed)
  215. # -------------- MHSA --------------
  216. src2 = self.self_attn(q, k, value=src)[0]
  217. src = src + self.dropout(src2)
  218. src = self.norm(src)
  219. # -------------- FFN --------------
  220. src = self.ffn(src)
  221. return src
  222. ## Transformer Encoder
  223. class TransformerEncoder(nn.Module):
  224. def __init__(self,
  225. d_model :int = 256,
  226. num_heads :int = 8,
  227. num_layers :int = 1,
  228. ffn_dim :int = 1024,
  229. pe_temperature : float = 10000.,
  230. dropout :float = 0.1,
  231. act_type :str = "relu",
  232. ):
  233. super().__init__()
  234. # ----------- Basic parameters -----------
  235. self.d_model = d_model
  236. self.num_heads = num_heads
  237. self.num_layers = num_layers
  238. self.ffn_dim = ffn_dim
  239. self.dropout = dropout
  240. self.act_type = act_type
  241. self.pe_temperature = pe_temperature
  242. self.pos_embed = None
  243. # ----------- Basic parameters -----------
  244. self.encoder_layers = get_clones(
  245. TransformerEncoderLayer(d_model, num_heads, ffn_dim, dropout, act_type), num_layers)
  246. def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
  247. assert embed_dim % 4 == 0, \
  248. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  249. # ----------- Check cahed pos_embed -----------
  250. if self.pos_embed is not None and \
  251. self.pos_embed.shape[2:] == [h, w]:
  252. return self.pos_embed
  253. # ----------- Generate grid coords -----------
  254. grid_w = torch.arange(int(w), dtype=torch.float32)
  255. grid_h = torch.arange(int(h), dtype=torch.float32)
  256. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  257. pos_dim = embed_dim // 4
  258. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  259. omega = 1. / (temperature**omega)
  260. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  261. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  262. # shape: [1, N, C]
  263. pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
  264. pos_embed = pos_embed.to(device)
  265. self.pos_embed = pos_embed
  266. return pos_embed
  267. def forward(self, src):
  268. """
  269. Input:
  270. src: [torch.Tensor] -> [B, C, H, W]
  271. Output:
  272. src: [torch.Tensor] -> [B, C, H, W]
  273. """
  274. # -------- Transformer encoder --------
  275. channels, fmp_h, fmp_w = src.shape[1:]
  276. # [B, C, H, W] -> [B, N, C], N=HxW
  277. src_flatten = src.flatten(2).permute(0, 2, 1)
  278. memory = src_flatten
  279. # PosEmbed: [1, N, C]
  280. pos_embed = self.build_2d_sincos_position_embedding(
  281. src.device, fmp_w, fmp_h, channels, self.pe_temperature)
  282. # Transformer Encoder layer
  283. for encoder in self.encoder_layers:
  284. memory = encoder(memory, pos_embed=pos_embed)
  285. # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
  286. src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
  287. return src
  288. ## Transformer Decoder layer
  289. class DeformableTransformerDecoderLayer(nn.Module):
  290. def __init__(self,
  291. d_model :int = 256,
  292. num_heads :int = 8,
  293. num_levels :int = 3,
  294. num_points :int = 4,
  295. ffn_dim :int = 1024,
  296. dropout :float = 0.1,
  297. act_type :str = "relu",
  298. ):
  299. super().__init__()
  300. # ----------- Basic parameters -----------
  301. self.d_model = d_model
  302. self.num_heads = num_heads
  303. self.num_levels = num_levels
  304. self.num_points = num_points
  305. self.ffn_dim = ffn_dim
  306. self.dropout = dropout
  307. self.act_type = act_type
  308. # ---------------- Network parameters ----------------
  309. ## Multi-head Self-Attn
  310. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  311. self.dropout1 = nn.Dropout(dropout)
  312. self.norm1 = nn.LayerNorm(d_model)
  313. ## CrossAttention
  314. self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
  315. self.dropout2 = nn.Dropout(dropout)
  316. self.norm2 = nn.LayerNorm(d_model)
  317. ## FFN
  318. self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
  319. def with_pos_embed(self, tensor, pos):
  320. return tensor if pos is None else tensor + pos
  321. def forward(self,
  322. tgt,
  323. reference_points,
  324. memory,
  325. memory_spatial_shapes,
  326. attn_mask=None,
  327. memory_mask=None,
  328. query_pos_embed=None):
  329. # ---------------- MSHA for Object Query -----------------
  330. q = k = self.with_pos_embed(tgt, query_pos_embed)
  331. tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)[0]
  332. tgt = tgt + self.dropout1(tgt2)
  333. tgt = self.norm1(tgt)
  334. # ---------------- CMHA for Object Query and Image-feature -----------------
  335. tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
  336. reference_points,
  337. memory,
  338. memory_spatial_shapes,
  339. memory_mask)
  340. tgt = tgt + self.dropout2(tgt2)
  341. tgt = self.norm2(tgt)
  342. # ---------------- FeedForward Network -----------------
  343. tgt = self.ffn(tgt)
  344. return tgt
  345. ## Transformer Decoder
  346. class DeformableTransformerDecoder(nn.Module):
  347. def __init__(self,
  348. d_model :int = 256,
  349. num_heads :int = 8,
  350. num_layers :int = 1,
  351. num_levels :int = 3,
  352. num_points :int = 4,
  353. ffn_dim :int = 1024,
  354. dropout :float = 0.1,
  355. act_type :str = "relu",
  356. return_intermediate :bool = False,
  357. ):
  358. super().__init__()
  359. # ----------- Basic parameters -----------
  360. self.d_model = d_model
  361. self.num_heads = num_heads
  362. self.num_layers = num_layers
  363. self.ffn_dim = ffn_dim
  364. self.dropout = dropout
  365. self.act_type = act_type
  366. self.pos_embed = None
  367. # ----------- Network parameters -----------
  368. self.decoder_layers = get_clones(
  369. DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, ffn_dim, dropout, act_type), num_layers)
  370. self.num_layers = num_layers
  371. self.return_intermediate = return_intermediate
  372. def forward(self,
  373. tgt,
  374. ref_points_unact,
  375. memory,
  376. memory_spatial_shapes,
  377. bbox_head,
  378. score_head,
  379. query_pos_head,
  380. attn_mask=None,
  381. memory_mask=None):
  382. output = tgt
  383. dec_out_bboxes = []
  384. dec_out_logits = []
  385. ref_points_detach = F.sigmoid(ref_points_unact)
  386. for i, layer in enumerate(self.decoder_layers):
  387. ref_points_input = ref_points_detach.unsqueeze(2)
  388. query_pos_embed = query_pos_head(ref_points_detach)
  389. output = layer(output, ref_points_input, memory,
  390. memory_spatial_shapes, attn_mask,
  391. memory_mask, query_pos_embed)
  392. inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
  393. dec_out_logits.append(score_head[i](output))
  394. if i == 0:
  395. dec_out_bboxes.append(inter_ref_bbox)
  396. else:
  397. dec_out_bboxes.append(
  398. F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
  399. ref_points = inter_ref_bbox
  400. ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox
  401. return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)