yjh0410 1 ano atrás
pai
commit
992e3a01e8

+ 0 - 0
yolo/config/detr_config.py


+ 30 - 30
yolo/models/yolo11/README.md

@@ -1,56 +1,56 @@
-# YOLOv7:
-
-|    Model    |   Backbone    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|-------------|---------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOv7-Tiny | ELANNet-Tiny  | 8xb16 |  640  |         39.5           |       58.5        |   22.6            |   7.9              | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
-| YOLOv7      | ELANNet-Large | 8xb16 |  640  |         49.5           |       68.8        |   144.6           |   44.0             | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov7_coco.pth) |
-| YOLOv7-X    | ELANNet-Huge  |       |  640  |                        |                   |                   |                    |  |
-
-- For training, we train `YOLOv7` and `YOLOv7-Tiny` with 300 epochs on 8 GPUs.
-- For data augmentation, we use the [YOLOX-style](https://github.com/Megvii-BaseDetection/YOLOX) augmentation including the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation.
-- For optimizer, we use `AdamW` with weight decay 0.05 and per image learning rate 0.001 / 64.
-- For learning rate scheduler, we use Cosine decay scheduler.
-- For YOLOv7's structure, we replace the coupled head with the YOLOX-style decoupled head.
-- I think YOLOv7 uses too many training tricks, such as `anchor box`, `AuxiliaryHead`, `RepConv`, `Mosaic9x` and so on, making the picture of YOLO too complicated, which is against the development concept of the YOLO series. Otherwise, why don't we use the DETR series? It's nothing more than doing some acceleration optimization on DETR. Therefore, I was faithful to my own technical aesthetics and realized a cleaner and simpler YOLOv7, but without the blessing of so many tricks, I did not reproduce all the performance, which is a pity.
-- I have no more GPUs to train my `YOLOv7-X`.
-
-## Train YOLOv7
+# YOLO11:
+
+- VOC
+
+|     Model   | Batch | Scale | AP<sup>val<br>0.5 | Weight |  Logs  |
+|-------------|-------|-------|-------------------|--------|--------|
+| YOLO11-S    | 1xb16 |  640  |      83.6     | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/yolo11_s_voc.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/YOLO11-S-VOC.txt) |
+
+- COCO
+
+|    Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |  Logs  |
+|-------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|--------|
+| YOLO11-S    | 1xb16 |  640  |                    |               |   26.9            |   8.9             |  |  |
+
+
+
+## Train YOLO11
 ### Single GPU
-Taking training YOLOv7-Tiny on COCO as the example,
+Taking training YOLO11-S on COCO as the example,
 ```Shell
-python train.py --cuda -d coco --root path/to/coco -m yolov7_tiny -bs 16 -size 640 --wp_epoch 3 --max_epoch 300 --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --multi_scale 
+python train.py --cuda -d coco --root path/to/coco -m yolo11_s -bs 16 --fp16 
 ```
 
 ### Multi GPU
-Taking training YOLOv7-Tiny on COCO as the example,
+Taking training YOLO11-S on COCO as the example,
 ```Shell
-python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root /data/datasets/ -m yolov7_tiny -bs 128 -size 640 --wp_epoch 3 --max_epoch 300  --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --sybn --multi_scale --save_folder weights/ 
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root path/to/coco -m yolo11_s -bs 256 --fp16 
 ```
 
-## Test YOLOv7
-Taking testing YOLOv7-Tiny on COCO-val as the example,
+## Test YOLO11
+Taking testing YOLO11-S on COCO-val as the example,
 ```Shell
-python test.py --cuda -d coco --root path/to/coco -m yolov7_tiny --weight path/to/yolov7_tiny.pth -size 640 -vt 0.4 --show 
+python test.py --cuda -d coco --root path/to/coco -m yolo11_s --weight path/to/yolo11.pth --show 
 ```
 
-## Evaluate YOLOv7
-Taking evaluating YOLOv7-Tiny on COCO-val as the example,
+## Evaluate YOLO11
+Taking evaluating YOLO11-S on COCO-val as the example,
 ```Shell
-python eval.py --cuda -d coco-val --root path/to/coco -m yolov7_tiny --weight path/to/yolov7_tiny.pth 
+python eval.py --cuda -d coco --root path/to/coco -m yolo11_s --weight path/to/yolo11.pth 
 ```
 
 ## Demo
 ### Detect with Image
 ```Shell
-python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolo11_s --weight path/to/weight --show
 ```
 
 ### Detect with Video
 ```Shell
-python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show --gif
+python demo.py --mode video --path_to_vid path/to/video --cuda -m yolo11_s --weight path/to/weight --show --gif
 ```
 
 ### Detect with Camera
 ```Shell
-python demo.py --mode camera --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show --gif
+python demo.py --mode camera --cuda -m yolo11_s --weight path/to/weight --show --gif
 ```

+ 8 - 50
yolo/models/yolo11/build.py

@@ -1,66 +1,24 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-
-import torch
 import torch.nn as nn
 
-from .loss import build_criterion
-from .yolo11 import YOLOv7
+from .loss import SetCriterion
+from .yolo11 import Yolo11
 
 
 # build object detector
-def build_yolov7(args, cfg, device, num_classes=80, trainable=False, deploy=False):
-    print('==============================')
-    print('Build {} ...'.format(args.model.upper()))
-    
-    print('==============================')
-    print('Model Configuration: \n', cfg)
-    
+def build_yolo11(cfg, is_val=False):
     # -------------- Build YOLO --------------
-    model = YOLOv7(cfg                = cfg,
-                   device             = device, 
-                   num_classes        = num_classes,
-                   trainable          = trainable,
-                   conf_thresh        = args.conf_thresh,
-                   nms_thresh         = args.nms_thresh,
-                   topk               = args.topk,
-                   deploy             = deploy,
-                   no_multi_labels    = args.no_multi_labels,
-                   nms_class_agnostic = args.nms_class_agnostic
-                   )
+    model = Yolo11(cfg, is_val)
 
     # -------------- Initialize YOLO --------------
     for m in model.modules():
         if isinstance(m, nn.BatchNorm2d):
             m.eps = 1e-3
             m.momentum = 0.03    
-    # Init bias
-    init_prob = 0.01
-    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-    # obj pred
-    for obj_pred in model.obj_preds:
-        b = obj_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-    # cls pred
-    for cls_pred in model.cls_preds:
-        b = cls_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-    # reg pred
-    for reg_pred in model.reg_preds:
-        b = reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = reg_pred.weight
-        w.data.fill_(0.)
-        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-
+            
     # -------------- Build criterion --------------
     criterion = None
-    if trainable:
+    if is_val:
         # build criterion for training
-        criterion = build_criterion(args, cfg, device, num_classes)
-
+        criterion = SetCriterion(cfg)
+        
     return model, criterion

+ 125 - 158
yolo/models/yolo11/loss.py

@@ -1,212 +1,179 @@
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from .matcher import SimOTA
-from utils.box_ops import get_ious
+
+from utils.box_ops import bbox2dist, bbox_iou
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
+from .matcher import TaskAlignedAssigner
 
 
-class Criterion(object):
-    def __init__(self,
-                 args,
-                 cfg, 
-                 device, 
-                 num_classes=80):
-        self.args = args
+class SetCriterion(object):
+    def __init__(self, cfg):
+        # --------------- Basic parameters ---------------
         self.cfg = cfg
-        self.device = device
-        self.num_classes = num_classes
-        self.max_epoch = args.max_epoch
-        self.no_aug_epoch = args.no_aug_epoch
-        self.aux_bbox_loss = False
-        # loss weight
-        self.loss_obj_weight = cfg['loss_obj_weight']
-        self.loss_cls_weight = cfg['loss_cls_weight']
-        self.loss_box_weight = cfg['loss_box_weight']
-        # matcher
-        matcher_config = cfg['matcher']
-        self.matcher = SimOTA(
-            num_classes=num_classes,
-            center_sampling_radius=matcher_config['center_sampling_radius'],
-            topk_candidate=matcher_config['topk_candicate']
-            )
-
-
-    def loss_objectness(self, pred_obj, gt_obj):
-        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
-
-        return loss_obj
-    
-
-    def loss_classes(self, pred_cls, gt_label):
-        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+        self.reg_max = cfg.reg_max
+        self.num_classes = cfg.num_classes
+        # --------------- Loss config ---------------
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+        self.loss_dfl_weight = cfg.loss_dfl
+        # --------------- Matcher config ---------------
+        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
+                                           topk_candidates = cfg.tal_topk_candidates,
+                                           alpha           = cfg.tal_alpha,
+                                           beta            = cfg.tal_beta
+                                           )
+
+    def loss_classes(self, pred_cls, gt_score):
+        # compute bce loss
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
 
         return loss_cls
-
-
-    def loss_bboxes(self, pred_box, gt_box):
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
         # regression loss
