yjh0410 11 mēneši atpakaļ
vecāks
revīzija
5fbdb8f277

+ 3 - 0
yolo/config/__init__.py

@@ -9,6 +9,7 @@ from .yolov6_config  import build_yolov6_config
 from .yolov7_config  import build_yolov7_config
 from .yolov8_config  import build_yolov8_config
 from .yolov9_config  import build_yolov9_config
+from .yolo11_config  import build_yolo11_config
 
 from .yolof_config   import build_yolof_config
 from .fcos_config    import build_fcos_config
@@ -39,6 +40,8 @@ def build_config(args):
         cfg = build_yolov8_config(args)
     elif 'yolov9' in args.model:
         cfg = build_yolov9_config(args)
+    elif 'yolo11' in args.model:
+        cfg = build_yolo11_config(args)
         
     # ----------- RT-DETR -----------
     elif 'yolof' in args.model:

+ 209 - 0
yolo/config/yolo11_config.py

@@ -0,0 +1,209 @@
+# yolo Config
+
+
+def build_yolo11_config(args):
+    if   args.model == 'yolo11_n':
+        return Yolo11NConfig()
+    elif args.model == 'yolo11_s':
+        return Yolo11SConfig()
+    elif args.model == 'yolo11_m':
+        return Yolo11MConfig()
+    elif args.model == 'yolo11_l':
+        return Yolo11LConfig()
+    elif args.model == 'yolo11_x':
+        return Yolo11XConfig()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLO11-Base config
+class Yolo11BaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.model_scale = "l"
+        self.width   = 1.0
+        self.depth   = 1.0
+        self.ratio   = 1.0
+        self.reg_max = 16
+
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.tal_topk_candidates = 10
+        self.tal_alpha = 0.5
+        self.tal_beta  = 6.0
+        ## Loss weight
+        self.loss_cls = 0.5
+        self.loss_box = 7.5
+        self.loss_dfl = 1.5
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.base_lr      = 0.001     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 500
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.2,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLO11-N
+class Yolo11NConfig(Yolo11BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "n"
+        self.width = 0.25
+        self.depth = 0.50
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0
+
+# YOLO11-S
+class Yolo11SConfig(Yolo11BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "s"
+        self.width = 0.50
+        self.depth = 0.50
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 1.0
+
+# YOLO11-M
+class Yolo11MConfig(Yolo11BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "m"
+        self.width = 1.0
+        self.depth = 0.5
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 1.0
+
+# YOLO11-L
+class Yolo11LConfig(Yolo11BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "l"
+        self.width = 1.0
+        self.depth = 1.0
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 1.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9999
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'sgd'
+        self.base_lr      = 0.01     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.0005
+        self.clip_max_norm   = 10.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+# YOLO11-X
+class Yolo11XConfig(Yolo11BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "x"
+        self.width = 1.50
+        self.depth = 1.0
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 1.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9999
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'sgd'
+        self.base_lr      = 0.01     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.0005
+        self.clip_max_norm   = 10.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8

+ 0 - 0
yolo/config/yolov11_config.py


+ 4 - 0
yolo/models/__init__.py

@@ -12,6 +12,7 @@ from .yolov6.build import build_yolov6
 from .yolov7.build import build_yolov7
 from .yolov8.build import build_yolov8
 from .yolov9.build import build_gelan
+from .yolo11.build import build_yolo11
 
 from .yolof.build  import build_yolof
 from .fcos.build   import build_fcos
@@ -51,6 +52,9 @@ def build_model(args, cfg, is_val=False):
     ## GElan
     elif 'yolov9' in args.model:
         model, criterion = build_gelan(cfg, is_val)
+    ## YOLO11
+    elif 'yolo11' in args.model:
+        model, criterion = build_yolo11(cfg, is_val)
 
     ## Yolof
     elif 'yolof' in args.model:

+ 100 - 30
yolo/models/yolo11/modules.py

@@ -66,39 +66,24 @@ class C3kBlock(nn.Module):
     def forward(self, x):
         return self.cv3(torch.cat([self.m(self.cv1(x)), self.cv2(x)], dim=1))
 
-class C3k2fBlock(nn.Module):
-    def __init__(self, in_dim, out_dim, num_blocks=1, use_c3k=True, expansion=0.5, shortcut=True):
+class SPPF(nn.Module):
+    def __init__(self, in_dim, out_dim, spp_pooling_size: int = 5, neck_expand_ratio:float = 0.5):
         super().__init__()
-        inter_dim = int(out_dim * expansion)  # hidden channels
-        self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
-        self.cv2 = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
-
-        if use_c3k:
-            self.m = nn.ModuleList(
-                C3kBlock(inter_dim, inter_dim, 2, shortcut)
-                for _ in range(num_blocks)
-            )
-        else:
-            self.m = nn.ModuleList(
-                Bottleneck(inter_dim, inter_dim, [3, 3], shortcut, expansion=0.5)
-                for _ in range(num_blocks)
-            )
-
-    def _forward_impl(self, x):
-        # Input proj
-        x1, x2 = torch.chunk(self.cv1(x), 2, dim=1)
-        out = list([x1, x2])
-
-        # Bottlenecl
-        out.extend(m(out[-1]) for m in self.m)
-
-        # Output proj
-        out = self.cv2(torch.cat(out, dim=1))
-
-        return out
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * neck_expand_ratio)
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, stride=1)
+        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, stride=1)
+        self.m = nn.MaxPool2d(kernel_size=spp_pooling_size, stride=1, padding=spp_pooling_size // 2)
 
     def forward(self, x):
-        return self._forward_impl(x)
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+    
 
 # ----------------- Attention modules  -----------------
 class Attention(nn.Module):
@@ -143,3 +128,88 @@ class PSABlock(nn.Module):
         x = x + self.attn(x) if self.add else self.attn(x)
         x = x + self.ffn(x)  if self.add else self.ffn(x)
         return x
+
+class C2PSA(nn.Module):
+    def __init__(self, in_dim, out_dim, num_blocks=1, expansion=0.5):
+        super().__init__()
+        assert in_dim == out_dim
+        inter_dim = int(in_dim * expansion)
+        self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(2 * inter_dim, in_dim, kernel_size=1)
+        self.m = nn.Sequential(*[
+            PSABlock(in_dim     = inter_dim,
+                     attn_ratio = 0.5,
+                     num_heads  = inter_dim // 64
+                     ) for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x1, x2 = torch.chunk(self.cv1(x), chunks=2, dim=1)
+        x2 = self.m(x2)
+
+        return self.cv2(torch.cat([x1, x2], dim=1))
+
+
+# ----------------- YOLO11 components -----------------
+class YoloStage(nn.Module):
+    def __init__(self, in_dim, out_dim, num_blocks=1, use_c3k=True, expansion=0.5, shortcut=True):
+        super().__init__()
+        inter_dim = int(out_dim * expansion)  # hidden channels
+        self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
+        self.cv2 = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
+
+        if use_c3k:
+            self.m = nn.ModuleList(
+                C3kBlock(inter_dim, inter_dim, 2, shortcut)
+                for _ in range(num_blocks)
+            )
+        else:
+            self.m = nn.ModuleList(
+                Bottleneck(inter_dim, inter_dim, [3, 3], shortcut, expansion=0.5)
+                for _ in range(num_blocks)
+            )
+
+    def _forward_impl(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.cv1(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.m)
+
+        # Output proj
+        out = self.cv2(torch.cat(out, dim=1))
+
+        return out
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+class DflLayer(nn.Module):
+    def __init__(self, reg_max=16):
+        """Initialize a convolutional layer with a given number of input channels."""
+        super().__init__()
+        self.reg_max = reg_max
+        proj_init = torch.arange(reg_max, dtype=torch.float)
+        self.proj_weight = nn.Parameter(proj_init.view([1, reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, pred_reg, anchor, stride):
+        bs, hw = pred_reg.shape[:2]
+        # [bs, hw, 4*rm] -> [bs, 4*rm, hw] -> [bs, 4, rm, hw]
+        pred_reg = pred_reg.permute(0, 2, 1).reshape(bs, 4, -1, hw)
+
+        # [bs, 4, rm, hw] -> [bs, rm, 4, hw]
+        pred_reg = pred_reg.permute(0, 2, 1, 3).contiguous()
+
+        # [bs, rm, 4, hw] -> [bs, 1, 4, hw]
+        delta_pred = F.conv2d(F.softmax(pred_reg, dim=1), self.proj_weight)
+
+        # [bs, 1, 4, hw] -> [bs, 4, hw] -> [bs, hw, 4]
+        delta_pred = delta_pred.view(bs, 4, hw).permute(0, 2, 1).contiguous()
+        delta_pred *= stride
+
+        # Decode bbox: tlbr -> xyxy
+        x1y1_pred = anchor - delta_pred[..., :2]
+        x2y2_pred = anchor + delta_pred[..., 2:]
+        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+        return box_pred

+ 5 - 34
yolo/models/yolo11/yolo11.py

@@ -4,15 +4,11 @@ import torch.nn as nn
 
 # --------------- Model components ---------------
 from .yolo11_backbone import Yolo11Backbone
-from .yolo11_neck     import SPPF, C2PSA
 from .yolo11_pafpn    import Yolo11PaFPN
 from .yolo11_head     import Yolo11DetHead
-from .yolo11_pred     import Yolo11DetPredLayer
 
-# --------------- External components ---------------
 from utils.misc import multiclass_nms
 
-
 # YOLO11
 class Yolo11(nn.Module):
     def __init__(self,
@@ -28,24 +24,10 @@ class Yolo11(nn.Module):
         self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
         self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
         self.no_multi_labels  = False if is_val else True
-        
-        # ---------------------- Network Parameters ----------------------
-        ## Backbone
-        self.backbone = Yolo11Backbone(cfg)
-        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
-
-        ## Neck
-        self.neck_spp  = SPPF(self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
-        self.neck_attn = C2PSA(self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1], num_blocks=int(2 * cfg.depth), expansion=0.5)
-        
-        ## Neck: PaFPN
-        self.fpn = Yolo11PaFPN(cfg, self.backbone.feat_dims)
 
-        ## Head
-        self.head = Yolo11DetHead(cfg, self.fpn.out_dims)
-
-        ## Pred
-        self.pred = Yolo11DetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
+        self.backbone = Yolo11Backbone(cfg)
+        self.pafpn    = Yolo11PaFPN(cfg, self.backbone.feat_dims[-3:])
+        self.det_head = Yolo11DetHead(cfg, self.pafpn.out_dims)
 
     def post_process(self, cls_preds, box_preds):
         """
@@ -126,20 +108,9 @@ class Yolo11(nn.Module):
         return bboxes, scores, labels
     
     def forward(self, x):
-        # ---------------- Backbone ----------------
         pyramid_feats = self.backbone(x)
-        # ---------------- Neck: SPP ----------------
-        pyramid_feats[-1] = self.neck_spp(pyramid_feats[-1])
-        pyramid_feats[-1] = self.neck_attn(pyramid_feats[-1])
-
-        # ---------------- Neck: PaFPN ----------------
-        pyramid_feats = self.fpn(pyramid_feats)
-
-        # ---------------- Heads ----------------
-        cls_feats, reg_feats = self.head(pyramid_feats)
-
-        # ---------------- Preds ----------------
-        outputs = self.pred(cls_feats, reg_feats)
+        pyramid_feats = self.pafpn(pyramid_feats)
+        outputs = self.det_head(pyramid_feats)
         outputs['image_size'] = [x.shape[2], x.shape[3]]
 
         if not self.training:

+ 19 - 6
yolo/models/yolo11/yolo11_backbone.py

@@ -2,9 +2,9 @@ import torch
 import torch.nn as nn
 
 try:
-    from .modules import ConvModule, C3k2fBlock
+    from .modules import ConvModule, YoloStage, SPPF, C2PSA
 except:
-    from  modules import ConvModule, C3k2fBlock
+    from  modules import ConvModule, YoloStage, SPPF, C2PSA
 
 
 # ---------------------------- YOLO11 Backbone ----------------------------
@@ -21,7 +21,7 @@ class Yolo11Backbone(nn.Module):
         # P2/4
         self.layer_2 = nn.Sequential(
             ConvModule(int(64 * cfg.width), int(128 * cfg.width), kernel_size=3, stride=2),
-            C3k2fBlock(in_dim     = int(128 * cfg.width),
+            YoloStage(in_dim     = int(128 * cfg.width),
                       out_dim    = int(256 * cfg.width),
                       num_blocks = round(2*cfg.depth),
                       shortcut   = True,
@@ -32,7 +32,7 @@ class Yolo11Backbone(nn.Module):
         # P3/8
         self.layer_3 = nn.Sequential(
             ConvModule(int(256 * cfg.width), int(256 * cfg.width), kernel_size=3, stride=2),
-            C3k2fBlock(in_dim     = int(256 * cfg.width),
+            YoloStage(in_dim     = int(256 * cfg.width),
                       out_dim    = int(512 * cfg.width),
                       num_blocks = round(2*cfg.depth),
                       shortcut   = True,
@@ -43,7 +43,7 @@ class Yolo11Backbone(nn.Module):
         # P4/16
         self.layer_4 = nn.Sequential(
             ConvModule(int(512 * cfg.width), int(512 * cfg.width), kernel_size=3, stride=2),
-            C3k2fBlock(in_dim     = int(512 * cfg.width),
+            YoloStage(in_dim     = int(512 * cfg.width),
                       out_dim    = int(512 * cfg.width),
                       num_blocks = round(2*cfg.depth),
                       shortcut   = True,
@@ -54,7 +54,7 @@ class Yolo11Backbone(nn.Module):
         # P5/32
         self.layer_5 = nn.Sequential(
             ConvModule(int(512 * cfg.width), int(512 * cfg.width * cfg.ratio), kernel_size=3, stride=2),
-            C3k2fBlock(in_dim     = int(512 * cfg.width * cfg.ratio),
+            YoloStage(in_dim     = int(512 * cfg.width * cfg.ratio),
                       out_dim    = int(512 * cfg.width * cfg.ratio),
                       num_blocks = round(2*cfg.depth),
                       shortcut   = True,
@@ -62,6 +62,17 @@ class Yolo11Backbone(nn.Module):
                       use_c3k    = True,
                       )
         )
+        # Extra module (no pretrained weight)
+        self.layer_6 = SPPF(in_dim  = int(512 * cfg.width * cfg.ratio),
+                            out_dim = int(512 * cfg.width * cfg.ratio),
+                            spp_pooling_size = 5,
+                            neck_expand_ratio = 0.5,
+                            )
+        self.layer_7 = C2PSA(in_dim  = int(512 * cfg.width * cfg.ratio),
+                             out_dim = int(512 * cfg.width * cfg.ratio),
+                             num_blocks = round(2*cfg.depth),
+                             expansion = 0.5,
+                             )
 
         # Initialize all layers
         self.init_weights()
@@ -77,6 +88,8 @@ class Yolo11Backbone(nn.Module):
         c3 = self.layer_3(c2)
         c4 = self.layer_4(c3)
         c5 = self.layer_5(c4)
+        c5 = self.layer_6(c5)
+        c5 = self.layer_7(c5)
         outputs = [c3, c4, c5]
 
         return outputs

+ 126 - 113
yolo/models/yolo11/yolo11_head.py

@@ -1,112 +1,133 @@
+import math
 import torch
 import torch.nn as nn
 from typing import List
 
 try:
-    from .modules import ConvModule
+    from .modules import ConvModule, DflLayer
 except:
-    from  modules import ConvModule
-
-
-# -------------------- Detection Head --------------------
-## Single-level Detection Head
-class DetHead(nn.Module):
-    def __init__(self,
-                 in_dim       :int  = 256,
-                 cls_head_dim :int  = 256,
-                 reg_head_dim :int  = 256,
-                 num_cls_head :int  = 2,
-                 num_reg_head :int  = 2,
-                 ):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.in_dim = in_dim
-        self.num_cls_head = num_cls_head
-        self.num_reg_head = num_reg_head
-        
-        # --------- Network Parameters ----------
-        ## classification head
-        cls_feats = []
-        self.cls_head_dim = cls_head_dim
-        for i in range(num_cls_head):
-            if i == 0:
-                cls_feats.append(nn.Sequential(
-                    ConvModule(in_dim, in_dim, kernel_size=3, stride=1, groups=in_dim),
-                    ConvModule(in_dim, self.cls_head_dim, kernel_size=1),
-                ))
-            else:
-                cls_feats.append(nn.Sequential(
-                    ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, stride=1, groups=self.cls_head_dim),
-                    ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=1),
-                ))
-        
-        ## bbox regression head
-        reg_feats = []
-        self.reg_head_dim = reg_head_dim
-        for i in range(num_reg_head):
-            if i == 0:
-                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, stride=1))
-            else:
-                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, stride=1))
-        
-        self.cls_feats = nn.Sequential(*cls_feats)
-        self.reg_feats = nn.Sequential(*reg_feats)
+    from  modules import ConvModule, DflLayer
 
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                m.reset_parameters()
 
-    def forward(self, x):
-        """
-            in_feats: (Tensor) [B, C, H, W]
-        """
-        cls_feats = self.cls_feats(x)
-        reg_feats = self.reg_feats(x)
-
-        return cls_feats, reg_feats
-    
-## Multi-level Detection Head
+# YOLO11 detection head
 class Yolo11DetHead(nn.Module):
-    def __init__(self, cfg, in_dims: List = [256, 512, 1024]):
+    def __init__(self, cfg, fpn_dims: List = [64, 128, 245]):
         super().__init__()
-        self.num_levels = len(cfg.out_stride)
-        ## ----------- Network Parameters -----------
-        self.multi_level_heads = nn.ModuleList(
-            [DetHead(in_dim       = in_dims[level],
-                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
-                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
-                     num_cls_head = cfg.num_cls_head,
-                     num_reg_head = cfg.num_reg_head,
-                     ) for level in range(self.num_levels)])
-        # --------- Basic Parameters ----------
-        self.in_dims = in_dims
-        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
-        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
-
-    def forward(self, feats):
+        self.out_stride = cfg.out_stride
+        self.reg_max = cfg.reg_max
+        self.num_classes = cfg.num_classes
+
+        self.cls_dim = max(fpn_dims[0], min(cfg.num_classes, 128))
+        self.reg_dim = max(fpn_dims[0]//4, 16, 4*cfg.reg_max)
+
+        # classification head
+        self.cls_heads = nn.ModuleList(
+            nn.Sequential(
+                nn.Sequential(ConvModule(dim, dim, kernel_size=3, stride=1, groups=dim),
+                              ConvModule(dim, self.cls_dim, kernel_size=1)),
+                nn.Sequential(ConvModule(self.cls_dim, self.cls_dim, kernel_size=3, stride=1, groups=self.cls_dim),
+                              ConvModule(self.cls_dim, self.cls_dim, kernel_size=1)),
+                nn.Conv2d(self.cls_dim, cfg.num_classes, kernel_size=1),
+            )
+            for dim in fpn_dims
+        )
+
+        # bbox regression head
+        self.reg_heads = nn.ModuleList(
+            nn.Sequential(
+                ConvModule(dim, self.reg_dim, kernel_size=3, stride=1),
+                ConvModule(self.reg_dim, self.reg_dim, kernel_size=3, stride=1),
+                nn.Conv2d(self.reg_dim, 4*cfg.reg_max, kernel_size=1),
+            )
+            for dim in fpn_dims
+        )
+
+        # DFL layer for decoding bbox
+        self.dfl_layer = DflLayer(cfg.reg_max)
+        for p in self.dfl_layer.parameters():
+            p.requires_grad = False
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred
+        for i, m in enumerate(self.cls_heads):
+            b = m[-1].bias.view(1, -1)
+            b.data.fill_(math.log(5 / self.num_classes / (640. / self.out_stride[i]) ** 2))
+            m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+        # reg pred
+        for m in self.reg_heads:
+            b = m[-1].bias.view(-1, )
+            b.data.fill_(1.0)
+            m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+            
+            w = m[-1].weight
+            w.data.fill_(0.)
+            m[-1].weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size, level):
         """
-            feats: List[(Tensor)] [[B, C, H, W], ...]
+            fmp_size: (List) [H, W]
         """
-        cls_feats = []
-        reg_feats = []
-        for feat, head in zip(feats, self.multi_level_heads):
-            # ---------------- Pred ----------------
-            cls_feat, reg_feat = head(feat)
-
-            cls_feats.append(cls_feat)
-            reg_feats.append(reg_feat)
-
-        return cls_feats, reg_feats
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.out_stride[level]
+
+        return anchors
+
+    def forward(self, fpn_feats):
+        anchors = []
+        strides = []
+        cls_preds = []
+        reg_preds = []
+        box_preds = []
+
+        for lvl, (feat, cls_head, reg_head) in enumerate(zip(fpn_feats, self.cls_heads, self.reg_heads)):
+            bs, c, h, w = feat.size()
+            device = feat.device
+            
+            # Prediction
+            cls_pred = cls_head(feat)
+            reg_pred = reg_head(feat)
+
+            # [bs, c, h, w] -> [bs, c, hw] -> [bs, hw, c]
+            cls_pred = cls_pred.flatten(2).permute(0, 2, 1).contiguous()
+            reg_pred = reg_pred.flatten(2).permute(0, 2, 1).contiguous()
+
+            # anchor points: [M, 2]
+            anchor = self.generate_anchors(fmp_size=[h, w], level=lvl).to(device)
+            stride = torch.ones_like(anchor[..., :1]) * self.out_stride[lvl]
+
+            # Decode bbox coords
+            box_pred = self.dfl_layer(reg_pred, anchor[None], self.out_stride[lvl])
+
+            # collect results
+            anchors.append(anchor)
+            strides.append(stride)
+            cls_preds.append(cls_pred)
+            reg_preds.append(reg_pred)
+            box_preds.append(box_pred)
+
+        # output dict
+        outputs = {"pred_cls":       cls_preds,        # List(Tensor) [B, M, C]
+                   "pred_reg":       reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":       box_preds,        # List(Tensor) [B, M, 4]
+                   "anchors":        anchors,          # List(Tensor) [M, 2]
+                   "stride_tensors": strides,          # List(Tensor) [M, 1]
+                   "strides":        self.out_stride,  # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
 
 
 if __name__=='__main__':
-    import time
     from thop import profile
-    
+
     # YOLO11-Base config
     class Yolo11BaseConfig(object):
         def __init__(self) -> None:
@@ -118,32 +139,24 @@ if __name__=='__main__':
             self.out_stride = [8, 16, 32]
             self.max_stride = 32
             self.num_levels = 3
-            ## Head
-            self.num_cls_head = 2
-            self.num_reg_head = 2
+            self.num_classes = 80
 
     cfg = Yolo11BaseConfig()
-    cfg.num_classes = 20
 
-    # Build a head
-    fpn_dims = [128, 256, 512]
-    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80),
-                     torch.randn(1, fpn_dims[1], 40, 40),
-                     torch.randn(1, fpn_dims[2], 20, 20)]
-    head = Yolo11DetHead(cfg, fpn_dims)
+    # Random data
+    fpn_dims = [256, 512, 512]
+    x = [torch.randn(1, fpn_dims[0], 80, 80),
+         torch.randn(1, fpn_dims[1], 40, 40),
+         torch.randn(1, fpn_dims[2], 20, 20)]
 
+    # Neck model
+    model = Yolo11DetHead(cfg, fpn_dims)
 
     # Inference
-    t0 = time.time()
-    cls_feats, reg_feats = head(pyramid_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print("====== Yolo11 Head output ======")
-    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
-        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
-
-    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
-    print('==============================')
+    outputs = model(x)
+
+    print('============ FLOPs & Params ===========')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
     print('Params : {:.2f} M'.format(params / 1e6))
     

+ 0 - 45
yolo/models/yolo11/yolo11_neck.py

@@ -1,45 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import ConvModule, PSABlock
-except:
-    from  modules import ConvModule, PSABlock
-
-
-class SPPF(nn.Module):
-    def __init__(self, in_dim, out_dim, spp_pooling_size: int = 5, neck_expand_ratio:float = 0.5):
-        super().__init__()
-        ## ----------- Basic Parameters -----------
-        inter_dim = round(in_dim * neck_expand_ratio)
-        self.out_dim = out_dim
-        ## ----------- Network Parameters -----------
-        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, stride=1)
-        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, stride=1)
-        self.m = nn.MaxPool2d(kernel_size=spp_pooling_size, stride=1, padding=spp_pooling_size // 2)
-
-    def forward(self, x):
-        x = self.cv1(x)
-        y1 = self.m(x)
-        y2 = self.m(y1)
-
-        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
-    
-class C2PSA(nn.Module):
-    def __init__(self, in_dim, out_dim, num_blocks=1, expansion=0.5):
-        super().__init__()
-        assert in_dim == out_dim
-        inter_dim = int(in_dim * expansion)
-        self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
-        self.cv2 = ConvModule(2 * inter_dim, in_dim, kernel_size=1)
-        self.m = nn.Sequential(*[
-            PSABlock(in_dim     = inter_dim,
-                     attn_ratio = 0.5,
-                     num_heads  = inter_dim // 64
-                     ) for _ in range(num_blocks)])
-
-    def forward(self, x):
-        x1, x2 = torch.chunk(self.cv1(x), chunks=2, dim=1)
-        x2 = self.m(x2)
-
-        return self.cv2(torch.cat([x1, x2], dim=1))

+ 6 - 6
yolo/models/yolo11/yolo11_pafpn.py

@@ -4,9 +4,9 @@ import torch.nn.functional as F
 from typing import List
 
 try:
-    from .modules import ConvModule, C3k2fBlock
+    from .modules import ConvModule, YoloStage
 except:
-    from  modules import ConvModule, C3k2fBlock
+    from  modules import ConvModule, YoloStage
 
 
 class Yolo11PaFPN(nn.Module):
@@ -19,7 +19,7 @@ class Yolo11PaFPN(nn.Module):
 
         # ----------------------------- Yolo11's Top-down FPN -----------------------------
         ## P5 -> P4
-        self.top_down_layer_1 = C3k2fBlock(in_dim     = self.in_dims[0] + self.in_dims[1],
+        self.top_down_layer_1 = YoloStage(in_dim     = self.in_dims[0] + self.in_dims[1],
                                           out_dim    = round(512*cfg.width),
                                           num_blocks = round(2 * cfg.depth),
                                           shortcut   = True,
@@ -27,7 +27,7 @@ class Yolo11PaFPN(nn.Module):
                                           use_c3k    = False if self.model_scale in "ns" else True,
                                           )
         ## P4 -> P3
-        self.top_down_layer_2 = C3k2fBlock(in_dim     = self.in_dims[2] + round(512*cfg.width),
+        self.top_down_layer_2 = YoloStage(in_dim     = self.in_dims[2] + round(512*cfg.width),
                                           out_dim    = round(256*cfg.width),
                                           num_blocks = round(2 * cfg.depth),
                                           shortcut   = True,
@@ -37,7 +37,7 @@ class Yolo11PaFPN(nn.Module):
         # ----------------------------- Yolo11's Bottom-up PAN -----------------------------
         ## P3 -> P4
         self.dowmsample_layer_1 = ConvModule(round(256*cfg.width), round(256*cfg.width), kernel_size=3, stride=2)
-        self.bottom_up_layer_1 = C3k2fBlock(in_dim     = round(256*cfg.width) + round(512*cfg.width),
+        self.bottom_up_layer_1 = YoloStage(in_dim     = round(256*cfg.width) + round(512*cfg.width),
                                            out_dim    = round(512*cfg.width),
                                            num_blocks = round(2 * cfg.depth),
                                            shortcut   = True,
@@ -46,7 +46,7 @@ class Yolo11PaFPN(nn.Module):
                                            )
         ## P4 -> P5
         self.dowmsample_layer_2 = ConvModule(round(512*cfg.width), round(512*cfg.width), kernel_size=3, stride=2)
-        self.bottom_up_layer_2 = C3k2fBlock(in_dim     = round(512*cfg.width) + self.in_dims[0],
+        self.bottom_up_layer_2 = YoloStage(in_dim     = round(512*cfg.width) + self.in_dims[0],
                                            out_dim    = round(512*cfg.width*cfg.ratio),
                                            num_blocks = round(2 * cfg.depth),
                                            shortcut   = True,

+ 0 - 207
yolo/models/yolo11/yolo11_pred.py

@@ -1,207 +0,0 @@
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-# -------------------- Detection Pred Layer --------------------
-## Single-level pred layer
-class DetPredLayer(nn.Module):
-    def __init__(self,
-                 cls_dim     :int = 256,
-                 reg_dim     :int = 256,
-                 stride      :int = 32,
-                 reg_max     :int = 16,
-                 num_classes :int = 80,
-                 num_coords  :int = 4):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.stride = stride
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-        self.reg_max = reg_max
-        self.num_classes = num_classes
-        self.num_coords = num_coords
-
-        # --------- Network Parameters ----------
-        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
-        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
-
-        self.init_bias()
-        
-    def init_bias(self):
-        # cls pred bias
-        b = self.cls_pred.bias.view(1, -1)
-        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
-        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred bias
-        b = self.reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = self.reg_pred.weight
-        w.data.fill_(0.)
-        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-    def generate_anchors(self, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchors += 0.5  # add center offset
-        anchors *= self.stride
-
-        return anchors
-        
-    def forward(self, cls_feat, reg_feat):
-        # pred
-        cls_pred = self.cls_pred(cls_feat)
-        reg_pred = self.reg_pred(reg_feat)
-
-        # generate anchor boxes: [M, 4]
-        B, _, H, W = cls_pred.size()
-        fmp_size = [H, W]
-        anchors = self.generate_anchors(fmp_size)
-        anchors = anchors.to(cls_pred.device)
-        # stride tensor: [M, 1]
-        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
-        
-        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-        
-        # output dict
-        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
-                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
-                   "anchors": anchors,              # List(Tensor) [M, 2]
-                   "strides": self.stride,          # List(Int) = [8, 16, 32]
-                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
-                   }
-
-        return outputs
-
-## Multi-level pred layer
-class Yolo11DetPredLayer(nn.Module):
-    def __init__(self, cfg, cls_dim: int, reg_dim: int):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.cfg = cfg
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-        self.num_levels = len(cfg.out_stride)
-
-        # ----------- Network Parameters -----------
-        ## pred layers
-        self.multi_level_preds = nn.ModuleList(
-            [DetPredLayer(cls_dim     = cls_dim,
-                          reg_dim     = reg_dim,
-                          stride      = cfg.out_stride[level],
-                          reg_max     = cfg.reg_max,
-                          num_classes = cfg.num_classes,
-                          num_coords  = 4 * cfg.reg_max)
-                          for level in range(self.num_levels)
-                          ])
-        ## proj conv
-        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
-        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
-        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
-
-    def forward(self, cls_feats, reg_feats):
-        all_anchors = []
-        all_strides = []
-        all_cls_preds = []
-        all_reg_preds = []
-        all_box_preds = []
-        for level in range(self.num_levels):
-            # -------------- Single-level prediction --------------
-            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
-
-            # -------------- Decode bbox --------------
-            B, M = outputs["pred_reg"].shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
-            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
-            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
-            # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
-            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
-            ## tlbr -> xyxy
-            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
-            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
-            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
-
-            # collect results
-            all_cls_preds.append(outputs["pred_cls"])
-            all_reg_preds.append(outputs["pred_reg"])
-            all_box_preds.append(box_pred)
-            all_anchors.append(outputs["anchors"])
-            all_strides.append(outputs["stride_tensor"])
-        
-        # output dict
-        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
-                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
-                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
-                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
-                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
-                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
-                   }
-
-        return outputs
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # YOLO11-Base config
-    class Yolo11BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 1.0
-            self.depth    = 1.0
-            self.ratio    = 1.0
-            self.reg_max  = 16
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
-
-    cfg = Yolo11BaseConfig()
-    cfg.num_classes = 20
-    cls_dim = 128
-    reg_dim = 64
-    # Build a pred layer
-    pred = Yolo11DetPredLayer(cfg, cls_dim, reg_dim)
-
-    # Inference
-    cls_feats = [torch.randn(1, cls_dim, 80, 80),
-                 torch.randn(1, cls_dim, 40, 40),
-                 torch.randn(1, cls_dim, 20, 20),]
-    reg_feats = [torch.randn(1, reg_dim, 80, 80),
-                 torch.randn(1, reg_dim, 40, 40),
-                 torch.randn(1, reg_dim, 20, 20),]
-    t0 = time.time()
-    output = pred(cls_feats, reg_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('====== Pred output ======= ')
-    pred_cls = output["pred_cls"]
-    pred_reg = output["pred_reg"]
-    pred_box = output["pred_box"]
-    anchors  = output["anchors"]
-    
-    for level in range(cfg.num_levels):
-        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
-        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
-        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
-        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
-
-    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))