| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- # ----------------- BoxRPM Cross Attention Ops -----------------
- class GlobalCrossAttention(nn.Module):
- def __init__(
- self,
- dim :int = 256,
- num_heads :int = 8,
- qkv_bias :bool = True,
- qk_scale :float = None,
- attn_drop :float = 0.0,
- proj_drop :float = 0.0,
- rpe_hidden_dim :int = 512,
- feature_stride :int = 16,
- ):
- super().__init__()
- # --------- Basic parameters ---------
- self.dim = dim
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
- self.feature_stride = feature_stride
- # --------- Network parameters ---------
- self.cpb_mlp1 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads)
- self.cpb_mlp2 = self.build_cpb_mlp(2, rpe_hidden_dim, num_heads)
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
- self.k = nn.Linear(dim, dim, bias=qkv_bias)
- self.v = nn.Linear(dim, dim, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
- def build_cpb_mlp(self, in_dim, hidden_dim, out_dim):
- cpb_mlp = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(hidden_dim, out_dim, bias=False))
- return cpb_mlp
- def forward(self,
- query,
- reference_points,
- k_input_flatten,
- v_input_flatten,
- input_spatial_shapes,
- input_padding_mask=None,
- ):
- assert input_spatial_shapes.size(0) == 1, 'This is designed for single-scale decoder.'
- h, w = input_spatial_shapes[0]
- stride = self.feature_stride
- ref_pts = torch.cat([
- reference_points[:, :, :, :2] - reference_points[:, :, :, 2:] / 2,
- reference_points[:, :, :, :2] + reference_points[:, :, :, 2:] / 2,
- ], dim=-1) # B, nQ, 1, 4
- pos_x = torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=w.device)[None, None, :, None] * stride # 1, 1, w, 1
- pos_y = torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=h.device)[None, None, :, None] * stride # 1, 1, h, 1
- delta_x = ref_pts[..., 0::2] - pos_x # B, nQ, w, 2
- delta_y = ref_pts[..., 1::2] - pos_y # B, nQ, h, 2
- rpe_x, rpe_y = self.cpb_mlp1(delta_x), self.cpb_mlp2(delta_y) # B, nQ, w/h, nheads
- rpe = (rpe_x[:, :, None] + rpe_y[:, :, :, None]).flatten(2, 3) # B, nQ, h, w, nheads -> B, nQ, h*w, nheads
- rpe = rpe.permute(0, 3, 1, 2)
- B_, N, C = k_input_flatten.shape
- k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
- v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
- B_, N, C = query.shape
- q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn += rpe
- if input_padding_mask is not None:
- attn += input_padding_mask[:, None, None] * -100
- fmin, fmax = torch.finfo(attn.dtype).min, torch.finfo(attn.dtype).max
- torch.clip_(attn, min=fmin, max=fmax)
- attn = self.softmax(attn)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
|