-        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
-        loss_box = 1.0 - ious
+        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
+        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
 
         return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
 
+        return loss_dfl
 
-    def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
-        # xyxy -> cxcy&bwbh
-        gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
-        gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
-        # encode gt box
-        gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
-        gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
-        gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
-        # l1 loss
-        loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
-
-        return loss_box_aux
-
-
-    def __call__(self, outputs, targets, epoch=0):        
+    def __call__(self, outputs, targets):        
         """
-            outputs['pred_obj']: List(Tensor) [B, M, 1]
             outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
             outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
             outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
             targets: (List) [dict{'boxes': [...], 
                                  'labels': [...], 
                                  'orig_size': ...}, ...]
         """
-        bs = outputs['pred_cls'][0].shape[0]
-        device = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
-        anchors = outputs['anchors']
         # preds: [B, M, C]
-        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
         box_preds = torch.cat(outputs['pred_box'], dim=1)
-
-        # label assignment
-        cls_targets = []
-        box_targets = []
-        obj_targets = []
+        bs, num_anchors = cls_preds.shape[:2]
+        device = cls_preds.device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        
+        # --------------- label assignment ---------------
+        gt_score_targets = []
+        gt_bbox_targets = []
         fg_masks = []
-
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
 
             # check target
-            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
-                num_anchors = sum([ab.shape[0] for ab in anchors])
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
                 # There is no valid gt
-                cls_target = obj_preds.new_zeros((0, self.num_classes))
-                box_target = obj_preds.new_zeros((0, 4))
-                obj_target = obj_preds.new_zeros((num_anchors, 1))
-                fg_mask = obj_preds.new_zeros(num_anchors).bool()
+                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
             else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    fg_mask,
-                    assigned_labels,
-                    assigned_ious,
-                    assigned_indexs
+                    _,
+                    gt_box,     # [1, M, 4]
+                    gt_score,   # [1, M, C]
+                    fg_mask,    # [1, M,]
+                    _
                 ) = self.matcher(
-                    fpn_strides = fpn_strides,
-                    anchors = anchors,
-                    pred_obj = obj_preds[batch_idx],
-                    pred_cls = cls_preds[batch_idx], 
-                    pred_box = box_preds[batch_idx],
-                    tgt_labels = tgt_labels,
-                    tgt_bboxes = tgt_bboxes
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
                     )
-
-                obj_target = fg_mask.unsqueeze(-1)
-                cls_target = F.one_hot(assigned_labels.long(), self.num_classes)
-                cls_target = cls_target * assigned_ious.unsqueeze(-1)
-                box_target = tgt_bboxes[assigned_indexs]
-
-            cls_targets.append(cls_target)
-            box_targets.append(box_target)
-            obj_targets.append(obj_target)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
             fg_masks.append(fg_mask)
 
-        cls_targets = torch.cat(cls_targets, 0)
-        box_targets = torch.cat(box_targets, 0)
-        obj_targets = torch.cat(obj_targets, 0)
-        fg_masks = torch.cat(fg_masks, 0)
-        num_fgs = fg_masks.sum()
-
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = gt_score_targets.sum()
+        
+        # Average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
-        # ------------------ Objecntness loss ------------------
-        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
-        loss_obj = loss_obj.sum() / num_fgs
-        
         # ------------------ Classification loss ------------------
-        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
-        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------
         box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
         loss_box = loss_box.sum() / num_fgs
 
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensors'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
         # total loss
-        losses = self.loss_obj_weight * loss_obj + \
-                 self.loss_cls_weight * loss_cls + \
-                 self.loss_box_weight * loss_box
-
-        # ------------------ Aux regression loss ------------------
-        loss_box_aux = None
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
-            ## reg_preds
-            reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-            reg_preds_pos = reg_preds.view(-1, 4)[fg_masks]
-            ## anchor tensors
-            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
-            anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
-            ## stride tensors
-            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
-            stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
-            ## aux loss
-            loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets, anchors_tensors_pos, stride_tensors_pos)
-            loss_box_aux = loss_box_aux.sum() / num_fgs
-
-            losses += loss_box_aux
-
-        # Loss dict
-        if loss_box_aux is None:
-            loss_dict = dict(
-                    loss_obj = loss_obj,
-                    loss_cls = loss_cls,
-                    loss_box = loss_box,
-                    losses = losses
-            )
-        else:
-            loss_dict = dict(
-                    loss_obj = loss_obj,
-                    loss_cls = loss_cls,
-                    loss_box = loss_box,
-                    loss_box_aux = loss_box_aux,
-                    losses = losses
-                    )
+        losses = loss_cls * self.loss_cls_weight + \
+                 loss_box * self.loss_box_weight + \
+                 loss_dfl * self.loss_dfl_weight
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
 
         return loss_dict
     
 
-def build_criterion(args, cfg, device, num_classes):
-    criterion = Criterion(
-        args=args,
-        cfg=cfg,
-        device=device,
-        num_classes=num_classes
-        )
-
-    return criterion
-
-
 if __name__ == "__main__":
     pass

+ 192 - 177
yolo/models/yolo11/matcher.py

@@ -1,187 +1,202 @@
-# ---------------------------------------------------------------------
-# Copyright (c) Megvii Inc. All rights reserved.
-# ---------------------------------------------------------------------
-
-
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from utils.box_ops import *
+from utils.box_ops import bbox_iou
 
 
-class SimOTA(object):
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
     """
-        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
+        This code referenced to https://github.com/ultralytics/ultralytics
     """
-    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
+    def __init__(self,
+                 num_classes     = 80,
+                 topk_candidates = 10,
+                 alpha           = 0.5,
+                 beta            = 6.0, 
+                 eps             = 1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk_candidates = topk_candidates
         self.num_classes = num_classes
-        self.center_sampling_radius = center_sampling_radius
-        self.topk_candidate = topk_candidate
-
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
 
     @torch.no_grad()
-    def __call__(self, 
-                 fpn_strides, 
-                 anchors, 
-                 pred_obj, 
-                 pred_cls, 
-                 pred_box, 
-                 tgt_labels,
-                 tgt_bboxes):
-        # [M,]
-        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
-        anchors = torch.cat(anchors, dim=0)
-        num_anchor = anchors.shape[0]        
-        num_gt = len(tgt_labels)
-
-        # ----------------------- Find inside points -----------------------
-        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
-            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
-        obj_preds = pred_obj[fg_mask].float()   # [Mp, 1]
-        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
-        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
-
-        # ----------------------- Reg cost -----------------------
-        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
-        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
-
-        # ----------------------- Cls cost -----------------------
-        with torch.cuda.amp.autocast(enabled=False):
-            # [Mp, C]
-            score_preds = torch.sqrt(obj_preds.sigmoid_()* cls_preds.sigmoid_())
-            # [N, Mp, C]
-            score_preds = score_preds.unsqueeze(0).repeat(num_gt, 1, 1)
-            # prepare cls_target
-            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
-            cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
-            # [N, Mp]
-            cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
-        del score_preds
-
-        #----------------------- Dynamic K-Matching -----------------------
-        cost_matrix = (
-            cls_cost
-            + 3.0 * reg_cost
-            + 100000.0 * (~is_in_boxes_and_center)
-        ) # [N, Mp]
-
-        (
-            assigned_labels,         # [num_fg,]
-            assigned_ious,           # [num_fg,]
-            assigned_indexs,         # [num_fg,]
-        ) = self.dynamic_k_matching(
-            cost_matrix,
-            pair_wise_ious,
-            tgt_labels,
-            num_gt,
-            fg_mask
-            )
-        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
-
-        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
-
-
-    def get_in_boxes_info(
-        self,
-        gt_bboxes,   # [N, 4]
-        anchors,     # [M, 2]
-        strides,     # [M,]
-        num_anchors, # M
-        num_gt,      # N
-        ):
-        # anchor center
-        x_centers = anchors[:, 0]
-        y_centers = anchors[:, 1]
-
-        # [M,] -> [1, M] -> [N, M]
-        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
-        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
-
-        # [N,] -> [N, 1] -> [N, M]
-        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
-        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
-        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
-        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
-
-        b_l = x_centers - gt_bboxes_l
-        b_r = gt_bboxes_r - x_centers
-        b_t = y_centers - gt_bboxes_t
-        b_b = gt_bboxes_b - y_centers
-        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
-
-        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
-        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
-        # in fixed center
-        center_radius = self.center_sampling_radius
-
-        # [N, 2]
-        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
-        
-        # [1, M]
-        center_radius_ = center_radius * strides.unsqueeze(0)
-
-        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
-        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
-        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
-        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
-
-        c_l = x_centers - gt_bboxes_l
-        c_r = gt_bboxes_r - x_centers
-        c_t = y_centers - gt_bboxes_t
-        c_b = gt_bboxes_b - y_centers
-        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
-        is_in_centers = center_deltas.min(dim=-1).values > 0.0
-        is_in_centers_all = is_in_centers.sum(dim=0) > 0
-
-        # in boxes and in centers
-        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
-
-        is_in_boxes_and_center = (
-            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
-        )
-        return is_in_boxes_anchor, is_in_boxes_and_center
-    
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # Assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
+        """Compute alignment metric given predicted and ground truth bounding boxes."""
+        na = pd_bboxes.shape[-2]
+        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
+
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
+        # Get the scores of each grid for each gt cls
+        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
+        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        return align_metric, overlaps
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
+        # (b, max_num_obj, topk)
+        topk_idxs.masked_fill_(~topk_mask, 0)
+
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
+        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
+        for k in range(self.topk_candidates):
+            # Expand topk_idxs for each value of k and add 1 at the specified positions
+            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
+        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
+        # Filter invalid bboxes
+        count_tensor.masked_fill_(count_tensor > 1, 0)
+
+        return count_tensor.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        # Assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # Assigned target scores
+        target_labels.clamp_(0)
+
+        # 10x faster than F.one_hot()
+        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
+                                    dtype=torch.int64,
+                                    device=target_labels.device)  # (b, h*w, 80)
+        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
+
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
     
