transformer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn.init import constant_, xavier_uniform_, uniform_
  7. try:
  8. from .basic import get_activation
  9. except:
  10. from basic import get_activation
  11. def get_clones(module, N):
  12. if N <= 0:
  13. return None
  14. else:
  15. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  16. def inverse_sigmoid(x, eps=1e-5):
  17. x = x.clamp(min=0., max=1.)
  18. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  19. # ----------------- MLP modules -----------------
  20. class MLP(nn.Module):
  21. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  22. super().__init__()
  23. self.num_layers = num_layers
  24. h = [hidden_dim] * (num_layers - 1)
  25. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  26. def forward(self, x):
  27. for i, layer in enumerate(self.layers):
  28. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  29. return x
  30. class FFN(nn.Module):
  31. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  32. super().__init__()
  33. self.fpn_dim = round(d_model * mlp_ratio)
  34. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  35. self.activation = get_activation(act_type)
  36. self.dropout2 = nn.Dropout(dropout)
  37. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  38. self.dropout3 = nn.Dropout(dropout)
  39. self.norm = nn.LayerNorm(d_model)
  40. def forward(self, src):
  41. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  42. src = src + self.dropout3(src2)
  43. src = self.norm(src)
  44. return src
  45. # ----------------- Basic Transformer Ops -----------------
  46. def multi_scale_deformable_attn_pytorch(
  47. value: torch.Tensor,
  48. value_spatial_shapes: torch.Tensor,
  49. sampling_locations: torch.Tensor,
  50. attention_weights: torch.Tensor,
  51. ) -> torch.Tensor:
  52. bs, _, num_heads, embed_dims = value.shape
  53. _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
  54. value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
  55. sampling_grids = 2 * sampling_locations - 1
  56. sampling_value_list = []
  57. for level, (H_, W_) in enumerate(value_spatial_shapes):
  58. # bs, H_*W_, num_heads, embed_dims ->
  59. # bs, H_*W_, num_heads*embed_dims ->
  60. # bs, num_heads*embed_dims, H_*W_ ->
  61. # bs*num_heads, embed_dims, H_, W_
  62. value_l_ = (
  63. value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
  64. )
  65. # bs, num_queries, num_heads, num_points, 2 ->
  66. # bs, num_heads, num_queries, num_points, 2 ->
  67. # bs*num_heads, num_queries, num_points, 2
  68. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
  69. # bs*num_heads, embed_dims, num_queries, num_points
  70. sampling_value_l_ = F.grid_sample(
  71. value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
  72. )
  73. sampling_value_list.append(sampling_value_l_)
  74. # (bs, num_queries, num_heads, num_levels, num_points) ->
  75. # (bs, num_heads, num_queries, num_levels, num_points) ->
  76. # (bs, num_heads, 1, num_queries, num_levels*num_points)
  77. attention_weights = attention_weights.transpose(1, 2).reshape(
  78. bs * num_heads, 1, num_queries, num_levels * num_points
  79. )
  80. output = (
  81. (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
  82. .sum(-1)
  83. .view(bs, num_heads * embed_dims, num_queries)
  84. )
  85. return output.transpose(1, 2).contiguous()
  86. class MSDeformableAttention(nn.Module):
  87. def __init__(self,
  88. embed_dim=256,
  89. num_heads=8,
  90. num_levels=4,
  91. num_points=4):
  92. """
  93. Multi-Scale Deformable Attention Module
  94. """
  95. super(MSDeformableAttention, self).__init__()
  96. self.embed_dim = embed_dim
  97. self.num_heads = num_heads
  98. self.num_levels = num_levels
  99. self.num_points = num_points
  100. self.total_points = num_heads * num_levels * num_points
  101. self.head_dim = embed_dim // num_heads
  102. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  103. self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
  104. self.attention_weights = nn.Linear(embed_dim, self.total_points)
  105. self.value_proj = nn.Linear(embed_dim, embed_dim)
  106. self.output_proj = nn.Linear(embed_dim, embed_dim)
  107. try:
  108. # use cuda op
  109. from deformable_detr_ops import ms_deformable_attn
  110. self.ms_deformable_attn_core = ms_deformable_attn
  111. except:
  112. # use torch func
  113. self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
  114. self._reset_parameters()
  115. def _reset_parameters(self):
  116. """
  117. Default initialization for Parameters of Module.
  118. """
  119. constant_(self.sampling_offsets.weight.data, 0.0)
  120. thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
  121. 2.0 * math.pi / self.num_heads
  122. )
  123. grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  124. grid_init = (
  125. (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
  126. .view(self.num_heads, 1, 1, 2)
  127. .repeat(1, self.num_levels, self.num_points, 1)
  128. )
  129. for i in range(self.num_points):
  130. grid_init[:, :, i, :] *= i + 1
  131. with torch.no_grad():
  132. self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
  133. constant_(self.attention_weights.weight.data, 0.0)
  134. constant_(self.attention_weights.bias.data, 0.0)
  135. xavier_uniform_(self.value_proj.weight.data)
  136. constant_(self.value_proj.bias.data, 0.0)
  137. xavier_uniform_(self.output_proj.weight.data)
  138. constant_(self.output_proj.bias.data, 0.0)
  139. def forward(self,
  140. query,
  141. reference_points,
  142. value,
  143. value_spatial_shapes,
  144. value_mask=None):
  145. """
  146. Args:
  147. query (Tensor): [bs, query_length, C]
  148. reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
  149. bottom-right (1, 1), including padding area
  150. value (Tensor): [bs, value_length, C]
  151. value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
  152. value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
  153. Returns:
  154. output (Tensor): [bs, Length_{query}, C]
  155. """
  156. bs, num_query = query.shape[:2]
  157. num_value = value.shape[1]
  158. assert sum([s[0] * s[1] for s in value_spatial_shapes]) == num_value
  159. # Value projection
  160. value = self.value_proj(value)
  161. # fill "0" for the padding part
  162. if value_mask is not None:
  163. value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
  164. value *= value_mask
  165. # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
  166. value = value.reshape([bs, num_value, self.num_heads, -1])
  167. # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
  168. sampling_offsets = self.sampling_offsets(query).reshape(
  169. [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
  170. # [bs, all_hw, num_head, nun_level*num_sample_point]
  171. attention_weights = self.attention_weights(query).reshape(
  172. [bs, num_query, self.num_heads, self.num_levels * self.num_points])
  173. attention_weights = attention_weights.softmax(-1)
  174. # [bs, all_hw, num_head, nun_level, num_sample_point]
  175. attention_weights = attention_weights.reshape(
  176. [bs, num_query, self.num_heads, self.num_levels, self.num_points])
  177. # [bs, num_query, num_heads, num_levels, num_points, 2]
  178. if reference_points.shape[-1] == 2:
  179. # reference_points [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
  180. # sampling_offsets [bs, all_hw, nun_head, num_level, num_sample_point, 2]
  181. # offset_normalizer [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
  182. # references_points + sampling_offsets
  183. offset_normalizer = value_spatial_shapes.flip([1]).reshape(
  184. [1, 1, 1, self.num_levels, 1, 2])
  185. sampling_locations = (
  186. reference_points[:, :, None, :, None, :]
  187. + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
  188. )
  189. elif reference_points.shape[-1] == 4:
  190. sampling_locations = (
  191. reference_points[:, :, None, :, None, :2]
  192. + sampling_offsets
  193. / self.num_points
  194. * reference_points[:, :, None, :, None, 2:]
  195. * 0.5)
  196. else:
  197. raise ValueError(
  198. "Last dim of reference_points must be 2 or 4, but get {} instead.".
  199. format(reference_points.shape[-1]))
  200. # Multi-scale Deformable attention
  201. output = self.ms_deformable_attn_core(
  202. value, value_spatial_shapes, sampling_locations, attention_weights)
  203. # Output project
  204. output = self.output_proj(output)
  205. return output
  206. # ----------------- Transformer modules -----------------
  207. ## Transformer Encoder layer
  208. class TransformerEncoderLayer(nn.Module):
  209. def __init__(self,
  210. d_model :int = 256,
  211. num_heads :int = 8,
  212. mlp_ratio :float = 4.0,
  213. dropout :float = 0.1,
  214. act_type :str = "relu",
  215. ):
  216. super().__init__()
  217. # ----------- Basic parameters -----------
  218. self.d_model = d_model
  219. self.num_heads = num_heads
  220. self.mlp_ratio = mlp_ratio
  221. self.dropout = dropout
  222. self.act_type = act_type
  223. # ----------- Basic parameters -----------
  224. # Multi-head Self-Attn
  225. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  226. self.dropout = nn.Dropout(dropout)
  227. self.norm = nn.LayerNorm(d_model)
  228. # Feedforwaed Network
  229. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  230. self._reset_parameters()
  231. def with_pos_embed(self, tensor, pos):
  232. return tensor if pos is None else tensor + pos
  233. def _reset_parameters(self):
  234. def linear_init_(module):
  235. bound = 1 / math.sqrt(module.weight.shape[0])
  236. uniform_(module.weight, -bound, bound)
  237. if hasattr(module, "bias") and module.bias is not None:
  238. uniform_(module.bias, -bound, bound)
  239. linear_init_(self.ffn.linear1)
  240. linear_init_(self.ffn.linear2)
  241. def forward(self, src, pos_embed):
  242. """
  243. Input:
  244. src: [torch.Tensor] -> [B, N, C]
  245. pos_embed: [torch.Tensor] -> [B, N, C]
  246. Output:
  247. src: [torch.Tensor] -> [B, N, C]
  248. """
  249. q = k = self.with_pos_embed(src, pos_embed)
  250. # -------------- MHSA --------------
  251. src2 = self.self_attn(q, k, value=src)[0]
  252. src = src + self.dropout(src2)
  253. src = self.norm(src)
  254. # -------------- FFN --------------
  255. src = self.ffn(src)
  256. return src
  257. ## Transformer Encoder
  258. class TransformerEncoder(nn.Module):
  259. def __init__(self,
  260. d_model :int = 256,
  261. num_heads :int = 8,
  262. num_layers :int = 1,
  263. mlp_ratio :float = 4.0,
  264. pe_temperature : float = 10000.,
  265. dropout :float = 0.1,
  266. act_type :str = "relu",
  267. ):
  268. super().__init__()
  269. # ----------- Basic parameters -----------
  270. self.d_model = d_model
  271. self.num_heads = num_heads
  272. self.num_layers = num_layers
  273. self.mlp_ratio = mlp_ratio
  274. self.dropout = dropout
  275. self.act_type = act_type
  276. self.pe_temperature = pe_temperature
  277. self.pos_embed = None
  278. # ----------- Basic parameters -----------
  279. self.encoder_layers = get_clones(
  280. TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
  281. def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
  282. assert embed_dim % 4 == 0, \
  283. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  284. # ----------- Check cahed pos_embed -----------
  285. if self.pos_embed is not None and \
  286. self.pos_embed.shape[2:] == [h, w]:
  287. return self.pos_embed
  288. # ----------- Generate grid coords -----------
  289. grid_w = torch.arange(int(w), dtype=torch.float32)
  290. grid_h = torch.arange(int(h), dtype=torch.float32)
  291. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  292. pos_dim = embed_dim // 4
  293. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  294. omega = 1. / (temperature**omega)
  295. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  296. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  297. # shape: [1, N, C]
  298. pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
  299. pos_embed = pos_embed.to(device)
  300. self.pos_embed = pos_embed
  301. return pos_embed
  302. def forward(self, src):
  303. """
  304. Input:
  305. src: [torch.Tensor] -> [B, C, H, W]
  306. Output:
  307. src: [torch.Tensor] -> [B, C, H, W]
  308. """
  309. # -------- Transformer encoder --------
  310. channels, fmp_h, fmp_w = src.shape[1:]
  311. # [B, C, H, W] -> [B, N, C], N=HxW
  312. src_flatten = src.flatten(2).permute(0, 2, 1)
  313. memory = src_flatten
  314. # PosEmbed: [1, N, C]
  315. pos_embed = self.build_2d_sincos_position_embedding(
  316. src.device, fmp_w, fmp_h, channels, self.pe_temperature)
  317. # Transformer Encoder layer
  318. for encoder in self.encoder_layers:
  319. memory = encoder(memory, pos_embed=pos_embed)
  320. # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
  321. src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
  322. return src
  323. ## Transformer Decoder layer
  324. class DeformableTransformerDecoderLayer(nn.Module):
  325. def __init__(self,
  326. d_model :int = 256,
  327. num_heads :int = 8,
  328. num_levels :int = 3,
  329. num_points :int = 4,
  330. mlp_ratio :float = 4.0,
  331. dropout :float = 0.1,
  332. act_type :str = "relu",
  333. ):
  334. super().__init__()
  335. # ----------- Basic parameters -----------
  336. self.d_model = d_model
  337. self.num_heads = num_heads
  338. self.num_levels = num_levels
  339. self.num_points = num_points
  340. self.mlp_ratio = mlp_ratio
  341. self.dropout = dropout
  342. self.act_type = act_type
  343. # ---------------- Network parameters ----------------
  344. ## Multi-head Self-Attn
  345. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
  346. self.dropout1 = nn.Dropout(dropout)
  347. self.norm1 = nn.LayerNorm(d_model)
  348. ## CrossAttention
  349. self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
  350. self.dropout2 = nn.Dropout(dropout)
  351. self.norm2 = nn.LayerNorm(d_model)
  352. ## FFN
  353. self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
  354. self._reset_parameters()
  355. def with_pos_embed(self, tensor, pos):
  356. return tensor if pos is None else tensor + pos
  357. def _reset_parameters(self):
  358. def linear_init_(module):
  359. bound = 1 / math.sqrt(module.weight.shape[0])
  360. uniform_(module.weight, -bound, bound)
  361. if hasattr(module, "bias") and module.bias is not None:
  362. uniform_(module.bias, -bound, bound)
  363. linear_init_(self.ffn.linear1)
  364. linear_init_(self.ffn.linear2)
  365. xavier_uniform_(self.ffn.linear1.weight)
  366. xavier_uniform_(self.ffn.linear2.weight)
  367. def forward(self,
  368. tgt,
  369. reference_points,
  370. memory,
  371. memory_spatial_shapes,
  372. attn_mask=None,
  373. memory_mask=None,
  374. query_pos_embed=None):
  375. # ---------------- MSHA for Object Query -----------------
  376. q = k = self.with_pos_embed(tgt, query_pos_embed)
  377. if attn_mask is not None:
  378. attn_mask = torch.where(
  379. attn_mask.bool(),
  380. torch.zeros(attn_mask.shape, dtype=tgt.dtype, device=attn_mask.device),
  381. torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype, device=attn_mask.device))
  382. tgt2 = self.self_attn(q, k, value=tgt)[0]
  383. tgt = tgt + self.dropout1(tgt2)
  384. tgt = self.norm1(tgt)
  385. # ---------------- CMHA for Object Query and Image-feature -----------------
  386. tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
  387. reference_points,
  388. memory,
  389. memory_spatial_shapes,
  390. memory_mask)
  391. tgt = tgt + self.dropout2(tgt2)
  392. tgt = self.norm2(tgt)
  393. # ---------------- FeedForward Network -----------------
  394. tgt = self.ffn(tgt)
  395. return tgt
  396. ## Transformer Decoder
  397. class DeformableTransformerDecoder(nn.Module):
  398. def __init__(self,
  399. d_model :int = 256,
  400. num_heads :int = 8,
  401. num_layers :int = 1,
  402. num_levels :int = 3,
  403. num_points :int = 4,
  404. mlp_ratio :float = 4.0,
  405. dropout :float = 0.1,
  406. act_type :str = "relu",
  407. return_intermediate :bool = False,
  408. ):
  409. super().__init__()
  410. # ----------- Basic parameters -----------
  411. self.d_model = d_model
  412. self.num_heads = num_heads
  413. self.num_layers = num_layers
  414. self.mlp_ratio = mlp_ratio
  415. self.dropout = dropout
  416. self.act_type = act_type
  417. self.pos_embed = None
  418. # ----------- Network parameters -----------
  419. self.decoder_layers = get_clones(
  420. DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
  421. self.num_layers = num_layers
  422. self.return_intermediate = return_intermediate
  423. def forward(self,
  424. tgt,
  425. ref_points_unact,
  426. memory,
  427. memory_spatial_shapes,
  428. bbox_head,
  429. score_head,
  430. query_pos_head,
  431. attn_mask=None,
  432. memory_mask=None):
  433. output = tgt
  434. dec_out_bboxes = []
  435. dec_out_logits = []
  436. ref_points_detach = F.sigmoid(ref_points_unact)
  437. for i, layer in enumerate(self.decoder_layers):
  438. ref_points_input = ref_points_detach.unsqueeze(2)
  439. query_pos_embed = query_pos_head(ref_points_detach)
  440. output = layer(output, ref_points_input, memory,
  441. memory_spatial_shapes, attn_mask,
  442. memory_mask, query_pos_embed)
  443. inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  444. ref_points_detach))
  445. dec_out_logits.append(score_head[i](output))
  446. if i == 0:
  447. dec_out_bboxes.append(inter_ref_bbox)
  448. else:
  449. dec_out_bboxes.append(
  450. F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
  451. ref_points)))
  452. ref_points = inter_ref_bbox
  453. ref_points_detach = inter_ref_bbox.detach()
  454. return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)