transformer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import math
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torch.utils.checkpoint as checkpoint
  7. try:
  8. from .basic import FFN, GlobalCrossAttention
  9. from .basic import trunc_normal_
  10. except:
  11. from basic import FFN, GlobalCrossAttention
  12. from basic import trunc_normal_
  13. def get_clones(module, N):
  14. if N <= 0:
  15. return None
  16. else:
  17. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  18. def inverse_sigmoid(x, eps=1e-5):
  19. x = x.clamp(min=0., max=1.)
  20. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  21. # ----------------- Transformer modules -----------------
  22. ## Transformer Encoder layer
  23. class TransformerEncoderLayer(nn.Module):
  24. def __init__(self,
  25. d_model :int = 256,
  26. num_heads :int = 8,
  27. ffn_dim :int = 1024,
  28. dropout :float = 0.1,
  29. act_type :str = "relu",
  30. ):
  31. super().__init__()
  32. # ----------- Basic parameters -----------
  33. self.d_model = d_model
  34. self.num_heads = num_heads
  35. self.ffn_dim = ffn_dim
  36. self.dropout = dropout
  37. self.act_type = act_type
  38. # ----------- Basic parameters -----------
  39. # Multi-head Self-Attn
  40. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
  41. self.dropout = nn.Dropout(dropout)
  42. self.norm = nn.LayerNorm(d_model)
  43. # Feedforwaed Network
  44. self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
  45. def with_pos_embed(self, tensor, pos):
  46. return tensor if pos is None else tensor + pos
  47. def forward(self, src, pos_embed):
  48. """
  49. Input:
  50. src: [torch.Tensor] -> [B, N, C]
  51. pos_embed: [torch.Tensor] -> [B, N, C]
  52. Output:
  53. src: [torch.Tensor] -> [B, N, C]
  54. """
  55. q = k = self.with_pos_embed(src, pos_embed)
  56. # -------------- MHSA --------------
  57. src2 = self.self_attn(q, k, value=src)[0]
  58. src = src + self.dropout(src2)
  59. src = self.norm(src)
  60. # -------------- FFN --------------
  61. src = self.ffn(src)
  62. return src
  63. ## Transformer Encoder
  64. class TransformerEncoder(nn.Module):
  65. def __init__(self,
  66. d_model :int = 256,
  67. num_heads :int = 8,
  68. num_layers :int = 1,
  69. ffn_dim :int = 1024,
  70. pe_temperature : float = 10000.,
  71. dropout :float = 0.1,
  72. act_type :str = "relu",
  73. ):
  74. super().__init__()
  75. # ----------- Basic parameters -----------
  76. self.d_model = d_model
  77. self.num_heads = num_heads
  78. self.num_layers = num_layers
  79. self.ffn_dim = ffn_dim
  80. self.dropout = dropout
  81. self.act_type = act_type
  82. self.pe_temperature = pe_temperature
  83. self.pos_embed = None
  84. # ----------- Basic parameters -----------
  85. self.encoder_layers = get_clones(
  86. TransformerEncoderLayer(d_model, num_heads, ffn_dim, dropout, act_type), num_layers)
  87. def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
  88. assert embed_dim % 4 == 0, \
  89. 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
  90. # ----------- Check cahed pos_embed -----------
  91. if self.pos_embed is not None and \
  92. self.pos_embed.shape[2:] == [h, w]:
  93. return self.pos_embed
  94. # ----------- Generate grid coords -----------
  95. grid_w = torch.arange(int(w), dtype=torch.float32)
  96. grid_h = torch.arange(int(h), dtype=torch.float32)
  97. grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
  98. pos_dim = embed_dim // 4
  99. omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
  100. omega = 1. / (temperature**omega)
  101. out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
  102. out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
  103. # shape: [1, N, C]
  104. pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
  105. pos_embed = pos_embed.to(device)
  106. self.pos_embed = pos_embed
  107. return pos_embed
  108. def forward(self, src):
  109. """
  110. Input:
  111. src: [torch.Tensor] -> [B, C, H, W]
  112. Output:
  113. src: [torch.Tensor] -> [B, C, H, W]
  114. """
  115. # -------- Transformer encoder --------
  116. channels, fmp_h, fmp_w = src.shape[1:]
  117. # [B, C, H, W] -> [B, N, C], N=HxW
  118. src_flatten = src.flatten(2).permute(0, 2, 1)
  119. memory = src_flatten
  120. # PosEmbed: [1, N, C]
  121. pos_embed = self.build_2d_sincos_position_embedding(
  122. src.device, fmp_w, fmp_h, channels, self.pe_temperature)
  123. # Transformer Encoder layer
  124. for encoder in self.encoder_layers:
  125. memory = encoder(memory, pos_embed=pos_embed)
  126. # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
  127. src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
  128. return src
  129. ## PlainDETR's Decoder layer
  130. class GlobalDecoderLayer(nn.Module):
  131. def __init__(self,
  132. d_model :int = 256,
  133. num_heads :int = 8,
  134. ffn_dim :int = 1024,
  135. dropout :float = 0.1,
  136. act_type :str = "relu",
  137. pre_norm :bool = False,
  138. rpe_hidden_dim :int = 512,
  139. feature_stride :int = 16,
  140. ) -> None:
  141. super().__init__()
  142. # ------------ Basic parameters ------------
  143. self.d_model = d_model
  144. self.num_heads = num_heads
  145. self.rpe_hidden_dim = rpe_hidden_dim
  146. self.ffn_dim = ffn_dim
  147. self.act_type = act_type
  148. self.pre_norm = pre_norm
  149. # ------------ Network parameters ------------
  150. ## Multi-head Self-Attn
  151. self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
  152. self.dropout1 = nn.Dropout(dropout)
  153. self.norm1 = nn.LayerNorm(d_model)
  154. ## Box-reparam Global Cross-Attn
  155. self.cross_attn = GlobalCrossAttention(d_model, num_heads, rpe_hidden_dim=rpe_hidden_dim, feature_stride=feature_stride)
  156. self.dropout2 = nn.Dropout(dropout)
  157. self.norm2 = nn.LayerNorm(d_model)
  158. ## FFN
  159. self.ffn = FFN(d_model, ffn_dim, dropout, act_type, pre_norm)
  160. @staticmethod
  161. def with_pos_embed(tensor, pos):
  162. return tensor if pos is None else tensor + pos
  163. def forward_pre_norm(self,
  164. tgt,
  165. query_pos,
  166. reference_points,
  167. src,
  168. src_pos_embed,
  169. src_spatial_shapes,
  170. src_padding_mask=None,
  171. self_attn_mask=None,
  172. ):
  173. # ----------- Multi-head self attention -----------
  174. tgt1 = self.norm1(tgt)
  175. q = k = self.with_pos_embed(tgt1, query_pos)
  176. tgt1 = self.self_attn(q.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
  177. k.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
  178. tgt1.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
  179. attn_mask=self_attn_mask,
  180. )[0].transpose(0, 1) # [N, B, C] -> [B, N, C]
  181. tgt = tgt + self.dropout1(tgt1)
  182. # ----------- Global corss attention -----------
  183. tgt1 = self.norm2(tgt)
  184. tgt1 = self.cross_attn(self.with_pos_embed(tgt1, query_pos),
  185. reference_points,
  186. self.with_pos_embed(src, src_pos_embed),
  187. src,
  188. src_spatial_shapes,
  189. src_padding_mask,
  190. )
  191. tgt = tgt + self.dropout2(tgt1)
  192. # ----------- FeedForward Network -----------
  193. tgt = self.ffn(tgt)
  194. return tgt
  195. def forward_post_norm(self,
  196. tgt,
  197. query_pos,
  198. reference_points,
  199. src,
  200. src_pos_embed,
  201. src_spatial_shapes,
  202. src_padding_mask=None,
  203. self_attn_mask=None,
  204. ):
  205. # ----------- Multi-head self attention -----------
  206. q = k = self.with_pos_embed(tgt, query_pos)
  207. tgt1 = self.self_attn(q.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
  208. k.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
  209. tgt.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
  210. attn_mask=self_attn_mask,
  211. )[0].transpose(0, 1) # [N, B, C] -> [B, N, C]
  212. tgt = tgt + self.dropout1(tgt1)
  213. tgt = self.norm1(tgt)
  214. # ----------- Global corss attention -----------
  215. tgt1 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
  216. reference_points,
  217. self.with_pos_embed(src, src_pos_embed),
  218. src,
  219. src_spatial_shapes,
  220. src_padding_mask,
  221. )
  222. tgt = tgt + self.dropout2(tgt1)
  223. tgt = self.norm2(tgt)
  224. # ----------- FeedForward Network -----------
  225. tgt = self.ffn(tgt)
  226. return tgt
  227. def forward(self,
  228. tgt,
  229. query_pos,
  230. reference_points,
  231. src,
  232. src_pos_embed,
  233. src_spatial_shapes,
  234. src_padding_mask=None,
  235. self_attn_mask=None,
  236. ):
  237. if self.pre_norm:
  238. return self.forward_pre_norm(tgt, query_pos, reference_points, src, src_pos_embed, src_spatial_shapes,
  239. src_padding_mask, self_attn_mask)
  240. else:
  241. return self.forward_post_norm(tgt, query_pos, reference_points, src, src_pos_embed, src_spatial_shapes,
  242. src_padding_mask, self_attn_mask)
  243. ## PlainDETR's Decoder
  244. class GlobalDecoder(nn.Module):
  245. def __init__(self,
  246. # Decoder layer params
  247. d_model :int = 256,
  248. num_heads :int = 8,
  249. ffn_dim :int = 1024,
  250. dropout :float = 0.1,
  251. act_type :str = "relu",
  252. pre_norm :bool = False,
  253. rpe_hidden_dim :int = 512,
  254. feature_stride :int = 16,
  255. num_layers :int = 6,
  256. # Decoder params
  257. return_intermediate :bool = False,
  258. use_checkpoint :bool = False,
  259. ):
  260. super().__init__()
  261. # ------------ Basic parameters ------------
  262. self.d_model = d_model
  263. self.num_heads = num_heads
  264. self.rpe_hidden_dim = rpe_hidden_dim
  265. self.ffn_dim = ffn_dim
  266. self.act_type = act_type
  267. self.num_layers = num_layers
  268. self.return_intermediate = return_intermediate
  269. self.use_checkpoint = use_checkpoint
  270. # ------------ Network parameters ------------
  271. decoder_layer = GlobalDecoderLayer(
  272. d_model, num_heads, ffn_dim, dropout, act_type, pre_norm, rpe_hidden_dim, feature_stride,)
  273. self.layers = get_clones(decoder_layer, num_layers)
  274. self.bbox_embed = None
  275. self.class_embed = None
  276. if pre_norm:
  277. self.final_layer_norm = nn.LayerNorm(d_model)
  278. else:
  279. self.final_layer_norm = None
  280. def _reset_parameters(self):
  281. # stolen from Swin Transformer
  282. def _init_weights(m):
  283. if isinstance(m, nn.Linear):
  284. trunc_normal_(m.weight, std=0.02)
  285. if isinstance(m, nn.Linear) and m.bias is not None:
  286. nn.init.constant_(m.bias, 0)
  287. elif isinstance(m, nn.LayerNorm):
  288. nn.init.constant_(m.bias, 0)
  289. nn.init.constant_(m.weight, 1.0)
  290. self.apply(_init_weights)
  291. def inverse_sigmoid(self, x, eps=1e-5):
  292. x = x.clamp(min=0, max=1)
  293. x1 = x.clamp(min=eps)
  294. x2 = (1 - x).clamp(min=eps)
  295. return torch.log(x1 / x2)
  296. def box_xyxy_to_cxcywh(self, x):
  297. x0, y0, x1, y1 = x.unbind(-1)
  298. b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
  299. return torch.stack(b, dim=-1)
  300. def delta2bbox(self, proposals,
  301. deltas,
  302. max_shape=None,
  303. wh_ratio_clip=16 / 1000,
  304. clip_border=True,
  305. add_ctr_clamp=False,
  306. ctr_clamp=32):
  307. dxy = deltas[..., :2]
  308. dwh = deltas[..., 2:]
  309. # Compute width/height of each roi
  310. pxy = proposals[..., :2]
  311. pwh = proposals[..., 2:]
  312. dxy_wh = pwh * dxy
  313. wh_ratio_clip = torch.as_tensor(wh_ratio_clip)
  314. max_ratio = torch.abs(torch.log(wh_ratio_clip)).item()
  315. if add_ctr_clamp:
  316. dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
  317. dwh = torch.clamp(dwh, max=max_ratio)
  318. else:
  319. dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
  320. gxy = pxy + dxy_wh
  321. gwh = pwh * dwh.exp()
  322. x1y1 = gxy - (gwh * 0.5)
  323. x2y2 = gxy + (gwh * 0.5)
  324. bboxes = torch.cat([x1y1, x2y2], dim=-1)
  325. if clip_border and max_shape is not None:
  326. bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1])
  327. bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0])
  328. return bboxes
  329. def forward(self,
  330. tgt,
  331. reference_points,
  332. src,
  333. src_pos_embed,
  334. src_spatial_shapes,
  335. query_pos=None,
  336. src_padding_mask=None,
  337. self_attn_mask=None,
  338. max_shape=None,
  339. ):
  340. output = tgt
  341. intermediate = []
  342. intermediate_reference_points = []
  343. for lid, layer in enumerate(self.layers):
  344. reference_points_input = reference_points[:, :, None]
  345. if self.use_checkpoint:
  346. output = checkpoint.checkpoint(
  347. layer,
  348. output,
  349. query_pos,
  350. reference_points_input,
  351. src,
  352. src_pos_embed,
  353. src_spatial_shapes,
  354. src_padding_mask,
  355. self_attn_mask,
  356. )
  357. else:
  358. output = layer(
  359. output,
  360. query_pos,
  361. reference_points_input,
  362. src,
  363. src_pos_embed,
  364. src_spatial_shapes,
  365. src_padding_mask,
  366. self_attn_mask,
  367. )
  368. if self.final_layer_norm is not None:
  369. output_after_norm = self.final_layer_norm(output)
  370. else:
  371. output_after_norm = output
  372. # hack implementation for iterative bounding box refinement
  373. if self.bbox_embed is not None:
  374. tmp = self.bbox_embed[lid](output_after_norm)
  375. new_reference_points = self.box_xyxy_to_cxcywh(
  376. self.delta2bbox(reference_points, tmp, max_shape))
  377. reference_points = new_reference_points.detach()
  378. if self.return_intermediate:
  379. intermediate.append(output_after_norm)
  380. intermediate_reference_points.append(new_reference_points)
  381. if self.return_intermediate:
  382. return torch.stack(intermediate), torch.stack(intermediate_reference_points)
  383. return output_after_norm, reference_points