-    def dynamic_k_matching(
-        self, 
-        cost, 
-        pair_wise_ious, 
-        gt_classes, 
-        num_gt, 
-        fg_mask
-        ):
-        # Dynamic K
-        # ---------------------------------------------------------------
-        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
-
-        ious_in_boxes_matrix = pair_wise_ious
-        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
-        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
-        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-        dynamic_ks = dynamic_ks.tolist()
-        for gt_idx in range(num_gt):
-            _, pos_idx = torch.topk(
-                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
-            )
-            matching_matrix[gt_idx][pos_idx] = 1
-
-        del topk_ious, dynamic_ks, pos_idx
-
-        anchor_matching_gt = matching_matrix.sum(0)
-        if (anchor_matching_gt > 1).sum() > 0:
-            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
-            matching_matrix[:, anchor_matching_gt > 1] *= 0
-            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
-        fg_mask_inboxes = matching_matrix.sum(0) > 0
-
-        fg_mask[fg_mask.clone()] = fg_mask_inboxes
-
-        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
-        assigned_labels = gt_classes[assigned_indexs]
-
-        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
-            fg_mask_inboxes
-        ]
-        return assigned_labels, assigned_ious, assigned_indexs
-    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(-2)
+    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+        fg_mask = mask_pos.sum(-2)
+    # Find each grid serve which gt(index)
+    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+
+    return target_gt_idx, fg_mask, mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 117 - 310
yolo/models/yolo11/modules.py

@@ -1,338 +1,145 @@
-import numpy as np
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
 
 
-# ---------------------------- 2D CNN ----------------------------
-class SiLU(nn.Module):
-    """export-friendly version of nn.SiLU()"""
-
-    @staticmethod
-    def forward(x):
-        return x * torch.sigmoid(x)
-
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-
-## Basic conv layer
-class Conv(nn.Module):
+# ----------------- CNN modules -----------------
+class ConvModule(nn.Module):
     def __init__(self, 
-                 c1,                   # in channels
-                 c2,                   # out channels 
-                 k=1,                  # kernel size 
-                 p=0,                  # padding
-                 s=1,                  # padding
-                 d=1,                  # dilation
-                 act_type='lrelu',     # activation
-                 norm_type='BN',       # normalization
-                 depthwise=False):
-        super(Conv, self).__init__()
-        convs = []
-        add_bias = False if norm_type else True
-        if depthwise:
-            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
-            # depthwise conv
-            if norm_type:
-                convs.append(get_norm(norm_type, c1))
-            if act_type:
-                convs.append(get_activation(act_type))
-            # pointwise conv
-            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-
-        else:
-            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-            
-        self.convs = nn.Sequential(*convs)
-
+                 in_dim,        # in channels
+                 out_dim,       # out channels 
+                 kernel_size=1, # kernel size 
+                 stride=1,      # padding
+                 groups=1,      # groups
+                 use_act: bool = True,
+                ):
+        super(ConvModule, self).__init__()
+        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size, padding=kernel_size//2, stride=stride, groups=groups, bias=False)
+        self.norm = nn.BatchNorm2d(out_dim)
+        self.act = nn.SiLU(inplace=True) if use_act else nn.Identity()
 
     def forward(self, x):
-        return self.convs(x)
-
-
-# ---------------------------- YOLOv7 Modules ----------------------------
-## ELAN-Block proposed by YOLOv7
-class ELANBlock(nn.Module):
-    def __init__(self, in_dim, out_dim, squeeze_ratio=0.5, branch_depth :int=2, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANBlock, self).__init__()
-        inter_dim = int(in_dim * squeeze_ratio)
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv3 = nn.Sequential(*[
-            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-            for _ in range(round(branch_depth))
-        ])
-        self.cv4 = nn.Sequential(*[
-            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-            for _ in range(round(branch_depth))
-        ])
-
-        self.out = Conv(inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-
-
+        return self.act(self.norm(self.conv(x)))
+
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :List = [3, 3],
+                 shortcut    :bool = False,
+                 expansion   :float = 0.5,
+                 ) -> None:
+        super(Bottleneck, self).__init__()
+        # ----------------- Network setting -----------------
+        inter_dim = int(out_dim * expansion)
+        self.cv1 = ConvModule(in_dim,  inter_dim, kernel_size=kernel_size[0], stride=1)
+        self.cv2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], stride=1)
+        self.shortcut = shortcut and in_dim == out_dim
 
     def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        x3 = self.cv3(x2)
-        x4 = self.cv4(x3)
-        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
-
-        return out
-
-## PaFPN's ELAN-Block proposed by YOLOv7
-class ELANBlockFPN(nn.Module):
-    def __init__(self, in_dim, out_dim, squeeze_ratio=0.5, branch_width :int=4, branch_depth :int=1, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANBlockFPN, self).__init__()
-        # Basic parameters
-        inter_dim = int(in_dim * squeeze_ratio)
-        inter_dim2 = int(inter_dim * squeeze_ratio) 
-        # Network structure
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv3 = nn.ModuleList()
-        for idx in range(round(branch_width)):
-            if idx == 0:
-                cvs = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
-            else:
-                cvs = [Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
-            # deeper
-            if round(branch_depth) > 1:
-                for _ in range(1, round(branch_depth)):
-                    cvs.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
-                self.cv3.append(nn.Sequential(*cvs))
-            else:
-                self.cv3.append(cvs[0])
-
-        self.out = Conv(inter_dim*2+inter_dim2*len(self.cv3), out_dim, k=1, act_type=act_type, norm_type=norm_type)
-
-
-    def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        inter_outs = [x1, x2]
-        for m in self.cv3:
-            y1 = inter_outs[-1]
-            y2 = m(y1)
-            inter_outs.append(y2)
-        out = self.out(torch.cat(inter_outs, dim=1))
-
-        return out
-
-## DownSample Block proposed by YOLOv7
-class DownSample(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+class C3kBlock(nn.Module):
+    def __init__(self,
+                 in_dim: int,
+                 out_dim: int,
+                 num_blocks: int = 1,
+                 shortcut: bool = True,
+                 expansion: float = 0.5,
+                 ):
         super().__init__()
-        inter_dim = out_dim // 2
-        self.mp = nn.MaxPool2d((2, 2), 2)
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = nn.Sequential(
-            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
-            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
+        inter_dim = int(out_dim * expansion)  # hidden channels
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv3 = ConvModule(2 * inter_dim, out_dim, kernel_size=1)  # optional act=FReLU(c2)
+        self.m = nn.Sequential(*[
+            Bottleneck(in_dim      = inter_dim,
+                       out_dim     = inter_dim,
+                       kernel_size = [3, 3],
+                       shortcut    = shortcut,
+                       expansion   = 1.0,
+                       ) for _ in range(num_blocks)])
 
     def forward(self, x):
-        x1 = self.cv1(self.mp(x))
-        x2 = self.cv2(x)
-        out = torch.cat([x1, x2], dim=1)
-
-        return out
-
-
-# ---------------------------- RepConv Modules ----------------------------
-class RepConv(nn.Module):
-    """
-        The code referenced to https://github.com/WongKinYiu/yolov7/models/common.py
-    """
-    # Represented convolution
-    # https://arxiv.org/abs/2101.03697
-
-    def __init__(self, c1, c2, k=3, s=1, p=1, g=1, act_type='silu', deploy=False):
-        super(RepConv, self).__init__()
-        # -------------- Basic parameters --------------
-        self.deploy = deploy
-        self.groups = g
-        self.in_channels = c1
-        self.out_channels = c2
-
-        # -------------- Network parameters --------------
-        if deploy:
-            self.rbr_reparam = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=True)
+        return self.cv3(torch.cat([self.m(self.cv1(x)), self.cv2(x)], dim=1))
 
+class C3k2fBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, num_blocks=1, use_c3k=True, expansion=0.5, shortcut=True):
+        super().__init__()
+        inter_dim = int(out_dim * expansion)  # hidden channels
+        self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
+        self.cv2 = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
+
+        if use_c3k:
+            self.m = nn.ModuleList(
+                C3kBlock(inter_dim, inter_dim, 2, shortcut)
+                for _ in range(num_blocks)
+            )
         else:
-            self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
-
-            self.rbr_dense = nn.Sequential(
-                nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False),
-                nn.BatchNorm2d(num_features=c2),
+            self.m = nn.ModuleList(
+                Bottleneck(inter_dim, inter_dim, [3, 3], shortcut, expansion=0.5)
+                for _ in range(num_blocks)
             )
 
-            self.rbr_1x1 = nn.Sequential(
-                nn.Conv2d(c1, c2, kernel_size=1, stride=s, bias=False),
-                nn.BatchNorm2d(num_features=c2),
-            )
-        self.act = get_activation(act_type)
+    def _forward_impl(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.cv1(x), 2, dim=1)
+        out = list([x1, x2])
 
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.m)
 
