yjh0410 9 сар өмнө
parent
commit
329eea4b8e

+ 12 - 8
yolo/models/yolov10/modules.py

@@ -118,28 +118,32 @@ class SCDown(nn.Module):
 class Attention(nn.Module):
     def __init__(self, dim, num_heads=8, attn_ratio=0.5):
         super().__init__()
-        self.num_heads = num_heads
-        self.head_dim = dim // num_heads
-        self.key_dim = int(self.head_dim * attn_ratio)
+        self.num_heads = num_heads                      # number of the attention heads
+        self.head_dim = dim // num_heads                # per head dim of v
+        self.key_dim = int(self.head_dim * attn_ratio)  # per head dim of qk
         self.scale = self.key_dim**-0.5
         
-        nh_kd = self.key_dim * num_heads
-        h = dim + nh_kd * 2
-        self.qkv  = ConvModule(dim, h, kernel_size=1, use_act=False)
-        self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False)
-        self.pe   = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False)
+        qkv_dims = dim + self.key_dim * num_heads * 2   # total dims of qkv
+        self.qkv  = ConvModule(dim, qkv_dims, kernel_size=1, use_act=False)  # qkv projection
+        self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False)       # output projection
+        self.pe   = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False)  # position embedding conv
 
     def forward(self, x):
         bs, c, h, w = x.shape
         seq_len = h * w
 
         qkv = self.qkv(x)
+
+        # q, k -> [bs, nh, c_kdh, hw]; v -> [bs, nh, c_vh, hw]
         q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
             [self.key_dim, self.key_dim, self.head_dim], dim=2
         )
 
+        # [bs, nh, hw(q), c_kdh] x [bs, nh, c_kdh, hw(k)] -> [bs, nh, hw(q), hw(k)]
         attn = (q.transpose(-2, -1) @ k) * self.scale
         attn = attn.softmax(dim=-1)
+
+        # [bs, nh, c_vh, hw(v)] x [bs, nh, hw(k), hw(q)] -> [bs, nh, c_vh, hw]
         x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
         x = self.proj(x)
 

+ 0 - 7
yolo/models/yolov10/yolov10.py

@@ -9,9 +9,6 @@ from .yolov10_pafpn    import Yolov10PaFPN
 from .yolov10_head     import Yolov10DetHead
 from .yolov10_pred     import Yolov10DetPredLayer
 
-# --------------- External components ---------------
-from utils.misc import multiclass_nms
-
 
 # YOLOv10
 class Yolov10(nn.Module):
@@ -115,10 +112,6 @@ class Yolov10(nn.Module):
         scores = scores.cpu().numpy()
         labels = labels.cpu().numpy()
         bboxes = bboxes.cpu().numpy()
-
-        # # nms
-        # scores, labels, bboxes = multiclass_nms(
-        #     scores, labels, bboxes, self.nms_thresh, self.num_classes)
         
         # keep top-300 results
         scores = scores[:300]