attn.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. # ----------------- BoxRPM Cross Attention Ops -----------------
  6. class GlobalCrossAttention(nn.Module):
  7. def __init__(
  8. self,
  9. dim :int = 256,
  10. num_heads :int = 8,
  11. qkv_bias :bool = True,
  12. qk_scale :float = None,
  13. attn_drop :float = 0.0,
  14. proj_drop :float = 0.0,
  15. rpe_hidden_dim :int = 512,
  16. feature_stride :int = 16,
  17. ):
  18. super().__init__()
  19. # --------- Basic parameters ---------
  20. self.dim = dim
  21. self.num_heads = num_heads
  22. head_dim = dim // num_heads
  23. self.scale = qk_scale or head_dim ** -0.5
  24. self.feature_stride = feature_stride
  25. # --------- Network parameters ---------
  26. self.cpb_mlp1 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads)
  27. self.cpb_mlp2 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads)
  28. self.q = nn.Linear(dim, dim, bias=qkv_bias)
  29. self.k = nn.Linear(dim, dim, bias=qkv_bias)
  30. self.v = nn.Linear(dim, dim, bias=qkv_bias)
  31. self.attn_drop = nn.Dropout(attn_drop)
  32. self.proj = nn.Linear(dim, dim)
  33. self.proj_drop = nn.Dropout(proj_drop)
  34. self.softmax = nn.Softmax(dim=-1)
  35. def build_cpb_mlp(self, in_dim, hidden_dim, out_dim):
  36. cpb_mlp = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=True),
  37. nn.ReLU(inplace=True),
  38. nn.Linear(hidden_dim, out_dim, bias=False))
  39. return cpb_mlp
  40. def forward(self,
  41. query,
  42. reference_points,
  43. k_input_flatten,
  44. v_input_flatten,
  45. input_spatial_shapes,
  46. input_padding_mask=None,
  47. ):
  48. assert input_spatial_shapes.size(0) == 1, 'This is designed for single-scale decoder.'
  49. h, w = input_spatial_shapes[0]
  50. stride = self.feature_stride
  51. ref_pts = torch.cat([
  52. reference_points[:, :, :, :2] - reference_points[:, :, :, 2:] / 2,
  53. reference_points[:, :, :, :2] + reference_points[:, :, :, 2:] / 2,
  54. ], dim=-1) # B, nQ, 1, 4
  55. pos_x = torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=w.device)[None, None, :, None] * stride # 1, 1, w, 1
  56. pos_y = torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=h.device)[None, None, :, None] * stride # 1, 1, h, 1
  57. delta_x = ref_pts[..., 0::2] - pos_x # B, nQ, w, 2
  58. delta_y = ref_pts[..., 1::2] - pos_y # B, nQ, h, 2
  59. rpe_x, rpe_y = self.cpb_mlp1(delta_x), self.cpb_mlp2(delta_y) # B, nQ, w/h, nheads
  60. rpe = (rpe_x[:, :, None] + rpe_y[:, :, :, None]).flatten(2, 3) # B, nQ, h, w, nheads -> B, nQ, h*w, nheads
  61. rpe = rpe.permute(0, 3, 1, 2)
  62. B_, N, C = k_input_flatten.shape
  63. k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  64. v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  65. B_, N, C = query.shape
  66. q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  67. q = q * self.scale
  68. attn = q @ k.transpose(-2, -1)
  69. attn += rpe
  70. if input_padding_mask is not None:
  71. attn += input_padding_mask[:, None, None] * -100
  72. fmin, fmax = torch.finfo(attn.dtype).min, torch.finfo(attn.dtype).max
  73. torch.clip_(attn, min=fmin, max=fmax)
  74. attn = self.softmax(attn)
  75. attn = self.attn_drop(attn)
  76. x = attn @ v
  77. x = x.transpose(1, 2).reshape(B_, N, C)
  78. x = self.proj(x)
  79. x = self.proj_drop(x)
  80. return x