-    def forward(self, inputs):
-        if hasattr(self, "rbr_reparam"):
-            return self.act(self.rbr_reparam(inputs))
+        # Output proj
+        out = self.cv2(torch.cat(out, dim=1))
 
-        if self.rbr_identity is None:
-            id_out = 0
-        else:
-            id_out = self.rbr_identity(inputs)
+        return out
 
-        return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
-    
-    def get_equivalent_kernel_bias(self):
-        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
-        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
-        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
-        return (
-            kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
-            bias3x3 + bias1x1 + biasid,
-        )
+    def forward(self, x):
+        return self._forward_impl(x)
 
-    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
-        if kernel1x1 is None:
-            return 0
-        else:
-            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
+# ----------------- Attention modules  -----------------
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, attn_ratio=0.5):
+        super().__init__()
+        self.num_heads = num_heads
+        self.head_dim = dim // num_heads
+        self.key_dim = int(self.head_dim * attn_ratio)
+        self.scale = self.key_dim**-0.5
+        
+        nh_kd = self.key_dim * num_heads
+        h = dim + nh_kd * 2
+        self.qkv  = ConvModule(dim, h, kernel_size=1, use_act=False)
+        self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False)
+        self.pe   = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False)
 
-    def _fuse_bn_tensor(self, branch):
-        if branch is None:
-            return 0, 0
-        if isinstance(branch, nn.Sequential):
-            kernel = branch[0].weight
-            running_mean = branch[1].running_mean
-            running_var = branch[1].running_var
-            gamma = branch[1].weight
-            beta = branch[1].bias
-            eps = branch[1].eps
-        else:
-            assert isinstance(branch, nn.BatchNorm2d)
-            if not hasattr(self, "id_tensor"):
-                input_dim = self.in_channels // self.groups
-                kernel_value = np.zeros(
-                    (self.in_channels, input_dim, 3, 3), dtype=np.float32
-                )
-                for i in range(self.in_channels):
-                    kernel_value[i, i % input_dim, 1, 1] = 1
-                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
-            kernel = self.id_tensor
-            running_mean = branch.running_mean
-            running_var = branch.running_var
-            gamma = branch.weight
-            beta = branch.bias
-            eps = branch.eps
-        std = (running_var + eps).sqrt()
-        t = (gamma / std).reshape(-1, 1, 1, 1)
-        return kernel * t, beta - running_mean * gamma / std
+    def forward(self, x):
+        bs, c, h, w = x.shape
+        seq_len = h * w
 
-    def repvgg_convert(self):
-        kernel, bias = self.get_equivalent_kernel_bias()
-        return (
-            kernel.detach().cpu().numpy(),
-            bias.detach().cpu().numpy(),
+        qkv = self.qkv(x)
+        q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
+            [self.key_dim, self.key_dim, self.head_dim], dim=2
         )
 
-    def fuse_conv_bn(self, conv, bn):
-
-        std = (bn.running_var + bn.eps).sqrt()
-        bias = bn.bias - bn.running_mean * bn.weight / std
-
-        t = (bn.weight / std).reshape(-1, 1, 1, 1)
-        weights = conv.weight * t
-
-        bn = nn.Identity()
-        conv = nn.Conv2d(in_channels = conv.in_channels,
-                              out_channels = conv.out_channels,
-                              kernel_size = conv.kernel_size,
-                              stride=conv.stride,
-                              padding = conv.padding,
-                              dilation = conv.dilation,
-                              groups = conv.groups,
-                              bias = True,
-                              padding_mode = conv.padding_mode)
-
-        conv.weight = torch.nn.Parameter(weights)
-        conv.bias = torch.nn.Parameter(bias)
-        return conv
-
-    def fuse_repvgg_block(self):    
-        if self.deploy:
-            return
-                
-        self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
-        
-        self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
-        rbr_1x1_bias = self.rbr_1x1.bias
-        weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
-        
-        # Fuse self.rbr_identity
-        if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
-            identity_conv_1x1 = nn.Conv2d(
-                    in_channels=self.in_channels,
-                    out_channels=self.out_channels,
-                    kernel_size=1,
-                    stride=1,
-                    padding=0,
-                    groups=self.groups, 
-                    bias=False)
-            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
-            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
-
-            identity_conv_1x1.weight.data.fill_(0.0)
-            identity_conv_1x1.weight.data.fill_diagonal_(1.0)
-            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
+        attn = (q.transpose(-2, -1) @ k) * self.scale
+        attn = attn.softmax(dim=-1)
+        x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
+        x = self.proj(x)
 
-            identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
-            bias_identity_expanded = identity_conv_1x1.bias
-            weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])            
-        else:
-            bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
-            weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )            
-        
-        self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
-        self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
-                
-        self.rbr_reparam = self.rbr_dense
-        self.deploy = True
+        return x
 
-        if self.rbr_identity is not None:
-            del self.rbr_identity
-            self.rbr_identity = None
-
-        if self.rbr_1x1 is not None:
-            del self.rbr_1x1
-            self.rbr_1x1 = None
+class PSABlock(nn.Module):
+    def __init__(self, in_dim, attn_ratio=0.5, num_heads=4, shortcut=True):
+        super().__init__()
+        self.attn = Attention(in_dim, attn_ratio=attn_ratio, num_heads=num_heads)
+        self.ffn = nn.Sequential(ConvModule(in_dim, in_dim * 2, kernel_size=1),
+                                 ConvModule(in_dim * 2, in_dim, kernel_size=1, use_act=False))
+        self.add = shortcut
 
-        if self.rbr_dense is not None:
-            del self.rbr_dense
-            self.rbr_dense = None
+    def forward(self, x):
+        x = x + self.attn(x) if self.add else self.attn(x)
+        x = x + self.ffn(x)  if self.add else self.ffn(x)
+        return x

+ 69 - 214
yolo/models/yolo11/yolo11.py

@@ -1,112 +1,73 @@
+# --------------- Torch components ---------------
 import torch
 import torch.nn as nn
 
-from utils.misc import multiclass_nms
+# --------------- Model components ---------------
+from .yolo11_backbone import Yolo11Backbone
+from .yolo11_neck     import SPPF, C2PSA
+from .yolo11_pafpn    import Yolo11PaFPN
+from .yolo11_head     import Yolo11DetHead
+from .yolo11_pred     import Yolo11DetPredLayer
 
