|
@@ -118,28 +118,32 @@ class SCDown(nn.Module):
|
|
|
class Attention(nn.Module):
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, dim, num_heads=8, attn_ratio=0.5):
|
|
def __init__(self, dim, num_heads=8, attn_ratio=0.5):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
- self.num_heads = num_heads
|
|
|
|
|
- self.head_dim = dim // num_heads
|
|
|
|
|
- self.key_dim = int(self.head_dim * attn_ratio)
|
|
|
|
|
|
|
+ self.num_heads = num_heads # number of the attention heads
|
|
|
|
|
+ self.head_dim = dim // num_heads # per head dim of v
|
|
|
|
|
+ self.key_dim = int(self.head_dim * attn_ratio) # per head dim of qk
|
|
|
self.scale = self.key_dim**-0.5
|
|
self.scale = self.key_dim**-0.5
|
|
|
|
|
|
|
|
- nh_kd = self.key_dim * num_heads
|
|
|
|
|
- h = dim + nh_kd * 2
|
|
|
|
|
- self.qkv = ConvModule(dim, h, kernel_size=1, use_act=False)
|
|
|
|
|
- self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False)
|
|
|
|
|
- self.pe = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False)
|
|
|
|
|
|
|
+ qkv_dims = dim + self.key_dim * num_heads * 2 # total dims of qkv
|
|
|
|
|
+ self.qkv = ConvModule(dim, qkv_dims, kernel_size=1, use_act=False) # qkv projection
|
|
|
|
|
+ self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False) # output projection
|
|
|
|
|
+ self.pe = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False) # position embedding conv
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
def forward(self, x):
|
|
|
bs, c, h, w = x.shape
|
|
bs, c, h, w = x.shape
|
|
|
seq_len = h * w
|
|
seq_len = h * w
|
|
|
|
|
|
|
|
qkv = self.qkv(x)
|
|
qkv = self.qkv(x)
|
|
|
|
|
+
|
|
|
|
|
+ # q, k -> [bs, nh, c_kdh, hw]; v -> [bs, nh, c_vh, hw]
|
|
|
q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
|
|
q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
|
|
|
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
|
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # [bs, nh, hw(q), c_kdh] x [bs, nh, c_kdh, hw(k)] -> [bs, nh, hw(q), hw(k)]
|
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale
|
|
|
attn = attn.softmax(dim=-1)
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
+
|
|
|
|
|
+ # [bs, nh, c_vh, hw(v)] x [bs, nh, hw(k), hw(q)] -> [bs, nh, c_vh, hw]
|
|
|
x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
|
|
x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
|
|
|
x = self.proj(x)
|
|
x = self.proj(x)
|
|
|
|
|
|