-from .yolo11_backbone import build_backbone
-from .yolo11_neck import build_neck
-from .yolo11_pafpn import build_fpn
-from .yolo11_head import build_head
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
 
 
-# YOLOv7
-class YOLOv7(nn.Module):
+# YOLO11
+class Yolo11(nn.Module):
     def __init__(self,
                  cfg,
-                 device,
-                 num_classes=20,
-                 conf_thresh=0.01,
-                 topk=100,
-                 nms_thresh=0.5,
-                 trainable=False,
-                 deploy = False,
-                 no_multi_labels = False,
-                 nms_class_agnostic = False):
-        super(YOLOv7, self).__init__()
-        # ------------------- Basic parameters -------------------
-        self.cfg = cfg                                 # 模型配置文件
-        self.device = device                           # cuda或者是cpu
-        self.num_classes = num_classes                 # 类别的数量
-        self.trainable = trainable                     # 训练的标记
-        self.conf_thresh = conf_thresh                 # 得分阈值
-        self.nms_thresh = nms_thresh                   # NMS阈值
-        self.topk_candidates = topk                    # topk
-        self.stride = [8, 16, 32]                      # 网络的输出步长
-        self.num_levels = 3
-        self.deploy = deploy
-        self.no_multi_labels = no_multi_labels
-        self.nms_class_agnostic = nms_class_agnostic
-        # ------------------- Network Structure -------------------
-        ## 主干网络
-        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
-
-        ## 颈部网络: SPP模块
-        self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1]//2)
-        feats_dim[-1] = self.neck.out_dim
-
-        ## 颈部网络: 特征金字塔
-        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['channel_width']))
-        self.head_dim = self.fpn.out_dim
-
-        ## 检测头
-        self.non_shared_heads = nn.ModuleList(
-            [build_head(cfg, head_dim, head_dim, num_classes) 
-            for head_dim in self.head_dim
-            ])
+                 is_val = False,
+                 ) -> None:
+        super(Yolo11, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+        
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone = Yolo11Backbone(cfg)
+        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
+
+        ## Neck
+        self.neck_spp  = SPPF(self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
+        self.neck_attn = C2PSA(self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1], num_blocks=int(2 * cfg.depth), expansion=0.5)
+        
+        ## Neck: PaFPN
+        self.fpn = Yolo11PaFPN(cfg, self.backbone.feat_dims)
 
-        ## 预测层
-        self.obj_preds = nn.ModuleList(
-                            [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ]) 
-        self.cls_preds = nn.ModuleList(
-                            [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ]) 
-        self.reg_preds = nn.ModuleList(
-                            [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ])                 
+        ## Head
+        self.head = Yolo11DetHead(cfg, self.fpn.out_dims)
 
+        ## Pred
+        self.pred = Yolo11DetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
 
-    # ---------------------- Basic Functions ----------------------
-    ## generate anchor points
-    def generate_anchors(self, level, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchor_xy += 0.5  # add center offset
-        anchor_xy *= self.stride[level]
-        anchors = anchor_xy.to(self.device)
-
-        return anchors
-        
-    ## post-process
-    def post_process(self, obj_preds, cls_preds, box_preds):
+    def post_process(self, cls_preds, box_preds):
         """
+        We process predictions at each scale hierarchically
         Input:
-            cls_preds: List[np.array] -> [[M, C], ...]
-            box_preds: List[np.array] -> [[M, 4], ...]
-            obj_preds: List[np.array] -> [[M, 1], ...] or None
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
         Output:
             bboxes: np.array -> [N, 4]
             scores: np.array -> [N,]
             labels: np.array -> [N,]
         """
-        assert len(cls_preds) == self.num_levels
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
             if self.no_multi_labels:
                 # [M,]
-                scores, labels = torch.max(torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -123,10 +84,9 @@ class YOLOv7(nn.Module):
 
                 labels = labels[topk_idxs]
                 bboxes = box_pred_i[topk_idxs]
-
             else:
                 # [M, C] -> [MC,]
-                scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+                scores_i = cls_pred_i.sigmoid().flatten()
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -150,9 +110,9 @@ class YOLOv7(nn.Module):
             all_labels.append(labels)
             all_bboxes.append(bboxes)
 
-        scores = torch.cat(all_scores)
-        labels = torch.cat(all_labels)
-        bboxes = torch.cat(all_bboxes)
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
 
         # to cpu & numpy
         scores = scores.cpu().numpy()
@@ -161,142 +121,37 @@ class YOLOv7(nn.Module):
 
         # nms
         scores, labels, bboxes = multiclass_nms(
-            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
-
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
         return bboxes, scores, labels
     
-
-    # ---------------------- Main Process for Inference ----------------------
-    @torch.no_grad()
-    def inference_single_image(self, x):
-        # 主干网络
+    def forward(self, x):
+        # ---------------- Backbone ----------------
         pyramid_feats = self.backbone(x)
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck_spp(pyramid_feats[-1])
+        pyramid_feats[-1] = self.neck_attn(pyramid_feats[-1])
 
-        # 颈部网络
-        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-        # 特征金字塔
+        # ---------------- Neck: PaFPN ----------------
         pyramid_feats = self.fpn(pyramid_feats)
 
-        # 检测头
-        all_obj_preds = []
-        all_cls_preds = []
-        all_box_preds = []
-        all_anchors = []
-        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
-            cls_feat, reg_feat = head(feat)
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(pyramid_feats)
 
-            # [1, C, H, W]
-            obj_pred = self.obj_preds[level](reg_feat)
-            cls_pred = self.cls_preds[level](cls_feat)
-            reg_pred = self.reg_preds[level](reg_feat)
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
 
-            # anchors: [M, 2]
-            fmp_size = cls_pred.shape[-2:]
-            anchors = self.generate_anchors(level, fmp_size)
+        if not self.training:
+            all_cls_preds = outputs['pred_cls']
+            all_box_preds = outputs['pred_box']
 
-            # [1, C, H, W] -> [H, W, C] -> [M, C]
-            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
-            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
-            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
-
-            # decode bbox
-            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-            pred_x1y1 = ctr_pred - wh_pred * 0.5
-            pred_x2y2 = ctr_pred + wh_pred * 0.5
-            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-            all_obj_preds.append(obj_pred)
-            all_cls_preds.append(cls_pred)
-            all_box_preds.append(box_pred)
-            all_anchors.append(anchors)
-
-        if self.deploy:
-            obj_preds = torch.cat(all_obj_preds, dim=0)
-            cls_preds = torch.cat(all_cls_preds, dim=0)
-            box_preds = torch.cat(all_box_preds, dim=0)
-            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
-            bboxes = box_preds
-            # [n_anchors_all, 4 + C]
-            outputs = torch.cat([bboxes, scores], dim=-1)
-
-        else:
             # post process
-            bboxes, scores, labels = self.post_process(
-                all_obj_preds, all_cls_preds, all_box_preds)
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
             outputs = {
                 "scores": scores,
                 "labels": labels,
                 "bboxes": bboxes
             }
-
-        return outputs
-
-    # ---------------------- Main Process for Training ----------------------
-    def forward(self, x):
-        if not self.trainable:
-            return self.inference_single_image(x)
-        else:
-            # 主干网络
-            pyramid_feats = self.backbone(x)
-
-            # 颈部网络
-            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-            # 特征金字塔
-            pyramid_feats = self.fpn(pyramid_feats)
-
-            # 检测头
-            all_anchors = []
-            all_strides = []
-            all_obj_preds = []
-            all_cls_preds = []
-            all_box_preds = []
-            all_reg_preds = []
-            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
-                cls_feat, reg_feat = head(feat)
-
-                # [B, C, H, W]
-                obj_pred = self.obj_preds[level](reg_feat)
-                cls_pred = self.cls_preds[level](cls_feat)
-                reg_pred = self.reg_preds[level](reg_feat)
-
-                B, _, H, W = cls_pred.size()
-                fmp_size = [H, W]
-                # generate anchor boxes: [M, 4]
-                anchors = self.generate_anchors(level, fmp_size)
-                
-                # stride tensor: [M, 1]
-                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
-
-                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
-                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-
-                # decode bbox
-                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-                pred_x1y1 = ctr_pred - wh_pred * 0.5
-                pred_x2y2 = ctr_pred + wh_pred * 0.5
-                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-                all_obj_preds.append(obj_pred)
-                all_cls_preds.append(cls_pred)
-                all_box_preds.append(box_pred)
-                all_reg_preds.append(reg_pred)
-                all_anchors.append(anchors)
-                all_strides.append(stride_tensor)
-            
-            # output dict
-            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
-                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
-                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
-                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4]
-                       "anchors": all_anchors,           # List(Tensor) [M, 2]
-                       "strides": self.stride,           # List(Int) [8, 16, 32]
-                       "stride_tensors": all_strides     # List(Tensor) [M, 1]
-                       }
-
-            return outputs 
+        
+        return outputs 

+ 66 - 187
yolo/models/yolo11/yolo11_backbone.py

@@ -2,153 +2,74 @@ import torch
 import torch.nn as nn
 
 try:
-    from .modules import Conv, ELANBlock, DownSample
+    from .modules import ConvModule, C3k2fBlock
 except:
-    from yolo.models.yolov10.modules import Conv, ELANBlock, DownSample
-    
+    from  modules import ConvModule, C3k2fBlock
 
-model_urls = {
-    "elannet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_tiny.pth",
-    "elannet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_large.pth",
-    "elannet_huge": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_huge.pth",
-}
 
-
-# --------------------- ELANNet -----------------------
-## ELANNet-Tiny
-class ELANNet_Tiny(nn.Module):
-    """
-    ELAN-Net of YOLOv7-Tiny.
-    """
-    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANNet_Tiny, self).__init__()
-        # -------------- Basic parameters --------------
-        self.feat_dims = [32, 64, 128, 256, 512]
-        self.squeeze_ratios = [0.5, 0.5, 0.5, 0.5]   # Stage-1 -> Stage-4
-        self.branch_depths = [1, 1, 1, 1]            # Stage-1 -> Stage-4
+# ---------------------------- YOLO11 Backbone ----------------------------
+class Yolo11Backbone(nn.Module):
+    def __init__(self, cfg):
+        super(Yolo11Backbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.model_scale
+        self.feat_dims = [int(512 * cfg.width), int(512 * cfg.width), int(512 * cfg.width * cfg.ratio)]
         
-        # -------------- Network parameters --------------
-        ## P1/2
-        self.layer_1 = Conv(3, self.feat_dims[0], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        ## P2/4: Stage-1
-        self.layer_2 = nn.Sequential(   
-            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
-            ELANBlock(self.feat_dims[1], self.feat_dims[1], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P3/8: Stage-2
-        self.layer_3 = nn.Sequential(
-            nn.MaxPool2d((2, 2), 2),             
-            ELANBlock(self.feat_dims[1], self.feat_dims[2], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P4/16: Stage-3
-        self.layer_4 = nn.Sequential(
-            nn.MaxPool2d((2, 2), 2),             
-            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P5/32: Stage-4
-        self.layer_5 = nn.Sequential(
-            nn.MaxPool2d((2, 2), 2),             
-            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-## ELANNet-Large
-class ELANNet_Lagre(nn.Module):
-    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANNet_Lagre, self).__init__()
-        # -------------------- Basic parameters --------------------
-        self.feat_dims = [32, 64, 128, 256, 512, 1024, 1024]
-        self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25]  # Stage-1 -> Stage-4
-        self.branch_depths = [2, 2, 2, 2]            # Stage-1 -> Stage-4
-
-        # -------------------- Network parameters --------------------
-        ## P1/2
-        self.layer_1 = nn.Sequential(
-            Conv(3, self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),      
-            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            Conv(self.feat_dims[1], self.feat_dims[1], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P2/4: Stage-1
-        self.layer_2 = nn.Sequential(   
-            Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
-            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P3/8: Stage-2
-        self.layer_3 = nn.Sequential(
-            DownSample(self.feat_dims[3], self.feat_dims[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P4/16: Stage-3
-        self.layer_4 = nn.Sequential(
-            DownSample(self.feat_dims[4], self.feat_dims[4], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P5/32: Stage-4
-        self.layer_5 = nn.Sequential(
-            DownSample(self.feat_dims[5], self.feat_dims[5], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-## ELANNet-Huge
-class ELANNet_Huge(nn.Module):
-    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANNet_Huge, self).__init__()
-        # -------------------- Basic parameters --------------------
-        self.feat_dims = [40, 80, 160, 320, 640, 1280, 1280]
-        self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25]  # Stage-1 -> Stage-4
-        self.branch_depths = [3, 3, 3, 3]            # Stage-1 -> Stage-4
-
-        # -------------------- Network parameters --------------------
+        # ------------------ Network setting ------------------
         ## P1/2
-        self.layer_1 = nn.Sequential(
-            Conv(3, self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),      
-            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            Conv(self.feat_dims[1], self.feat_dims[1], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.layer_1 = ConvModule(3, int(64 * cfg.width), kernel_size=3, stride=2)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            ConvModule(int(64 * cfg.width), int(128 * cfg.width), kernel_size=3, stride=2),
+            C3k2fBlock(in_dim     = int(128 * cfg.width),
+                      out_dim    = int(256 * cfg.width),
+                      num_blocks = round(2*cfg.depth),
+                      shortcut   = True,
+                      expansion  = 0.25,
+                      use_c3k    = False if self.model_scale in "ns" else True,
+                      )
         )
-        ## P2/4: Stage-1
-        self.layer_2 = nn.Sequential(   
-            Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
-            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P3/8: Stage-2
+        # P3/8
         self.layer_3 = nn.Sequential(
-            DownSample(self.feat_dims[3], self.feat_dims[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            ConvModule(int(256 * cfg.width), int(256 * cfg.width), kernel_size=3, stride=2),
+            C3k2fBlock(in_dim     = int(256 * cfg.width),
+                      out_dim    = int(512 * cfg.width),
+                      num_blocks = round(2*cfg.depth),
+                      shortcut   = True,
+                      expansion  = 0.25,
+                      use_c3k    = False if self.model_scale in "ns" else True,
+                      )
         )
-        ## P4/16: Stage-3
+        # P4/16
         self.layer_4 = nn.Sequential(
-            DownSample(self.feat_dims[4], self.feat_dims[4], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            ConvModule(int(512 * cfg.width), int(512 * cfg.width), kernel_size=3, stride=2),
+            C3k2fBlock(in_dim     = int(512 * cfg.width),
+                      out_dim    = int(512 * cfg.width),
+                      num_blocks = round(2*cfg.depth),
+                      shortcut   = True,
+                      expansion  = 0.5,
+                      use_c3k    = True,
+                      )
         )
-        ## P5/32: Stage-4
+        # P5/32
         self.layer_5 = nn.Sequential(
-            DownSample(self.feat_dims[5], self.feat_dims[5], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            ConvModule(int(512 * cfg.width), int(512 * cfg.width * cfg.ratio), kernel_size=3, stride=2),
+            C3k2fBlock(in_dim     = int(512 * cfg.width * cfg.ratio),
+                      out_dim    = int(512 * cfg.width * cfg.ratio),
+                      num_blocks = round(2*cfg.depth),
+                      shortcut   = True,
+                      expansion  = 0.5,
+                      use_c3k    = True,
+                      )
         )
 
+        # Initialize all layers
+        self.init_weights()
+        
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, x):
         c1 = self.layer_1(x)
@@ -156,65 +77,23 @@ class ELANNet_Huge(nn.Module):
         c3 = self.layer_3(c2)
         c4 = self.layer_4(c3)
         c5 = self.layer_5(c4)
-
         outputs = [c3, c4, c5]
 
         return outputs
 
 
-# --------------------- Functions -----------------------
-## build backbone
-def build_backbone(cfg, pretrained=False): 
-    # build backbone
-    if cfg['backbone'] == 'elannet_huge':
-        backbone = ELANNet_Huge(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
-    elif cfg['backbone'] == 'elannet_large':
-        backbone = ELANNet_Lagre(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
-    elif cfg['backbone'] == 'elannet_tiny':
-        backbone = ELANNet_Tiny(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
-    # pyramid feat dims
-    feat_dims = backbone.feat_dims[-3:]
-
-    # load imagenet pretrained weight
-    if pretrained:
-        url = model_urls[cfg['backbone']]
-        if url is not None:
-            print('Loading pretrained weight for {}.'.format(cfg['backbone'].upper()))
-            checkpoint = torch.hub.load_state_dict_from_url(
-                url=url, map_location="cpu", check_hash=True)
-            # checkpoint state dict
-            checkpoint_state_dict = checkpoint.pop("model")
-            # model state dict
-            model_state_dict = backbone.state_dict()
-            # check
-            for k in list(checkpoint_state_dict.keys()):
-                if k in model_state_dict:
-                    shape_model = tuple(model_state_dict[k].shape)
-                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                    if shape_model != shape_checkpoint:
-                        checkpoint_state_dict.pop(k)
-                else:
-                    checkpoint_state_dict.pop(k)
-                    print('Unused key: ', k)
-
-            backbone.load_state_dict(checkpoint_state_dict)
-        else:
-            print('No backbone pretrained: ELANNet')        
-
-    return backbone, feat_dims
-
-
 if __name__ == '__main__':
     import time
     from thop import profile
-    cfg = {
-        'pretrained': False,
-        'backbone': 'elannet_tiny',
-        'bk_act': 'silu',
-        'bk_norm': 'BN',
-        'bk_dpw': False,
-    }
-    model, feats = build_backbone(cfg)
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.width = 0.25
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.model_scale = "n"
+            
+    cfg = BaseConfig()
+    model = Yolo11Backbone(cfg)
     x = torch.randn(1, 3, 640, 640)
     t0 = time.time()
     outputs = model(x)
@@ -223,8 +102,8 @@ if __name__ == '__main__':
     for out in outputs:
         print(out.shape)
 
+    x = torch.randn(1, 3, 640, 640)
     print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 121 - 46
yolo/models/yolo11/yolo11_head.py

@@ -1,61 +1,64 @@
 import torch
 import torch.nn as nn
+from typing import List
 
-from .modules import Conv
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
 
 
-class DecoupledHead(nn.Module):
-    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+# -------------------- Detection Head --------------------
+## Single-level Detection Head
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 ):
         super().__init__()
-        print('==============================')
-        print('Head: Decoupled Head')
+        # --------- Basic Parameters ----------
         self.in_dim = in_dim
-        self.num_cls_head=cfg['num_cls_head']
-        self.num_reg_head=cfg['num_reg_head']
-        self.act_type=cfg['head_act']
-        self.norm_type=cfg['head_norm']
-
-        # cls head
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        
+        # --------- Network Parameters ----------
+        ## classification head
         cls_feats = []
-        self.cls_out_dim = max(out_dim, num_classes)
-        for i in range(cfg['num_cls_head']):
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
             if i == 0:
-                cls_feats.append(
-                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
+                cls_feats.append(nn.Sequential(
+                    ConvModule(in_dim, in_dim, kernel_size=3, stride=1, groups=in_dim),
+                    ConvModule(in_dim, self.cls_head_dim, kernel_size=1),
+                ))
             else:
-                cls_feats.append(
-                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-                
-        # reg head
+                cls_feats.append(nn.Sequential(
+                    ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, stride=1, groups=self.cls_head_dim),
+                    ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=1),
+                ))
+        
+        ## bbox regression head
         reg_feats = []
-        self.reg_out_dim = max(out_dim, 64)
-        for i in range(cfg['num_reg_head']):
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
             if i == 0:
-                reg_feats.append(
-                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
+                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, stride=1))
             else:
-                reg_feats.append(
-                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-
+                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, stride=1))
+        
         self.cls_feats = nn.Sequential(*cls_feats)
         self.reg_feats = nn.Sequential(*reg_feats)
 
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, x):
         """
@@ -66,9 +69,81 @@ class DecoupledHead(nn.Module):
 
         return cls_feats, reg_feats
     
+## Multi-level Detection Head
+class Yolo11DetHead(nn.Module):
+    def __init__(self, cfg, in_dims: List = [256, 512, 1024]):
+        super().__init__()
+        self.num_levels = len(cfg.out_stride)
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     ) for level in range(self.num_levels)])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    
+    # YOLO11-Base config
+    class Yolo11BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.ratio    = 2.0
+            self.reg_max  = 16
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+            self.num_cls_head = 2
+            self.num_reg_head = 2
+
+    cfg = Yolo11BaseConfig()
+    cfg.num_classes = 20
+
+    # Build a head
+    fpn_dims = [128, 256, 512]
+    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80),
+                     torch.randn(1, fpn_dims[1], 40, 40),
+                     torch.randn(1, fpn_dims[2], 20, 20)]
+    head = Yolo11DetHead(cfg, fpn_dims)
+
 
-# build detection head
-def build_head(cfg, in_dim, out_dim, num_classes=80):
-    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+    # Inference
+    t0 = time.time()
+    cls_feats, reg_feats = head(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print("====== Yolo11 Head output ======")
+    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
+        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
 
-    return head
+    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))
+    

+ 28 - 81
yolo/models/yolo11/yolo11_neck.py

@@ -1,20 +1,22 @@
 import torch
 import torch.nn as nn
-from .modules import Conv
+
+try:
+    from .modules import ConvModule, PSABlock
+except:
+    from  modules import ConvModule, PSABlock
 
 
-# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
 class SPPF(nn.Module):
-    """
-        This code referenced to https://github.com/ultralytics/yolov5
-    """
-    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='lrelu', norm_type='BN'):
+    def __init__(self, in_dim, out_dim, spp_pooling_size: int = 5, neck_expand_ratio:float = 0.5):
         super().__init__()
-        inter_dim = int(in_dim * expand_ratio)
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * neck_expand_ratio)
         self.out_dim = out_dim
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+        ## ----------- Network Parameters -----------
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, stride=1)
+        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, stride=1)
+        self.m = nn.MaxPool2d(kernel_size=spp_pooling_size, stride=1, padding=spp_pooling_size // 2)
 
     def forward(self, x):
         x = self.cv1(x)
@@ -22,77 +24,22 @@ class SPPF(nn.Module):
         y2 = self.m(y1)
 
         return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+    
+class C2PSA(nn.Module):
+    def __init__(self, in_dim, out_dim, num_blocks=1, expansion=0.5):
+        super().__init__()
+        assert in_dim == out_dim
+        inter_dim = int(in_dim * expansion)
+        self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(2 * inter_dim, in_dim, kernel_size=1)
+        self.m = nn.Sequential(*[
+            PSABlock(in_dim     = inter_dim,
+                     attn_ratio = 0.5,
+                     num_heads  = inter_dim // 64
+                     ) for _ in range(num_blocks)])
 
-
-# SPPF block with CSP module
-class SPPFBlockCSP(nn.Module):
-    """
-        CSP Spatial Pyramid Pooling Block
-    """
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expand_ratio=0.5,
-                 pooling_size=5,
-                 act_type='lrelu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super(SPPFBlockCSP, self).__init__()
-        inter_dim = int(in_dim * expand_ratio)
-        self.out_dim = out_dim
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.Sequential(
-            Conv(inter_dim, inter_dim, k=3, p=1, 
-                 act_type=act_type, norm_type=norm_type, 
-                 depthwise=depthwise),
-            SPPF(inter_dim, 
-                 inter_dim, 
-                 expand_ratio=1.0, 
-                 pooling_size=pooling_size, 
-                 act_type=act_type, 
-                 norm_type=norm_type),
-            Conv(inter_dim, inter_dim, k=3, p=1, 
-                 act_type=act_type, norm_type=norm_type, 
-                 depthwise=depthwise)
-        )
-        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
-
-        
     def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        x3 = self.m(x2)
-        y = self.cv3(torch.cat([x1, x3], dim=1))
-
-        return y
-
-
-def build_neck(cfg, in_dim, out_dim):
-    model = cfg['neck']
-    print('==============================')
-    print('Neck: {}'.format(model))
-    # build neck
-    if model == 'sppf':
-        neck = SPPF(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            expand_ratio=cfg['expand_ratio'], 
-            pooling_size=cfg['pooling_size'],
-            act_type=cfg['neck_act'],
-            norm_type=cfg['neck_norm']
-            )
-    elif model == 'csp_sppf':
-        neck = SPPFBlockCSP(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            expand_ratio=cfg['expand_ratio'], 
-            pooling_size=cfg['pooling_size'],
-            act_type=cfg['neck_act'],
-            norm_type=cfg['neck_norm'],
-            depthwise=cfg['neck_depthwise']
-            )
+        x1, x2 = torch.chunk(self.cv1(x), chunks=2, dim=1)
+        x2 = self.m(x2)
 
-    return neck
-        
+        return self.cv2(torch.cat([x1, x2], dim=1))

+ 101 - 120
yolo/models/yolo11/yolo11_pafpn.py

@@ -1,146 +1,127 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from .modules import Conv, ELANBlockFPN, DownSample
+from typing import List
 
+try:
+    from .modules import ConvModule, C3k2fBlock
+except:
+    from  modules import ConvModule, C3k2fBlock
 
-# PaFPN-ELAN (YOLOv7's)
-class Yolov7PaFPN(nn.Module):
-    def __init__(self, 
-                 in_dims=[512, 1024, 512],
-                 out_dim=None,
-                 channel_width : float = 1.0,
-                 branch_width  : int   = 4.0,
-                 branch_depth  : int   = 1.0,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False):
-        super(Yolov7PaFPN, self).__init__()
-        # ----------------------------- Basic parameters -----------------------------
-        self.fpn_dims = in_dims
-        self.channel_width = channel_width
-        self.branch_width = branch_width
-        self.branch_depth = branch_depth
-        c3, c4, c5 = self.fpn_dims
 
-        # ----------------------------- Top-down FPN -----------------------------
+class Yolo11PaFPN(nn.Module):
+    def __init__(self, cfg, in_dims :List = [256, 512, 1024]):
+        super(Yolo11PaFPN, self).__init__()
+        # --------------------------- Basic Parameters ---------------------------
+        self.model_scale = cfg.model_scale
+        self.in_dims = in_dims[::-1]
+        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
+
+        # ----------------------------- Yolo11's Top-down FPN -----------------------------
         ## P5 -> P4
-        self.reduce_layer_1 = Conv(c5, round(256*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.reduce_layer_2 = Conv(c4, round(256*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.top_down_layer_1 = ELANBlockFPN(in_dim=round(256*channel_width) + round(256*channel_width),
-                                             out_dim=round(256*channel_width),
-                                             squeeze_ratio=0.5,
-                                             branch_width=branch_width,
-                                             branch_depth=branch_depth,
-                                             act_type=act_type,
-                                             norm_type=norm_type,
-                                             depthwise=depthwise
-                                             )
+        self.top_down_layer_1 = C3k2fBlock(in_dim     = self.in_dims[0] + self.in_dims[1],
+                                          out_dim    = round(512*cfg.width),
+                                          num_blocks = round(2 * cfg.depth),
+                                          shortcut   = True,
+                                          expansion  = 0.5,
+                                          use_c3k    = False if self.model_scale in "ns" else True,
+                                          )
         ## P4 -> P3
-        self.reduce_layer_3 = Conv(round(256*channel_width), round(128*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.reduce_layer_4 = Conv(c3, round(128*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.top_down_layer_2 = ELANBlockFPN(in_dim=round(128*channel_width) + round(128*channel_width),
-                                             out_dim=round(128*channel_width),
-                                             squeeze_ratio=0.5,
-                                             branch_width=branch_width,
-                                             branch_depth=branch_depth,
-                                             act_type=act_type,
-                                             norm_type=norm_type,
-                                             depthwise=depthwise
-                                             )
-        # ----------------------------- Bottom-up FPN -----------------------------
+        self.top_down_layer_2 = C3k2fBlock(in_dim     = self.in_dims[2] + round(512*cfg.width),
+                                          out_dim    = round(256*cfg.width),
+                                          num_blocks = round(2 * cfg.depth),
+                                          shortcut   = True,
+                                          expansion  = 0.5,
+                                          use_c3k    = False if self.model_scale in "ns" else True,
+                                          )
+        # ----------------------------- Yolo11's Bottom-up PAN -----------------------------
         ## P3 -> P4
-        self.downsample_layer_1 = DownSample(round(128*channel_width), round(256*channel_width), act_type, norm_type, depthwise)
-        self.bottom_up_layer_1 = ELANBlockFPN(in_dim=round(256*channel_width) + round(256*channel_width),
-                                              out_dim=round(256*channel_width),
-                                              squeeze_ratio=0.5,
-                                              branch_width=branch_width,
-                                              branch_depth=branch_depth,
-                                              act_type=act_type,
-                                              norm_type=norm_type,
-                                              depthwise=depthwise
-                                              )
+        self.dowmsample_layer_1 = ConvModule(round(256*cfg.width), round(256*cfg.width), kernel_size=3, stride=2)
+        self.bottom_up_layer_1 = C3k2fBlock(in_dim     = round(256*cfg.width) + round(512*cfg.width),
+                                           out_dim    = round(512*cfg.width),
+                                           num_blocks = round(2 * cfg.depth),
+                                           shortcut   = True,
+                                           expansion  = 0.5,
+                                           use_c3k    = False if self.model_scale in "ns" else True,
+                                           )
         ## P4 -> P5
-        self.downsample_layer_2 = DownSample(round(256*channel_width), round(512*channel_width), act_type, norm_type, depthwise)
-        self.bottom_up_layer_2 = ELANBlockFPN(in_dim=round(512*channel_width) + c5,
-                                              out_dim=round(512*channel_width),
-                                              squeeze_ratio=0.5,
-                                              branch_width=branch_width,
-                                              branch_depth=branch_depth,
-                                              act_type=act_type,
-                                              norm_type=norm_type,
-                                              depthwise=depthwise
-                                              )
-        # ----------------------------- Output Proj -----------------------------
-        ## Head convs
-        self.head_conv_1 = Conv(round(128*channel_width), round(256*channel_width), k=3, s=1, p=1, act_type=act_type, norm_type=norm_type)
-        self.head_conv_2 = Conv(round(256*channel_width), round(512*channel_width), k=3, s=1, p=1, act_type=act_type, norm_type=norm_type)
-        self.head_conv_3 = Conv(round(512*channel_width), round(1024*channel_width), k=3, s=1, p=1, act_type=act_type, norm_type=norm_type)
-        ## Output projs
-        if out_dim is not None:
-            self.out_layers = nn.ModuleList([
-                Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-                for in_dim in [round(256*channel_width), round(512*channel_width), round(1024*channel_width)]
-                ])
-            self.out_dim = [out_dim] * 3
-        else:
-            self.out_layers = None
-            self.out_dim = [round(256*channel_width), round(512*channel_width), round(1024*channel_width)]
+        self.dowmsample_layer_2 = ConvModule(round(512*cfg.width), round(512*cfg.width), kernel_size=3, stride=2)
+        self.bottom_up_layer_2 = C3k2fBlock(in_dim     = round(512*cfg.width) + self.in_dims[0],
+                                           out_dim    = round(512*cfg.width*cfg.ratio),
+                                           num_blocks = round(2 * cfg.depth),
+                                           shortcut   = True,
+                                           expansion  = 0.5,
+                                           use_c3k    = True,
+                                           )
 
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, features):
         c3, c4, c5 = features
 
-        # Top down
+        # ------------------ Top down FPN ------------------
         ## P5 -> P4
-        c6 = self.reduce_layer_1(c5)
-        c7 = F.interpolate(c6, scale_factor=2.0)
-        c8 = torch.cat([c7, self.reduce_layer_2(c4)], dim=1)
-        c9 = self.top_down_layer_1(c8)
+        p5_up = F.interpolate(c5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
+
         ## P4 -> P3
-        c10 = self.reduce_layer_3(c9)
-        c11 = F.interpolate(c10, scale_factor=2.0)
-        c12 = torch.cat([c11, self.reduce_layer_4(c3)], dim=1)
-        c13 = self.top_down_layer_2(c12)
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
 
-        # Bottom up
+        # ------------------ Bottom up FPN ------------------
         ## p3 -> P4
-        c14 = self.downsample_layer_1(c13)
-        c15 = torch.cat([c14, c9], dim=1)
-        c16 = self.bottom_up_layer_1(c15)
-        ## P4 -> P5
-        c17 = self.downsample_layer_2(c16)
-        c18 = torch.cat([c17, c5], dim=1)
-        c19 = self.bottom_up_layer_2(c18)
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
 
-        c20 = self.head_conv_1(c13)
-        c21 = self.head_conv_2(c16)
-        c22 = self.head_conv_3(c19)
-        out_feats = [c20, c21, c22] # [P3, P4, P5]
-        
-        # output proj layers
-        if self.out_layers is not None:
-            out_feats_proj = []
-            for feat, layer in zip(out_feats, self.out_layers):
-                out_feats_proj.append(layer(feat))
-            return out_feats_proj
+        ## P4 -> 5
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
 
+        out_feats = [p3, p4, p5] # [P3, P4, P5]
+                
         return out_feats
+    
 
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov8BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.ratio    = 2.0
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.model_scale = "s"
 
-def build_fpn(cfg, in_dims, out_dim=None):
-    model = cfg['fpn']
-    # build pafpn
-    if model == 'yolov7_pafpn':
-        fpn_net = Yolov7PaFPN(in_dims       = in_dims,
-                              out_dim       = out_dim,
-                              channel_width = cfg['channel_width'],
-                              branch_width  = cfg['branch_width'],
-                              branch_depth  = cfg['branch_depth'],
-                              act_type      = cfg['fpn_act'],
-                              norm_type     = cfg['fpn_norm'],
-                              depthwise     = cfg['fpn_depthwise']
-                              )
+    cfg = Yolov8BaseConfig()
+    # Build a head
+    in_dims  = [128, 256, 512]
+    fpn = Yolo11PaFPN(cfg, in_dims)
 
+    # Inference
+    x = [torch.randn(1, in_dims[0], 80, 80),
+         torch.randn(1, in_dims[1], 40, 40),
+         torch.randn(1, in_dims[2], 20, 20)]
+    t0 = time.time()
+    output = fpn(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== FPN output ====== ')
+    for level, feat in enumerate(output):
+        print("- Level-{} : ".format(level), feat.shape)
 
-    return fpn_net
+    flops, params = profile(fpn, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 207 - 0
yolo/models/yolo11/yolo11_pred.py

@@ -0,0 +1,207 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# -------------------- Detection Pred Layer --------------------
+## Single-level pred layer
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 reg_max     :int = 16,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.reg_max = reg_max
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+## Multi-level pred layer
+class Yolo11DetPredLayer(nn.Module):
+    def __init__(self, cfg, cls_dim: int, reg_dim: int):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.num_levels = len(cfg.out_stride)
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          stride      = cfg.out_stride[level],
+                          reg_max     = cfg.reg_max,
+                          num_classes = cfg.num_classes,
+                          num_coords  = 4 * cfg.reg_max)
+                          for level in range(self.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(box_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
+                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLO11-Base config
+    class Yolo11BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 1.0
+            self.depth    = 1.0
+            self.ratio    = 1.0
+            self.reg_max  = 16
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+
+    cfg = Yolo11BaseConfig()
+    cfg.num_classes = 20
+    cls_dim = 128
+    reg_dim = 64
+    # Build a pred layer
+    pred = Yolo11DetPredLayer(cfg, cls_dim, reg_dim)
+
+    # Inference
+    cls_feats = [torch.randn(1, cls_dim, 80, 80),
+                 torch.randn(1, cls_dim, 40, 40),
+                 torch.randn(1, cls_dim, 20, 20),]
+    reg_feats = [torch.randn(1, reg_dim, 80, 80),
+                 torch.randn(1, reg_dim, 40, 40),
+                 torch.randn(1, reg_dim, 20, 20),]
+    t0 = time.time()
+    output = pred(cls_feats, reg_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== Pred output ======= ')
+    pred_cls = output["pred_cls"]
+    pred_reg = output["pred_reg"]
+    pred_box = output["pred_box"]
+    anchors  = output["anchors"]
+    
+    for level in range(cfg.num_levels):
+        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
+        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
+        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
+        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
+
+    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))