Browse Source

modify v10

yjh0410 10 months ago
parent
commit
b5e70703cd

+ 3 - 0
yolo/config/__init__.py

@@ -9,6 +9,7 @@ from .yolov6_config  import build_yolov6_config
 from .yolov7_config  import build_yolov7_config
 from .yolov8_config  import build_yolov8_config
 from .yolov9_config  import build_yolov9_config
+from .yolov10_config  import build_yolov10_config
 from .yolo11_config  import build_yolo11_config
 
 from .yolof_config   import build_yolof_config
@@ -40,6 +41,8 @@ def build_config(args):
         cfg = build_yolov8_config(args)
     elif 'yolov9' in args.model:
         cfg = build_yolov9_config(args)
+    elif 'yolov10' in args.model:
+        cfg = build_yolov10_config(args)
     elif 'yolo11' in args.model:
         cfg = build_yolo11_config(args)
         

+ 179 - 0
yolo/config/yolov10_config.py

@@ -0,0 +1,179 @@
+# yolo Config
+
+
+def build_yolov10_config(args):
+    if   args.model == 'yolov10_n':
+        return Yolov10NConfig()
+    elif args.model == 'yolov10_s':
+        return Yolov10SConfig()
+    elif args.model == 'yolov10_m':
+        return Yolov10MConfig()
+    elif args.model == 'yolov10_l':
+        return Yolov10LConfig()
+    elif args.model == 'yolov10_x':
+        return Yolov10XConfig()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLOv10-Base config
+class Yolov10BaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.model_scale = "l"
+        self.width   = 1.0
+        self.depth   = 1.0
+        self.ratio   = 1.0
+        self.reg_max = 16
+
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+
+        ## Head
+        self.num_cls_head = 2
+        self.num_reg_head = 2
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.tal_topk_candidates = 10
+        self.tal_alpha = 0.5
+        self.tal_beta  = 6.0
+        ## Loss weight
+        self.loss_cls = 0.5
+        self.loss_box = 7.5
+        self.loss_dfl = 1.5
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.base_lr      = 0.001     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 500
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.2,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv10-N
+class Yolov10NConfig(Yolov10BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "n"
+        self.width = 0.25
+        self.depth = 0.34
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.5
+
+# YOLOv10-S
+class Yolov10SConfig(Yolov10BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "s"
+        self.width = 0.50
+        self.depth = 0.34
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.5
+
+# YOLOv10-M
+class Yolov10MConfig(Yolov10BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "m"
+        self.width = 0.75
+        self.depth = 0.67
+        self.ratio = 1.5
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.5
+
+# YOLOv10-L
+class Yolov10LConfig(Yolov10BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "l"
+        self.width = 1.0
+        self.depth = 1.0
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.5
+
+# YOLOv10-X
+class Yolov10XConfig(Yolov10BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.model_scale = "x"
+        self.width = 1.25
+        self.depth = 1.0
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.5

+ 5 - 0
yolo/models/__init__.py

@@ -12,6 +12,7 @@ from .yolov6.build import build_yolov6
 from .yolov7.build import build_yolov7
 from .yolov8.build import build_yolov8
 from .yolov9.build import build_gelan
+from .yolov10.build import build_yolov10
 from .yolo11.build import build_yolo11
 
 from .yolof.build  import build_yolof
@@ -52,6 +53,10 @@ def build_model(args, cfg, is_val=False):
     ## GElan
     elif 'yolov9' in args.model:
         model, criterion = build_gelan(cfg, is_val)
+    ## YOLOv10
+    elif 'yolov10' in args.model:
+        model, criterion = build_yolov10(cfg, is_val)
+
     ## YOLO11
     elif 'yolo11' in args.model:
         model, criterion = build_yolo11(cfg, is_val)

+ 30 - 30
yolo/models/yolov10/README.md

@@ -1,56 +1,56 @@
-# YOLOv7:
-
-|    Model    |   Backbone    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|-------------|---------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOv7-Tiny | ELANNet-Tiny  | 8xb16 |  640  |         39.5           |       58.5        |   22.6            |   7.9              | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
-| YOLOv7      | ELANNet-Large | 8xb16 |  640  |         49.5           |       68.8        |   144.6           |   44.0             | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov7_coco.pth) |
-| YOLOv7-X    | ELANNet-Huge  |       |  640  |                        |                   |                   |                    |  |
-
-- For training, we train `YOLOv7` and `YOLOv7-Tiny` with 300 epochs on 8 GPUs.
-- For data augmentation, we use the [YOLOX-style](https://github.com/Megvii-BaseDetection/YOLOX) augmentation including the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation.
-- For optimizer, we use `AdamW` with weight decay 0.05 and per image learning rate 0.001 / 64.
-- For learning rate scheduler, we use Cosine decay scheduler.
-- For YOLOv7's structure, we replace the coupled head with the YOLOX-style decoupled head.
-- I think YOLOv7 uses too many training tricks, such as `anchor box`, `AuxiliaryHead`, `RepConv`, `Mosaic9x` and so on, making the picture of YOLO too complicated, which is against the development concept of the YOLO series. Otherwise, why don't we use the DETR series? It's nothing more than doing some acceleration optimization on DETR. Therefore, I was faithful to my own technical aesthetics and realized a cleaner and simpler YOLOv7, but without the blessing of so many tricks, I did not reproduce all the performance, which is a pity.
-- I have no more GPUs to train my `YOLOv7-X`.
-
-## Train YOLOv7
+# YOLOv8:
+
+- VOC
+
+|     Model   | Batch | Scale | AP<sup>val<br>0.5 | Weight |  Logs  |
+|-------------|-------|-------|-------------------|--------|--------|
+| YOLOv8-S    | 1xb16 |  640  |      83.6     | [ckpt](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/yolov8_s_voc.pth) | [log](https://github.com/yjh0410/YOLO-Tutorial-v5/releases/download/yolo_tutorial_ckpt/YOLOv8-S-VOC.txt) |
+
+- COCO
+
+|    Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |  Logs  |
+|-------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|--------|
+| YOLOv8-S    | 1xb16 |  640  |                    |               |   26.9            |   8.9             |  |  |
+
+
+
+## Train YOLOv8
 ### Single GPU
-Taking training YOLOv7-Tiny on COCO as the example,
+Taking training YOLOv8-S on COCO as the example,
 ```Shell
-python train.py --cuda -d coco --root path/to/coco -m yolov7_tiny -bs 16 -size 640 --wp_epoch 3 --max_epoch 300 --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --multi_scale 
+python train.py --cuda -d coco --root path/to/coco -m yolov8_s -bs 16 --fp16 
 ```
 
 ### Multi GPU
-Taking training YOLOv7-Tiny on COCO as the example,
+Taking training YOLOv8-S on COCO as the example,
 ```Shell
-python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root /data/datasets/ -m yolov7_tiny -bs 128 -size 640 --wp_epoch 3 --max_epoch 300  --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --sybn --multi_scale --save_folder weights/ 
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root path/to/coco -m yolov8_s -bs 256 --fp16 
 ```
 
-## Test YOLOv7
-Taking testing YOLOv7-Tiny on COCO-val as the example,
+## Test YOLOv8
+Taking testing YOLOv8-S on COCO-val as the example,
 ```Shell
-python test.py --cuda -d coco --root path/to/coco -m yolov7_tiny --weight path/to/yolov7_tiny.pth -size 640 -vt 0.4 --show 
+python test.py --cuda -d coco --root path/to/coco -m yolov8_s --weight path/to/yolov8.pth --show 
 ```
 
-## Evaluate YOLOv7
-Taking evaluating YOLOv7-Tiny on COCO-val as the example,
+## Evaluate YOLOv8
+Taking evaluating YOLOv8-S on COCO-val as the example,
 ```Shell
-python eval.py --cuda -d coco-val --root path/to/coco -m yolov7_tiny --weight path/to/yolov7_tiny.pth 
+python eval.py --cuda -d coco --root path/to/coco -m yolov8_s --weight path/to/yolov8.pth 
 ```
 
 ## Demo
 ### Detect with Image
 ```Shell
-python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov8_s --weight path/to/weight --show
 ```
 
 ### Detect with Video
 ```Shell
-python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show --gif
+python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov8_s --weight path/to/weight --show --gif
 ```
 
 ### Detect with Camera
 ```Shell
-python demo.py --mode camera --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show --gif
+python demo.py --mode camera --cuda -m yolov8_s --weight path/to/weight --show --gif
 ```

+ 8 - 50
yolo/models/yolov10/build.py

@@ -1,66 +1,24 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-
-import torch
 import torch.nn as nn
 
-from .loss import build_criterion
-from .yolov10 import YOLOv7
+from .loss import SetCriterion
+from .yolov10 import Yolov10
 
 
 # build object detector
-def build_yolov7(args, cfg, device, num_classes=80, trainable=False, deploy=False):
-    print('==============================')
-    print('Build {} ...'.format(args.model.upper()))
-    
-    print('==============================')
-    print('Model Configuration: \n', cfg)
-    
+def build_yolov10(cfg, is_val=False):
     # -------------- Build YOLO --------------
-    model = YOLOv7(cfg                = cfg,
-                   device             = device, 
-                   num_classes        = num_classes,
-                   trainable          = trainable,
-                   conf_thresh        = args.conf_thresh,
-                   nms_thresh         = args.nms_thresh,
-                   topk               = args.topk,
-                   deploy             = deploy,
-                   no_multi_labels    = args.no_multi_labels,
-                   nms_class_agnostic = args.nms_class_agnostic
-                   )
+    model = Yolov10(cfg, is_val)
 
     # -------------- Initialize YOLO --------------
     for m in model.modules():
         if isinstance(m, nn.BatchNorm2d):
             m.eps = 1e-3
             m.momentum = 0.03    
-    # Init bias
-    init_prob = 0.01
-    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-    # obj pred
-    for obj_pred in model.obj_preds:
-        b = obj_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-    # cls pred
-    for cls_pred in model.cls_preds:
-        b = cls_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-    # reg pred
-    for reg_pred in model.reg_preds:
-        b = reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = reg_pred.weight
-        w.data.fill_(0.)
-        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-
+            
     # -------------- Build criterion --------------
     criterion = None
-    if trainable:
+    if is_val:
         # build criterion for training
-        criterion = build_criterion(args, cfg, device, num_classes)
-
+        criterion = SetCriterion(cfg)
+        
     return model, criterion

+ 125 - 158
yolo/models/yolov10/loss.py

@@ -1,212 +1,179 @@
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from .matcher import SimOTA
-from utils.box_ops import get_ious
+
+from utils.box_ops import bbox2dist, bbox_iou
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
+from .matcher import TaskAlignedAssigner
 
 
-class Criterion(object):
-    def __init__(self,
-                 args,
-                 cfg, 
-                 device, 
-                 num_classes=80):
-        self.args = args
+class SetCriterion(object):
+    def __init__(self, cfg):
+        # --------------- Basic parameters ---------------
         self.cfg = cfg
-        self.device = device
-        self.num_classes = num_classes
-        self.max_epoch = args.max_epoch
-        self.no_aug_epoch = args.no_aug_epoch
-        self.aux_bbox_loss = False
-        # loss weight
-        self.loss_obj_weight = cfg['loss_obj_weight']
-        self.loss_cls_weight = cfg['loss_cls_weight']
-        self.loss_box_weight = cfg['loss_box_weight']
-        # matcher
-        matcher_config = cfg['matcher']
-        self.matcher = SimOTA(
-            num_classes=num_classes,
-            center_sampling_radius=matcher_config['center_sampling_radius'],
-            topk_candidate=matcher_config['topk_candicate']
-            )
-
-
-    def loss_objectness(self, pred_obj, gt_obj):
-        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
-
-        return loss_obj
-    
-
-    def loss_classes(self, pred_cls, gt_label):
-        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+        self.reg_max = cfg.reg_max
+        self.num_classes = cfg.num_classes
+        # --------------- Loss config ---------------
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+        self.loss_dfl_weight = cfg.loss_dfl
+        # --------------- Matcher config ---------------
+        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
+                                           topk_candidates = cfg.tal_topk_candidates,
+                                           alpha           = cfg.tal_alpha,
+                                           beta            = cfg.tal_beta
+                                           )
+
+    def loss_classes(self, pred_cls, gt_score):
+        # compute bce loss
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
 
         return loss_cls
-
-
-    def loss_bboxes(self, pred_box, gt_box):
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
         # regression loss
-        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
-        loss_box = 1.0 - ious
+        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
+        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
 
         return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
 
+        return loss_dfl
 
-    def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
-        # xyxy -> cxcy&bwbh
-        gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
-        gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
-        # encode gt box
-        gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
-        gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
-        gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
-        # l1 loss
-        loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
-
-        return loss_box_aux
-
-
-    def __call__(self, outputs, targets, epoch=0):        
+    def __call__(self, outputs, targets):        
         """
-            outputs['pred_obj']: List(Tensor) [B, M, 1]
             outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
             outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
             outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
             targets: (List) [dict{'boxes': [...], 
                                  'labels': [...], 
                                  'orig_size': ...}, ...]
         """
-        bs = outputs['pred_cls'][0].shape[0]
-        device = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
-        anchors = outputs['anchors']
         # preds: [B, M, C]
-        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
         box_preds = torch.cat(outputs['pred_box'], dim=1)
-
-        # label assignment
-        cls_targets = []
-        box_targets = []
-        obj_targets = []
+        bs, num_anchors = cls_preds.shape[:2]
+        device = cls_preds.device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        
+        # --------------- label assignment ---------------
+        gt_score_targets = []
+        gt_bbox_targets = []
         fg_masks = []
-
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
 
             # check target
-            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
-                num_anchors = sum([ab.shape[0] for ab in anchors])
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
                 # There is no valid gt
-                cls_target = obj_preds.new_zeros((0, self.num_classes))
-                box_target = obj_preds.new_zeros((0, 4))
-                obj_target = obj_preds.new_zeros((num_anchors, 1))
-                fg_mask = obj_preds.new_zeros(num_anchors).bool()
+                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
             else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    fg_mask,
-                    assigned_labels,
-                    assigned_ious,
-                    assigned_indexs
+                    _,
+                    gt_box,     # [1, M, 4]
+                    gt_score,   # [1, M, C]
+                    fg_mask,    # [1, M,]
+                    _
                 ) = self.matcher(
-                    fpn_strides = fpn_strides,
-                    anchors = anchors,
-                    pred_obj = obj_preds[batch_idx],
-                    pred_cls = cls_preds[batch_idx], 
-                    pred_box = box_preds[batch_idx],
-                    tgt_labels = tgt_labels,
-                    tgt_bboxes = tgt_bboxes
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
                     )
-
-                obj_target = fg_mask.unsqueeze(-1)
-                cls_target = F.one_hot(assigned_labels.long(), self.num_classes)
-                cls_target = cls_target * assigned_ious.unsqueeze(-1)
-                box_target = tgt_bboxes[assigned_indexs]
-
-            cls_targets.append(cls_target)
-            box_targets.append(box_target)
-            obj_targets.append(obj_target)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
             fg_masks.append(fg_mask)
 
-        cls_targets = torch.cat(cls_targets, 0)
-        box_targets = torch.cat(box_targets, 0)
-        obj_targets = torch.cat(obj_targets, 0)
-        fg_masks = torch.cat(fg_masks, 0)
-        num_fgs = fg_masks.sum()
-
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = gt_score_targets.sum()
+        
+        # Average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
-        # ------------------ Objecntness loss ------------------
-        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
-        loss_obj = loss_obj.sum() / num_fgs
-        
         # ------------------ Classification loss ------------------
-        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
-        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------
         box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
         loss_box = loss_box.sum() / num_fgs
 
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
         # total loss
-        losses = self.loss_obj_weight * loss_obj + \
-                 self.loss_cls_weight * loss_cls + \
-                 self.loss_box_weight * loss_box
-
-        # ------------------ Aux regression loss ------------------
-        loss_box_aux = None
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
-            ## reg_preds
-            reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-            reg_preds_pos = reg_preds.view(-1, 4)[fg_masks]
-            ## anchor tensors
-            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
-            anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
-            ## stride tensors
-            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
-            stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
-            ## aux loss
-            loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets, anchors_tensors_pos, stride_tensors_pos)
-            loss_box_aux = loss_box_aux.sum() / num_fgs
-
-            losses += loss_box_aux
-
-        # Loss dict
-        if loss_box_aux is None:
-            loss_dict = dict(
-                    loss_obj = loss_obj,
-                    loss_cls = loss_cls,
-                    loss_box = loss_box,
-                    losses = losses
-            )
-        else:
-            loss_dict = dict(
-                    loss_obj = loss_obj,
-                    loss_cls = loss_cls,
-                    loss_box = loss_box,
-                    loss_box_aux = loss_box_aux,
-                    losses = losses
-                    )
+        losses = loss_cls * self.loss_cls_weight + \
+                 loss_box * self.loss_box_weight + \
+                 loss_dfl * self.loss_dfl_weight
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
 
         return loss_dict
     
 
-def build_criterion(args, cfg, device, num_classes):
-    criterion = Criterion(
-        args=args,
-        cfg=cfg,
-        device=device,
-        num_classes=num_classes
-        )
-
-    return criterion
-
-
 if __name__ == "__main__":
     pass

+ 192 - 177
yolo/models/yolov10/matcher.py

@@ -1,187 +1,202 @@
-# ---------------------------------------------------------------------
-# Copyright (c) Megvii Inc. All rights reserved.
-# ---------------------------------------------------------------------
-
-
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from utils.box_ops import *
+from utils.box_ops import bbox_iou
 
 
-class SimOTA(object):
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
     """
-        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
+        This code referenced to https://github.com/ultralytics/ultralytics
     """
-    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
+    def __init__(self,
+                 num_classes     = 80,
+                 topk_candidates = 10,
+                 alpha           = 0.5,
+                 beta            = 6.0, 
+                 eps             = 1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk_candidates = topk_candidates
         self.num_classes = num_classes
-        self.center_sampling_radius = center_sampling_radius
-        self.topk_candidate = topk_candidate
-
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
 
     @torch.no_grad()
-    def __call__(self, 
-                 fpn_strides, 
-                 anchors, 
-                 pred_obj, 
-                 pred_cls, 
-                 pred_box, 
-                 tgt_labels,
-                 tgt_bboxes):
-        # [M,]
-        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
-        anchors = torch.cat(anchors, dim=0)
-        num_anchor = anchors.shape[0]        
-        num_gt = len(tgt_labels)
-
-        # ----------------------- Find inside points -----------------------
-        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
-            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
-        obj_preds = pred_obj[fg_mask].float()   # [Mp, 1]
-        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
-        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
-
-        # ----------------------- Reg cost -----------------------
-        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
-        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
-
-        # ----------------------- Cls cost -----------------------
-        with torch.cuda.amp.autocast(enabled=False):
-            # [Mp, C]
-            score_preds = torch.sqrt(obj_preds.sigmoid_()* cls_preds.sigmoid_())
-            # [N, Mp, C]
-            score_preds = score_preds.unsqueeze(0).repeat(num_gt, 1, 1)
-            # prepare cls_target
-            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
-            cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
-            # [N, Mp]
-            cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
-        del score_preds
-
-        #----------------------- Dynamic K-Matching -----------------------
-        cost_matrix = (
-            cls_cost
-            + 3.0 * reg_cost
-            + 100000.0 * (~is_in_boxes_and_center)
-        ) # [N, Mp]
-
-        (
-            assigned_labels,         # [num_fg,]
-            assigned_ious,           # [num_fg,]
-            assigned_indexs,         # [num_fg,]
-        ) = self.dynamic_k_matching(
-            cost_matrix,
-            pair_wise_ious,
-            tgt_labels,
-            num_gt,
-            fg_mask
-            )
-        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
-
-        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
-
-
-    def get_in_boxes_info(
-        self,
-        gt_bboxes,   # [N, 4]
-        anchors,     # [M, 2]
-        strides,     # [M,]
-        num_anchors, # M
-        num_gt,      # N
-        ):
-        # anchor center
-        x_centers = anchors[:, 0]
-        y_centers = anchors[:, 1]
-
-        # [M,] -> [1, M] -> [N, M]
-        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
-        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
-
-        # [N,] -> [N, 1] -> [N, M]
-        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
-        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
-        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
-        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
-
-        b_l = x_centers - gt_bboxes_l
-        b_r = gt_bboxes_r - x_centers
-        b_t = y_centers - gt_bboxes_t
-        b_b = gt_bboxes_b - y_centers
-        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
-
-        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
-        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
-        # in fixed center
-        center_radius = self.center_sampling_radius
-
-        # [N, 2]
-        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
-        
-        # [1, M]
-        center_radius_ = center_radius * strides.unsqueeze(0)
-
-        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
-        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
-        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
-        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
-
-        c_l = x_centers - gt_bboxes_l
-        c_r = gt_bboxes_r - x_centers
-        c_t = y_centers - gt_bboxes_t
-        c_b = gt_bboxes_b - y_centers
-        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
-        is_in_centers = center_deltas.min(dim=-1).values > 0.0
-        is_in_centers_all = is_in_centers.sum(dim=0) > 0
-
-        # in boxes and in centers
-        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
-
-        is_in_boxes_and_center = (
-            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
-        )
-        return is_in_boxes_anchor, is_in_boxes_and_center
-    
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # Assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
+        """Compute alignment metric given predicted and ground truth bounding boxes."""
+        na = pd_bboxes.shape[-2]
+        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
+
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
+        # Get the scores of each grid for each gt cls
+        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
+        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        return align_metric, overlaps
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
+        # (b, max_num_obj, topk)
+        topk_idxs.masked_fill_(~topk_mask, 0)
+
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
+        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
+        for k in range(self.topk_candidates):
+            # Expand topk_idxs for each value of k and add 1 at the specified positions
+            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
+        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
+        # Filter invalid bboxes
+        count_tensor.masked_fill_(count_tensor > 1, 0)
+
+        return count_tensor.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        # Assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # Assigned target scores
+        target_labels.clamp_(0)
+
+        # 10x faster than F.one_hot()
+        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
+                                    dtype=torch.int64,
+                                    device=target_labels.device)  # (b, h*w, 80)
+        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
+
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
     
-    def dynamic_k_matching(
-        self, 
-        cost, 
-        pair_wise_ious, 
-        gt_classes, 
-        num_gt, 
-        fg_mask
-        ):
-        # Dynamic K
-        # ---------------------------------------------------------------
-        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
-
-        ious_in_boxes_matrix = pair_wise_ious
-        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
-        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
-        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-        dynamic_ks = dynamic_ks.tolist()
-        for gt_idx in range(num_gt):
-            _, pos_idx = torch.topk(
-                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
-            )
-            matching_matrix[gt_idx][pos_idx] = 1
-
-        del topk_ious, dynamic_ks, pos_idx
-
-        anchor_matching_gt = matching_matrix.sum(0)
-        if (anchor_matching_gt > 1).sum() > 0:
-            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
-            matching_matrix[:, anchor_matching_gt > 1] *= 0
-            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
-        fg_mask_inboxes = matching_matrix.sum(0) > 0
-
-        fg_mask[fg_mask.clone()] = fg_mask_inboxes
-
-        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
-        assigned_labels = gt_classes[assigned_indexs]
-
-        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
-            fg_mask_inboxes
-        ]
-        return assigned_labels, assigned_ious, assigned_indexs
-    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(-2)
+    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+        fg_mask = mask_pos.sum(-2)
+    # Find each grid serve which gt(index)
+    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+
+    return target_gt_idx, fg_mask, mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 189 - 301
yolo/models/yolov10/modules.py

@@ -1,338 +1,226 @@
-import numpy as np
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
 
 
-# ---------------------------- 2D CNN ----------------------------
-class SiLU(nn.Module):
-    """export-friendly version of nn.SiLU()"""
-
-    @staticmethod
-    def forward(x):
-        return x * torch.sigmoid(x)
-
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-
-## Basic conv layer
-class Conv(nn.Module):
+# --------------------- Basic modules ---------------------
+class ConvModule(nn.Module):
     def __init__(self, 
-                 c1,                   # in channels
-                 c2,                   # out channels 
-                 k=1,                  # kernel size 
-                 p=0,                  # padding
-                 s=1,                  # padding
-                 d=1,                  # dilation
-                 act_type='lrelu',     # activation
-                 norm_type='BN',       # normalization
-                 depthwise=False):
-        super(Conv, self).__init__()
-        convs = []
-        add_bias = False if norm_type else True
-        if depthwise:
-            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
-            # depthwise conv
-            if norm_type:
-                convs.append(get_norm(norm_type, c1))
-            if act_type:
-                convs.append(get_activation(act_type))
-            # pointwise conv
-            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
+                 in_dim,
+                 out_dim,
+                 kernel_size=1,
+                 stride=1,
+                 groups=1,
+                 use_act=True,
+                ):
+        super(ConvModule, self).__init__()
+        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=groups, bias=False)
+        self.norm = nn.BatchNorm2d(out_dim)
+        self.act  = nn.SiLU(inplace=True) if use_act else nn.Identity()
 
-        else:
-            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-            
-        self.convs = nn.Sequential(*convs)
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
+
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :List  = [1, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = False,
+                 ):
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], stride=1)
+        self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], stride=1)
+        self.shortcut = shortcut and in_dim == out_dim
 
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CIBBlock(nn.Module):
+    def __init__(self,
+                 in_dim   :int,
+                 out_dim  :int,
+                 shortcut :bool  = False,
+                 ) -> None:
+        super(CIBBlock, self).__init__()
+        # ----------------- Network setting -----------------
+        self.cv1 = ConvModule(in_dim, in_dim, kernel_size=3, groups=in_dim)
+        self.cv2 = ConvModule(in_dim, in_dim * 2, kernel_size=1)
+        self.cv3 = ConvModule(in_dim * 2, in_dim * 2, kernel_size=3, groups=in_dim * 2)
+        self.cv4 = ConvModule(in_dim * 2, out_dim, kernel_size=1)
+        self.cv5 = ConvModule(out_dim, out_dim, kernel_size=3, groups=out_dim)
+        self.shortcut = shortcut and in_dim == out_dim
 
     def forward(self, x):
-        return self.convs(x)
+        h = self.cv5(self.cv4(self.cv3(self.cv2(self.cv1(x)))))
+
+        return x + h if self.shortcut else h
+
+
+# --------------------- Yolov10 modules ---------------------
+class C2fBlock(nn.Module):
+    def __init__(self,
+                 in_dim: int,
+                 out_dim: int,
+                 expansion : float = 0.5,
+                 num_blocks : int = 1,
+                 shortcut: bool = False,
+                 use_cib: bool = False,
+                 ):
+        super(C2fBlock, self).__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj  = ConvModule(in_dim, inter_dim * 2, kernel_size=1)
+        self.output_proj = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
+
+        if use_cib:
+            self.blocks = nn.ModuleList([
+                CIBBlock(in_dim = inter_dim,
+                         out_dim = inter_dim,
+                         shortcut = shortcut,
+                         ) for _ in range(num_blocks)])
+        else:
+            self.blocks = nn.ModuleList([
+                YoloBottleneck(in_dim = inter_dim,
+                               out_dim = inter_dim,
+                               kernel_size = [3, 3],
+                               expansion = 1.0,
+                               shortcut = shortcut,
+                               ) for _ in range(num_blocks)])
 
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
 
-# ---------------------------- YOLOv7 Modules ----------------------------
-## ELAN-Block proposed by YOLOv7
-class ELANBlock(nn.Module):
-    def __init__(self, in_dim, out_dim, squeeze_ratio=0.5, branch_depth :int=2, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANBlock, self).__init__()
-        inter_dim = int(in_dim * squeeze_ratio)
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv3 = nn.Sequential(*[
-            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-            for _ in range(round(branch_depth))
-        ])
-        self.cv4 = nn.Sequential(*[
-            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-            for _ in range(round(branch_depth))
-        ])
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.blocks)
 
-        self.out = Conv(inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
 
+        return out
 
+class SCDown(nn.Module):
+    def __init__(self, in_dim, out_dim, kernel_size: int = 3, stride: int = 2):
+        super().__init__()
+        self.cv1 = ConvModule(in_dim, out_dim, kernel_size=1)
+        self.cv2 = ConvModule(out_dim, out_dim, kernel_size=kernel_size, stride=stride, groups=out_dim, use_act=False)
 
     def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        x3 = self.cv3(x2)
-        x4 = self.cv4(x3)
-        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
+        return self.cv2(self.cv1(x))
 
-        return out
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, attn_ratio=0.5):
+        super().__init__()
+        self.num_heads = num_heads
+        self.head_dim = dim // num_heads
+        self.key_dim = int(self.head_dim * attn_ratio)
+        self.scale = self.key_dim**-0.5
+        
+        nh_kd = self.key_dim * num_heads
+        h = dim + nh_kd * 2
+        self.qkv  = ConvModule(dim, h, kernel_size=1, use_act=False)
+        self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False)
+        self.pe   = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False)
 
-## PaFPN's ELAN-Block proposed by YOLOv7
-class ELANBlockFPN(nn.Module):
-    def __init__(self, in_dim, out_dim, squeeze_ratio=0.5, branch_width :int=4, branch_depth :int=1, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANBlockFPN, self).__init__()
-        # Basic parameters
-        inter_dim = int(in_dim * squeeze_ratio)
-        inter_dim2 = int(inter_dim * squeeze_ratio) 
-        # Network structure
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv3 = nn.ModuleList()
-        for idx in range(round(branch_width)):
-            if idx == 0:
-                cvs = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
-            else:
-                cvs = [Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
-            # deeper
-            if round(branch_depth) > 1:
-                for _ in range(1, round(branch_depth)):
-                    cvs.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
-                self.cv3.append(nn.Sequential(*cvs))
-            else:
-                self.cv3.append(cvs[0])
-
-        self.out = Conv(inter_dim*2+inter_dim2*len(self.cv3), out_dim, k=1, act_type=act_type, norm_type=norm_type)
+    def forward(self, x):
+        bs, c, h, w = x.shape
+        seq_len = h * w
 
+        qkv = self.qkv(x)
+        q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
+            [self.key_dim, self.key_dim, self.head_dim], dim=2
+        )
 
-    def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        inter_outs = [x1, x2]
-        for m in self.cv3:
-            y1 = inter_outs[-1]
-            y2 = m(y1)
-            inter_outs.append(y2)
-        out = self.out(torch.cat(inter_outs, dim=1))
+        attn = (q.transpose(-2, -1) @ k) * self.scale
+        attn = attn.softmax(dim=-1)
+        x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
+        x = self.proj(x)
 
-        return out
+        return x
 
-## DownSample Block proposed by YOLOv7
-class DownSample(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+class PSABlock(nn.Module):
+    def __init__(self, in_dim, out_dim, expansion=0.5):
         super().__init__()
-        inter_dim = out_dim // 2
-        self.mp = nn.MaxPool2d((2, 2), 2)
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = nn.Sequential(
-            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
-            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        assert(in_dim == out_dim)
+        self.inter_dim = int(in_dim * expansion)
+        self.cv1 = ConvModule(in_dim, 2 * self.inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(2 * self.inter_dim, in_dim, kernel_size=1)
+        
+        self.attn = Attention(self.inter_dim, attn_ratio=0.5, num_heads=self.inter_dim // 64)
+        self.ffn = nn.Sequential(
+            ConvModule(self.inter_dim, self.inter_dim * 2, kernel_size=1),
+            ConvModule(self.inter_dim * 2, self.inter_dim, kernel_size=1, use_act=False)
         )
-
+        
     def forward(self, x):
-        x1 = self.cv1(self.mp(x))
-        x2 = self.cv2(x)
-        out = torch.cat([x1, x2], dim=1)
-
-        return out
+        a, b = self.cv1(x).split((self.inter_dim, self.inter_dim), dim=1)
+        b = b + self.attn(b)
+        b = b + self.ffn(b)
+        return self.cv2(torch.cat((a, b), 1))
 
-
-# ---------------------------- RepConv Modules ----------------------------
-class RepConv(nn.Module):
+class SPPF(nn.Module):
     """
-        The code referenced to https://github.com/WongKinYiu/yolov7/models/common.py
+        This code referenced to https://github.com/ultralytics/yolov5
     """
-    # Represented convolution
-    # https://arxiv.org/abs/2101.03697
-
-    def __init__(self, c1, c2, k=3, s=1, p=1, g=1, act_type='silu', deploy=False):
-        super(RepConv, self).__init__()
-        # -------------- Basic parameters --------------
-        self.deploy = deploy
-        self.groups = g
-        self.in_channels = c1
-        self.out_channels = c2
+    def __init__(self, in_dim, out_dim):
+        super().__init__()
+        ## ----------- Basic Parameters -----------
+        inter_dim = in_dim // 2
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, stride=1)
+        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, stride=1)
+        self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
-        # -------------- Network parameters --------------
-        if deploy:
-            self.rbr_reparam = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=True)
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
 
-        else:
-            self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
 
-            self.rbr_dense = nn.Sequential(
-                nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False),
-                nn.BatchNorm2d(num_features=c2),
-            )
+class DflLayer(nn.Module):
+    def __init__(self, reg_max=16):
+        """Initialize a convolutional layer with a given number of input channels."""
+        super().__init__()
+        self.reg_max = reg_max
+        proj_init = torch.arange(reg_max, dtype=torch.float)
+        self.proj_weight = nn.Parameter(proj_init.view([1, reg_max, 1, 1]), requires_grad=False)
 
-            self.rbr_1x1 = nn.Sequential(
-                nn.Conv2d(c1, c2, kernel_size=1, stride=s, bias=False),
-                nn.BatchNorm2d(num_features=c2),
-            )
-        self.act = get_activation(act_type)
+    def forward(self, pred_reg, anchor, stride):
+        bs, hw = pred_reg.shape[:2]
+        # [bs, hw, 4*rm] -> [bs, 4*rm, hw] -> [bs, 4, rm, hw]
+        pred_reg = pred_reg.permute(0, 2, 1).reshape(bs, 4, -1, hw)
 
+        # [bs, 4, rm, hw] -> [bs, rm, 4, hw]
+        pred_reg = pred_reg.permute(0, 2, 1, 3).contiguous()
 
-    def forward(self, inputs):
-        if hasattr(self, "rbr_reparam"):
-            return self.act(self.rbr_reparam(inputs))
+        # [bs, rm, 4, hw] -> [bs, 1, 4, hw]
+        delta_pred = F.conv2d(F.softmax(pred_reg, dim=1), self.proj_weight)
 
-        if self.rbr_identity is None:
-            id_out = 0
-        else:
-            id_out = self.rbr_identity(inputs)
-
-        return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
-    
-    def get_equivalent_kernel_bias(self):
-        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
-        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
-        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
-        return (
-            kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
-            bias3x3 + bias1x1 + biasid,
-        )
+        # [bs, 1, 4, hw] -> [bs, 4, hw] -> [bs, hw, 4]
+        delta_pred = delta_pred.view(bs, 4, hw).permute(0, 2, 1).contiguous()
+        delta_pred *= stride
 
-    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
-        if kernel1x1 is None:
-            return 0
-        else:
-            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
-
-    def _fuse_bn_tensor(self, branch):
-        if branch is None:
-            return 0, 0
-        if isinstance(branch, nn.Sequential):
-            kernel = branch[0].weight
-            running_mean = branch[1].running_mean
-            running_var = branch[1].running_var
-            gamma = branch[1].weight
-            beta = branch[1].bias
-            eps = branch[1].eps
-        else:
-            assert isinstance(branch, nn.BatchNorm2d)
-            if not hasattr(self, "id_tensor"):
-                input_dim = self.in_channels // self.groups
-                kernel_value = np.zeros(
-                    (self.in_channels, input_dim, 3, 3), dtype=np.float32
-                )
-                for i in range(self.in_channels):
-                    kernel_value[i, i % input_dim, 1, 1] = 1
-                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
-            kernel = self.id_tensor
-            running_mean = branch.running_mean
-            running_var = branch.running_var
-            gamma = branch.weight
-            beta = branch.bias
-            eps = branch.eps
-        std = (running_var + eps).sqrt()
-        t = (gamma / std).reshape(-1, 1, 1, 1)
-        return kernel * t, beta - running_mean * gamma / std
-
-    def repvgg_convert(self):
-        kernel, bias = self.get_equivalent_kernel_bias()
-        return (
-            kernel.detach().cpu().numpy(),
-            bias.detach().cpu().numpy(),
-        )
+        # Decode bbox: tlbr -> xyxy
+        x1y1_pred = anchor - delta_pred[..., :2]
+        x2y2_pred = anchor + delta_pred[..., 2:]
+        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
 
-    def fuse_conv_bn(self, conv, bn):
-
-        std = (bn.running_var + bn.eps).sqrt()
-        bias = bn.bias - bn.running_mean * bn.weight / std
-
-        t = (bn.weight / std).reshape(-1, 1, 1, 1)
-        weights = conv.weight * t
-
-        bn = nn.Identity()
-        conv = nn.Conv2d(in_channels = conv.in_channels,
-                              out_channels = conv.out_channels,
-                              kernel_size = conv.kernel_size,
-                              stride=conv.stride,
-                              padding = conv.padding,
-                              dilation = conv.dilation,
-                              groups = conv.groups,
-                              bias = True,
-                              padding_mode = conv.padding_mode)
-
-        conv.weight = torch.nn.Parameter(weights)
-        conv.bias = torch.nn.Parameter(bias)
-        return conv
-
-    def fuse_repvgg_block(self):    
-        if self.deploy:
-            return
-                
-        self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
-        
-        self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
-        rbr_1x1_bias = self.rbr_1x1.bias
-        weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
-        
-        # Fuse self.rbr_identity
-        if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
-            identity_conv_1x1 = nn.Conv2d(
-                    in_channels=self.in_channels,
-                    out_channels=self.out_channels,
-                    kernel_size=1,
-                    stride=1,
-                    padding=0,
-                    groups=self.groups, 
-                    bias=False)
-            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
-            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
-
-            identity_conv_1x1.weight.data.fill_(0.0)
-            identity_conv_1x1.weight.data.fill_diagonal_(1.0)
-            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
-
-            identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
-            bias_identity_expanded = identity_conv_1x1.bias
-            weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])            
-        else:
-            bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
-            weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )            
-        
-        self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
-        self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
-                
-        self.rbr_reparam = self.rbr_dense
-        self.deploy = True
-
-        if self.rbr_identity is not None:
-            del self.rbr_identity
-            self.rbr_identity = None
-
-        if self.rbr_1x1 is not None:
-            del self.rbr_1x1
-            self.rbr_1x1 = None
-
-        if self.rbr_dense is not None:
-            del self.rbr_dense
-            self.rbr_dense = None
+        return box_pred

+ 47 - 224
yolo/models/yolov10/yolov10.py

@@ -1,112 +1,52 @@
+# --------------- Torch components ---------------
 import torch
 import torch.nn as nn
 
-from utils.misc import multiclass_nms
-
-from .yolov10_backbone import build_backbone
-from .yolov10_neck import build_neck
-from .yolov10_pafpn import build_fpn
-from .yolov10_head import build_head
-
-
-# YOLOv7
-class YOLOv7(nn.Module):
-    def __init__(self,
-                 cfg,
-                 device,
-                 num_classes=20,
-                 conf_thresh=0.01,
-                 topk=100,
-                 nms_thresh=0.5,
-                 trainable=False,
-                 deploy = False,
-                 no_multi_labels = False,
-                 nms_class_agnostic = False):
-        super(YOLOv7, self).__init__()
-        # ------------------- Basic parameters -------------------
-        self.cfg = cfg                                 # 模型配置文件
-        self.device = device                           # cuda或者是cpu
-        self.num_classes = num_classes                 # 类别的数量
-        self.trainable = trainable                     # 训练的标记
-        self.conf_thresh = conf_thresh                 # 得分阈值
-        self.nms_thresh = nms_thresh                   # NMS阈值
-        self.topk_candidates = topk                    # topk
-        self.stride = [8, 16, 32]                      # 网络的输出步长
-        self.num_levels = 3
-        self.deploy = deploy
-        self.no_multi_labels = no_multi_labels
-        self.nms_class_agnostic = nms_class_agnostic
-        # ------------------- Network Structure -------------------
-        ## 主干网络
-        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
-
-        ## 颈部网络: SPP模块
-        self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1]//2)
-        feats_dim[-1] = self.neck.out_dim
+# --------------- Model components ---------------
+from .yolov10_backbone import Yolov10Backbone
+from .yolov10_pafpn    import Yolov10PaFPN
+from .yolov10_head     import Yolov10DetHead
 
-        ## 颈部网络: 特征金字塔
-        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['channel_width']))
-        self.head_dim = self.fpn.out_dim
-
-        ## 检测头
-        self.non_shared_heads = nn.ModuleList(
-            [build_head(cfg, head_dim, head_dim, num_classes) 
-            for head_dim in self.head_dim
-            ])
-
-        ## 预测层
-        self.obj_preds = nn.ModuleList(
-                            [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ]) 
-        self.cls_preds = nn.ModuleList(
-                            [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ]) 
-        self.reg_preds = nn.ModuleList(
-                            [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ])                 
-
-
-    # ---------------------- Basic Functions ----------------------
-    ## generate anchor points
-    def generate_anchors(self, level, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchor_xy += 0.5  # add center offset
-        anchor_xy *= self.stride[level]
-        anchors = anchor_xy.to(self.device)
+from utils.misc import multiclass_nms
 
-        return anchors
-        
-    ## post-process
-    def post_process(self, obj_preds, cls_preds, box_preds):
+# YOLOv10
+class Yolov10(nn.Module):
+    def __init__(self, cfg, is_val = False) -> None:
+        super(Yolov10, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+
+        self.backbone = Yolov10Backbone(cfg)
+        self.pafpn    = Yolov10PaFPN(cfg, self.backbone.feat_dims[-3:])
+        self.det_head = Yolov10DetHead(cfg, self.pafpn.out_dims)
+
+    def post_process(self, cls_preds, box_preds):
         """
+        We process predictions at each scale hierarchically
         Input:
-            cls_preds: List[np.array] -> [[M, C], ...]
-            box_preds: List[np.array] -> [[M, 4], ...]
-            obj_preds: List[np.array] -> [[M, 1], ...] or None
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
         Output:
             bboxes: np.array -> [N, 4]
             scores: np.array -> [N,]
             labels: np.array -> [N,]
         """
-        assert len(cls_preds) == self.num_levels
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
             if self.no_multi_labels:
                 # [M,]
-                scores, labels = torch.max(torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -123,10 +63,9 @@ class YOLOv7(nn.Module):
 
                 labels = labels[topk_idxs]
                 bboxes = box_pred_i[topk_idxs]
-
             else:
                 # [M, C] -> [MC,]
-                scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+                scores_i = cls_pred_i.sigmoid().flatten()
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -150,9 +89,9 @@ class YOLOv7(nn.Module):
             all_labels.append(labels)
             all_bboxes.append(bboxes)
 
-        scores = torch.cat(all_scores)
-        labels = torch.cat(all_labels)
-        bboxes = torch.cat(all_bboxes)
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
 
         # to cpu & numpy
         scores = scores.cpu().numpy()
@@ -161,142 +100,26 @@ class YOLOv7(nn.Module):
 
         # nms
         scores, labels, bboxes = multiclass_nms(
-            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
-
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
         return bboxes, scores, labels
     
-
-    # ---------------------- Main Process for Inference ----------------------
-    @torch.no_grad()
-    def inference_single_image(self, x):
-        # 主干网络
+    def forward(self, x):
         pyramid_feats = self.backbone(x)
+        pyramid_feats = self.pafpn(pyramid_feats)
+        outputs = self.det_head(pyramid_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
 
-        # 颈部网络
-        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-        # 特征金字塔
-        pyramid_feats = self.fpn(pyramid_feats)
-
-        # 检测头
-        all_obj_preds = []
-        all_cls_preds = []
-        all_box_preds = []
-        all_anchors = []
-        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
-            cls_feat, reg_feat = head(feat)
+        if not self.training:
+            all_cls_preds = outputs['pred_cls']
+            all_box_preds = outputs['pred_box']
 
-            # [1, C, H, W]
-            obj_pred = self.obj_preds[level](reg_feat)
-            cls_pred = self.cls_preds[level](cls_feat)
-            reg_pred = self.reg_preds[level](reg_feat)
-
-            # anchors: [M, 2]
-            fmp_size = cls_pred.shape[-2:]
-            anchors = self.generate_anchors(level, fmp_size)
-
-            # [1, C, H, W] -> [H, W, C] -> [M, C]
-            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
-            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
-            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
-
-            # decode bbox
-            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-            pred_x1y1 = ctr_pred - wh_pred * 0.5
-            pred_x2y2 = ctr_pred + wh_pred * 0.5
-            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-            all_obj_preds.append(obj_pred)
-            all_cls_preds.append(cls_pred)
-            all_box_preds.append(box_pred)
-            all_anchors.append(anchors)
-
-        if self.deploy:
-            obj_preds = torch.cat(all_obj_preds, dim=0)
-            cls_preds = torch.cat(all_cls_preds, dim=0)
-            box_preds = torch.cat(all_box_preds, dim=0)
-            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
-            bboxes = box_preds
-            # [n_anchors_all, 4 + C]
-            outputs = torch.cat([bboxes, scores], dim=-1)
-
-        else:
             # post process
-            bboxes, scores, labels = self.post_process(
-                all_obj_preds, all_cls_preds, all_box_preds)
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
             outputs = {
                 "scores": scores,
                 "labels": labels,
                 "bboxes": bboxes
             }
-
-        return outputs
-
-    # ---------------------- Main Process for Training ----------------------
-    def forward(self, x):
-        if not self.trainable:
-            return self.inference_single_image(x)
-        else:
-            # 主干网络
-            pyramid_feats = self.backbone(x)
-
-            # 颈部网络
-            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-            # 特征金字塔
-            pyramid_feats = self.fpn(pyramid_feats)
-
-            # 检测头
-            all_anchors = []
-            all_strides = []
-            all_obj_preds = []
-            all_cls_preds = []
-            all_box_preds = []
-            all_reg_preds = []
-            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
-                cls_feat, reg_feat = head(feat)
-
-                # [B, C, H, W]
-                obj_pred = self.obj_preds[level](reg_feat)
-                cls_pred = self.cls_preds[level](cls_feat)
-                reg_pred = self.reg_preds[level](reg_feat)
-
-                B, _, H, W = cls_pred.size()
-                fmp_size = [H, W]
-                # generate anchor boxes: [M, 4]
-                anchors = self.generate_anchors(level, fmp_size)
-                
-                # stride tensor: [M, 1]
-                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
-
-                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
-                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-
-                # decode bbox
-                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-                pred_x1y1 = ctr_pred - wh_pred * 0.5
-                pred_x2y2 = ctr_pred + wh_pred * 0.5
-                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-                all_obj_preds.append(obj_pred)
-                all_cls_preds.append(cls_pred)
-                all_box_preds.append(box_pred)
-                all_reg_preds.append(reg_pred)
-                all_anchors.append(anchors)
-                all_strides.append(stride_tensor)
-            
-            # output dict
-            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
-                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
-                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
-                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4]
-                       "anchors": all_anchors,           # List(Tensor) [M, 2]
-                       "strides": self.stride,           # List(Int) [8, 16, 32]
-                       "stride_tensors": all_strides     # List(Tensor) [M, 1]
-                       }
-
-            return outputs 
+        
+        return outputs 

+ 98 - 185
yolo/models/yolov10/yolov10_backbone.py

@@ -2,55 +2,88 @@ import torch
 import torch.nn as nn
 
 try:
-    from .modules import Conv, ELANBlock, DownSample
+    from .modules import ConvModule, C2fBlock, SCDown, SPPF, PSABlock
 except:
-    from yolo.models.yolov10.modules import Conv, ELANBlock, DownSample
-    
-
-model_urls = {
-    "elannet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_tiny.pth",
-    "elannet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_large.pth",
-    "elannet_huge": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet_huge.pth",
-}
-
-
-# --------------------- ELANNet -----------------------
-## ELANNet-Tiny
-class ELANNet_Tiny(nn.Module):
-    """
-    ELAN-Net of YOLOv7-Tiny.
-    """
-    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANNet_Tiny, self).__init__()
-        # -------------- Basic parameters --------------
-        self.feat_dims = [32, 64, 128, 256, 512]
-        self.squeeze_ratios = [0.5, 0.5, 0.5, 0.5]   # Stage-1 -> Stage-4
-        self.branch_depths = [1, 1, 1, 1]            # Stage-1 -> Stage-4
+    from  modules import ConvModule, C2fBlock, SCDown, SPPF, PSABlock
+
+
+# ---------------------------- Basic functions ----------------------------
+class Yolov10Backbone(nn.Module):
+    def __init__(self, cfg):
+        super(Yolov10Backbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.model_scale
+        self.feat_dims = [round(64  * cfg.width),
+                          round(128 * cfg.width),
+                          round(256 * cfg.width),
+                          round(512 * cfg.width),
+                          round(512 * cfg.width * cfg.ratio)]
         
-        # -------------- Network parameters --------------
+        # ------------------ Network setting ------------------
         ## P1/2
-        self.layer_1 = Conv(3, self.feat_dims[0], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        ## P2/4: Stage-1
-        self.layer_2 = nn.Sequential(   
-            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
-            ELANBlock(self.feat_dims[1], self.feat_dims[1], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.layer_1 = ConvModule(3, self.feat_dims[0], kernel_size=3, stride=2)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, stride=2),
+            C2fBlock(in_dim     = self.feat_dims[1],
+                     out_dim    = self.feat_dims[1],
+                     num_blocks = round(3*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     use_cib    = False,
+                     )
         )
-        ## P3/8: Stage-2
+        # P3/8
         self.layer_3 = nn.Sequential(
-            nn.MaxPool2d((2, 2), 2),             
-            ELANBlock(self.feat_dims[1], self.feat_dims[2], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, stride=2),
+            C2fBlock(in_dim     = self.feat_dims[2],
+                     out_dim    = self.feat_dims[2],
+                     num_blocks = round(6*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     use_cib    = False,
+                     )
         )
-        ## P4/16: Stage-3
+        # P4/16
         self.layer_4 = nn.Sequential(
-            nn.MaxPool2d((2, 2), 2),             
-            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            SCDown(self.feat_dims[2], self.feat_dims[3], kernel_size=3, stride=2),
+            C2fBlock(in_dim     = self.feat_dims[3],
+                     out_dim    = self.feat_dims[3],
+                     num_blocks = round(6*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     use_cib    = False,
+                     )
         )
-        ## P5/32: Stage-4
+        # P5/32
         self.layer_5 = nn.Sequential(
-            nn.MaxPool2d((2, 2), 2),             
-            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            SCDown(self.feat_dims[3], self.feat_dims[4], kernel_size=3, stride=2),
+            C2fBlock(in_dim     = self.feat_dims[4],
+                     out_dim    = self.feat_dims[4],
+                     num_blocks = round(3*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     use_cib    = True if self.model_scale in "smlx" else False,
+                     )
         )
 
+        # Extra module (no pretrained weight)
+        self.layer_6 = SPPF(in_dim  = int(512 * cfg.width * cfg.ratio),
+                            out_dim = int(512 * cfg.width * cfg.ratio),
+                            )
+        self.layer_7 = PSABlock(in_dim  = int(512 * cfg.width * cfg.ratio),
+                                out_dim = int(512 * cfg.width * cfg.ratio),
+                                expansion = 0.5,
+                                )
+
+        # Initialize all layers
+        self.init_weights()
+                
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, x):
         c1 = self.layer_1(x)
@@ -59,162 +92,41 @@ class ELANNet_Tiny(nn.Module):
         c4 = self.layer_4(c3)
         c5 = self.layer_5(c4)
 
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-## ELANNet-Large
-class ELANNet_Lagre(nn.Module):
-    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANNet_Lagre, self).__init__()
-        # -------------------- Basic parameters --------------------
-        self.feat_dims = [32, 64, 128, 256, 512, 1024, 1024]
-        self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25]  # Stage-1 -> Stage-4
-        self.branch_depths = [2, 2, 2, 2]            # Stage-1 -> Stage-4
-
-        # -------------------- Network parameters --------------------
-        ## P1/2
-        self.layer_1 = nn.Sequential(
-            Conv(3, self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),      
-            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            Conv(self.feat_dims[1], self.feat_dims[1], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P2/4: Stage-1
-        self.layer_2 = nn.Sequential(   
-            Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
-            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P3/8: Stage-2
-        self.layer_3 = nn.Sequential(
-            DownSample(self.feat_dims[3], self.feat_dims[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P4/16: Stage-3
-        self.layer_4 = nn.Sequential(
-            DownSample(self.feat_dims[4], self.feat_dims[4], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P5/32: Stage-4
-        self.layer_5 = nn.Sequential(
-            DownSample(self.feat_dims[5], self.feat_dims[5], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-## ELANNet-Huge
-class ELANNet_Huge(nn.Module):
-    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANNet_Huge, self).__init__()
-        # -------------------- Basic parameters --------------------
-        self.feat_dims = [40, 80, 160, 320, 640, 1280, 1280]
-        self.squeeze_ratios = [0.5, 0.5, 0.5, 0.25]  # Stage-1 -> Stage-4
-        self.branch_depths = [3, 3, 3, 3]            # Stage-1 -> Stage-4
-
-        # -------------------- Network parameters --------------------
-        ## P1/2
-        self.layer_1 = nn.Sequential(
-            Conv(3, self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),      
-            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            Conv(self.feat_dims[1], self.feat_dims[1], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P2/4: Stage-1
-        self.layer_2 = nn.Sequential(   
-            Conv(self.feat_dims[1], self.feat_dims[2], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
-            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.squeeze_ratios[0], self.branch_depths[0], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P3/8: Stage-2
-        self.layer_3 = nn.Sequential(
-            DownSample(self.feat_dims[3], self.feat_dims[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.squeeze_ratios[1], self.branch_depths[1], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P4/16: Stage-3
-        self.layer_4 = nn.Sequential(
-            DownSample(self.feat_dims[4], self.feat_dims[4], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[4], self.feat_dims[5], self.squeeze_ratios[2], self.branch_depths[2], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        ## P5/32: Stage-4
-        self.layer_5 = nn.Sequential(
-            DownSample(self.feat_dims[5], self.feat_dims[5], act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANBlock(self.feat_dims[5], self.feat_dims[6], self.squeeze_ratios[3], self.branch_depths[3], act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
+        c5 = self.layer_6(c5)
+        c5 = self.layer_7(c5)
 
         outputs = [c3, c4, c5]
 
         return outputs
 
 
-# --------------------- Functions -----------------------
-## build backbone
-def build_backbone(cfg, pretrained=False): 
-    # build backbone
-    if cfg['backbone'] == 'elannet_huge':
-        backbone = ELANNet_Huge(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
-    elif cfg['backbone'] == 'elannet_large':
-        backbone = ELANNet_Lagre(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
-    elif cfg['backbone'] == 'elannet_tiny':
-        backbone = ELANNet_Tiny(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
-    # pyramid feat dims
-    feat_dims = backbone.feat_dims[-3:]
-
-    # load imagenet pretrained weight
-    if pretrained:
-        url = model_urls[cfg['backbone']]
-        if url is not None:
-            print('Loading pretrained weight for {}.'.format(cfg['backbone'].upper()))
-            checkpoint = torch.hub.load_state_dict_from_url(
-                url=url, map_location="cpu", check_hash=True)
-            # checkpoint state dict
-            checkpoint_state_dict = checkpoint.pop("model")
-            # model state dict
-            model_state_dict = backbone.state_dict()
-            # check
-            for k in list(checkpoint_state_dict.keys()):
-                if k in model_state_dict:
-                    shape_model = tuple(model_state_dict[k].shape)
-                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                    if shape_model != shape_checkpoint:
-                        checkpoint_state_dict.pop(k)
-                else:
-                    checkpoint_state_dict.pop(k)
-                    print('Unused key: ', k)
-
-            backbone.load_state_dict(checkpoint_state_dict)
-        else:
-            print('No backbone pretrained: ELANNet')        
-
-    return backbone, feat_dims
-
-
 if __name__ == '__main__':
     import time
     from thop import profile
-    cfg = {
-        'pretrained': False,
-        'backbone': 'elannet_tiny',
-        'bk_act': 'silu',
-        'bk_norm': 'BN',
-        'bk_dpw': False,
-    }
-    model, feats = build_backbone(cfg)
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.width = 0.25
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.model_scale = "n"
+
+            self.width = 0.50
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.model_scale = "s"
+
+            self.width = 0.75
+            self.depth = 0.67
+            self.ratio = 1.5
+            self.model_scale = "m"
+
+            self.width = 1.0
+            self.depth = 1.0
+            self.ratio = 1.0
+            self.model_scale = "l"
+
+    cfg = BaseConfig()
+    model = Yolov10Backbone(cfg)
     x = torch.randn(1, 3, 640, 640)
     t0 = time.time()
     outputs = model(x)
@@ -223,6 +135,7 @@ if __name__ == '__main__':
     for out in outputs:
         print(out.shape)
 
+    x = torch.randn(1, 3, 640, 640)
     print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)
     print('==============================')

+ 152 - 63
yolo/models/yolov10/yolov10_head.py

@@ -1,74 +1,163 @@
+import math
 import torch
 import torch.nn as nn
+from typing import List
 
-from .modules import Conv
+try:
+    from .modules import ConvModule, DflLayer
+except:
+    from  modules import ConvModule, DflLayer
 
 
-class DecoupledHead(nn.Module):
-    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+# YOLOv10 detection head
+class Yolov10DetHead(nn.Module):
+    def __init__(self, cfg, fpn_dims: List = [64, 128, 245]):
         super().__init__()
-        print('==============================')
-        print('Head: Decoupled Head')
-        self.in_dim = in_dim
-        self.num_cls_head=cfg['num_cls_head']
-        self.num_reg_head=cfg['num_reg_head']
-        self.act_type=cfg['head_act']
-        self.norm_type=cfg['head_norm']
-
-        # cls head
-        cls_feats = []
-        self.cls_out_dim = max(out_dim, num_classes)
-        for i in range(cfg['num_cls_head']):
-            if i == 0:
-                cls_feats.append(
-                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-            else:
-                cls_feats.append(
-                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-                
-        # reg head
-        reg_feats = []
-        self.reg_out_dim = max(out_dim, 64)
-        for i in range(cfg['num_reg_head']):
-            if i == 0:
-                reg_feats.append(
-                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-            else:
-                reg_feats.append(
-                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-
-        self.cls_feats = nn.Sequential(*cls_feats)
-        self.reg_feats = nn.Sequential(*reg_feats)
-
-
-    def forward(self, x):
+        self.out_stride = cfg.out_stride
+        self.reg_max = cfg.reg_max
+        self.num_classes = cfg.num_classes
+
+        self.cls_dim = max(fpn_dims[0], min(cfg.num_classes, 128))
+        self.reg_dim = max(fpn_dims[0]//4, 16, 4*cfg.reg_max)
+
+        # classification head
+        self.cls_heads = nn.ModuleList(
+            nn.Sequential(
+                nn.Sequential(ConvModule(dim, dim, kernel_size=3, stride=1, groups=dim),
+                              ConvModule(dim, self.cls_dim, kernel_size=1)),
+                nn.Sequential(ConvModule(self.cls_dim, self.cls_dim, kernel_size=3, stride=1, groups=self.cls_dim),
+                              ConvModule(self.cls_dim, self.cls_dim, kernel_size=1)),
+                nn.Conv2d(self.cls_dim, cfg.num_classes, kernel_size=1),
+            )
+            for dim in fpn_dims
+        )
+
+        # bbox regression head
+        self.reg_heads = nn.ModuleList(
+            nn.Sequential(
+                ConvModule(dim, self.reg_dim, kernel_size=3, stride=1),
+                ConvModule(self.reg_dim, self.reg_dim, kernel_size=3, stride=1),
+                nn.Conv2d(self.reg_dim, 4*cfg.reg_max, kernel_size=1),
+            )
+            for dim in fpn_dims
+        )
+
+        # DFL layer for decoding bbox
+        self.dfl_layer = DflLayer(cfg.reg_max)
+        for p in self.dfl_layer.parameters():
+            p.requires_grad = False
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred
+        for i, m in enumerate(self.cls_heads):
+            b = m[-1].bias.view(1, -1)
+            b.data.fill_(math.log(5 / self.num_classes / (640. / self.out_stride[i]) ** 2))
+            m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+        # reg pred
+        for m in self.reg_heads:
+            b = m[-1].bias.view(-1, )
+            b.data.fill_(1.0)
+            m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+            
+            w = m[-1].weight
+            w.data.fill_(0.)
+            m[-1].weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size, level):
         """
-            in_feats: (Tensor) [B, C, H, W]
+            fmp_size: (List) [H, W]
         """
-        cls_feats = self.cls_feats(x)
-        reg_feats = self.reg_feats(x)
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.out_stride[level]
+
+        return anchors
+
+    def forward(self, fpn_feats):
+        anchors = []
+        strides = []
+        cls_preds = []
+        reg_preds = []
+        box_preds = []
+
+        for lvl, (feat, cls_head, reg_head) in enumerate(zip(fpn_feats, self.cls_heads, self.reg_heads)):
+            bs, c, h, w = feat.size()
+            device = feat.device
+            
+            # Prediction
+            cls_pred = cls_head(feat)
+            reg_pred = reg_head(feat)
+
+            # [bs, c, h, w] -> [bs, c, hw] -> [bs, hw, c]
+            cls_pred = cls_pred.flatten(2).permute(0, 2, 1).contiguous()
+            reg_pred = reg_pred.flatten(2).permute(0, 2, 1).contiguous()
+
+            # anchor points: [M, 2]
+            anchor = self.generate_anchors(fmp_size=[h, w], level=lvl).to(device)
+            stride = torch.ones_like(anchor[..., :1]) * self.out_stride[lvl]
+
+            # Decode bbox coords
+            box_pred = self.dfl_layer(reg_pred, anchor[None], self.out_stride[lvl])
+
+            # collect results
+            anchors.append(anchor)
+            strides.append(stride)
+            cls_preds.append(cls_pred)
+            reg_preds.append(reg_pred)
+            box_preds.append(box_pred)
+
+        # output dict
+        outputs = {"pred_cls":       cls_preds,        # List(Tensor) [B, M, C]
+                   "pred_reg":       reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":       box_preds,        # List(Tensor) [B, M, 4]
+                   "anchors":        anchors,          # List(Tensor) [M, 2]
+                   "stride_tensors": strides,          # List(Tensor) [M, 1]
+                   "strides":        self.out_stride,  # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+if __name__=='__main__':
+    from thop import profile
+
+    # YOLOv10-Base config
+    class Yolov10BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.ratio    = 2.0
+            
+            self.reg_max  = 16
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            self.num_classes = 80
+
+    cfg = Yolov10BaseConfig()
+
+    # Random data
+    fpn_dims = [256, 512, 512]
+    x = [torch.randn(1, fpn_dims[0], 80, 80),
+         torch.randn(1, fpn_dims[1], 40, 40),
+         torch.randn(1, fpn_dims[2], 20, 20)]
 
-        return cls_feats, reg_feats
-    
+    # Neck model
+    model = Yolov10DetHead(cfg, fpn_dims)
 
-# build detection head
-def build_head(cfg, in_dim, out_dim, num_classes=80):
-    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+    # Inference
+    outputs = model(x)
 
-    return head
+    print('============ FLOPs & Params ===========')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))
+    

+ 0 - 98
yolo/models/yolov10/yolov10_neck.py

@@ -1,98 +0,0 @@
-import torch
-import torch.nn as nn
-from .modules import Conv
-
-
-# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
-class SPPF(nn.Module):
-    """
-        This code referenced to https://github.com/ultralytics/yolov5
-    """
-    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='lrelu', norm_type='BN'):
-        super().__init__()
-        inter_dim = int(in_dim * expand_ratio)
-        self.out_dim = out_dim
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
-
-    def forward(self, x):
-        x = self.cv1(x)
-        y1 = self.m(x)
-        y2 = self.m(y1)
-
-        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
-
-
-# SPPF block with CSP module
-class SPPFBlockCSP(nn.Module):
-    """
-        CSP Spatial Pyramid Pooling Block
-    """
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expand_ratio=0.5,
-                 pooling_size=5,
-                 act_type='lrelu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super(SPPFBlockCSP, self).__init__()
-        inter_dim = int(in_dim * expand_ratio)
-        self.out_dim = out_dim
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.Sequential(
-            Conv(inter_dim, inter_dim, k=3, p=1, 
-                 act_type=act_type, norm_type=norm_type, 
-                 depthwise=depthwise),
-            SPPF(inter_dim, 
-                 inter_dim, 
-                 expand_ratio=1.0, 
-                 pooling_size=pooling_size, 
-                 act_type=act_type, 
-                 norm_type=norm_type),
-            Conv(inter_dim, inter_dim, k=3, p=1, 
-                 act_type=act_type, norm_type=norm_type, 
-                 depthwise=depthwise)
-        )
-        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
-
-        
-    def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        x3 = self.m(x2)
-        y = self.cv3(torch.cat([x1, x3], dim=1))
-
-        return y
-
-
-def build_neck(cfg, in_dim, out_dim):
-    model = cfg['neck']
-    print('==============================')
-    print('Neck: {}'.format(model))
-    # build neck
-    if model == 'sppf':
-        neck = SPPF(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            expand_ratio=cfg['expand_ratio'], 
-            pooling_size=cfg['pooling_size'],
-            act_type=cfg['neck_act'],
-            norm_type=cfg['neck_norm']
-            )
-    elif model == 'csp_sppf':
-        neck = SPPFBlockCSP(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            expand_ratio=cfg['expand_ratio'], 
-            pooling_size=cfg['pooling_size'],
-            act_type=cfg['neck_act'],
-            norm_type=cfg['neck_norm'],
-            depthwise=cfg['neck_depthwise']
-            )
-
-    return neck
-        

+ 123 - 124
yolo/models/yolov10/yolov10_pafpn.py

@@ -1,146 +1,145 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from .modules import Conv, ELANBlockFPN, DownSample
-
-
-# PaFPN-ELAN (YOLOv7's)
-class Yolov7PaFPN(nn.Module):
-    def __init__(self, 
-                 in_dims=[512, 1024, 512],
-                 out_dim=None,
-                 channel_width : float = 1.0,
-                 branch_width  : int   = 4.0,
-                 branch_depth  : int   = 1.0,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False):
-        super(Yolov7PaFPN, self).__init__()
-        # ----------------------------- Basic parameters -----------------------------
-        self.fpn_dims = in_dims
-        self.channel_width = channel_width
-        self.branch_width = branch_width
-        self.branch_depth = branch_depth
-        c3, c4, c5 = self.fpn_dims
-
-        # ----------------------------- Top-down FPN -----------------------------
+from typing import List
+
+try:
+    from .modules import ConvModule, C2fBlock, SCDown
+except:
+    from  modules import ConvModule, C2fBlock, SCDown
+
+
+# YOLOv10's PaFPN
+class Yolov10PaFPN(nn.Module):
+    def __init__(self, cfg, in_dims :List = [256, 512, 1024]) -> None:
+        super(Yolov10PaFPN, self).__init__()
+        # --------------------------- Basic Parameters ---------------------------
+        self.model_scale = cfg.model_scale
+        self.in_dims = in_dims[::-1]
+        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
+
+        # ----------------------------- Yolov10's Top-down FPN -----------------------------
         ## P5 -> P4
-        self.reduce_layer_1 = Conv(c5, round(256*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.reduce_layer_2 = Conv(c4, round(256*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.top_down_layer_1 = ELANBlockFPN(in_dim=round(256*channel_width) + round(256*channel_width),
-                                             out_dim=round(256*channel_width),
-                                             squeeze_ratio=0.5,
-                                             branch_width=branch_width,
-                                             branch_depth=branch_depth,
-                                             act_type=act_type,
-                                             norm_type=norm_type,
-                                             depthwise=depthwise
-                                             )
+        self.top_down_layer_1 = C2fBlock(in_dim     = self.in_dims[0] + self.in_dims[1],
+                                         out_dim    = round(512*cfg.width),
+                                         expansion  = 0.5,
+                                         num_blocks = round(3 * cfg.depth),
+                                         shortcut   = False,
+                                         use_cib    = True if self.model_scale in "mlx" else False
+                                         )
         ## P4 -> P3
-        self.reduce_layer_3 = Conv(round(256*channel_width), round(128*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.reduce_layer_4 = Conv(c3, round(128*channel_width), k=1, norm_type=norm_type, act_type=act_type)
-        self.top_down_layer_2 = ELANBlockFPN(in_dim=round(128*channel_width) + round(128*channel_width),
-                                             out_dim=round(128*channel_width),
-                                             squeeze_ratio=0.5,
-                                             branch_width=branch_width,
-                                             branch_depth=branch_depth,
-                                             act_type=act_type,
-                                             norm_type=norm_type,
-                                             depthwise=depthwise
-                                             )
-        # ----------------------------- Bottom-up FPN -----------------------------
+        self.top_down_layer_2 = C2fBlock(in_dim     = self.in_dims[2] + round(512*cfg.width),
+                                         out_dim    = round(256*cfg.width),
+                                         expansion  = 0.5,
+                                         num_blocks = round(3 * cfg.depth),
+                                         shortcut   = False,
+                                         use_cib    = False
+                                         )
+        # ----------------------------- Yolov10's Bottom-up PAN -----------------------------
         ## P3 -> P4
-        self.downsample_layer_1 = DownSample(round(128*channel_width), round(256*channel_width), act_type, norm_type, depthwise)
-        self.bottom_up_layer_1 = ELANBlockFPN(in_dim=round(256*channel_width) + round(256*channel_width),
-                                              out_dim=round(256*channel_width),
-                                              squeeze_ratio=0.5,
-                                              branch_width=branch_width,
-                                              branch_depth=branch_depth,
-                                              act_type=act_type,
-                                              norm_type=norm_type,
-                                              depthwise=depthwise
-                                              )
+        self.dowmsample_layer_1 = SCDown(round(256*cfg.width), round(256*cfg.width), kernel_size=3, stride=2)
+        self.bottom_up_layer_1 = C2fBlock(in_dim     = round(256*cfg.width) + round(512*cfg.width),
+                                          out_dim    = round(512*cfg.width),
+                                          expansion  = 0.5,
+                                          num_blocks = round(3 * cfg.depth),
+                                          shortcut   = False,
+                                          use_cib    = True if self.model_scale in "mlx" else False
+                                          )
         ## P4 -> P5
-        self.downsample_layer_2 = DownSample(round(256*channel_width), round(512*channel_width), act_type, norm_type, depthwise)
-        self.bottom_up_layer_2 = ELANBlockFPN(in_dim=round(512*channel_width) + c5,
-                                              out_dim=round(512*channel_width),
-                                              squeeze_ratio=0.5,
-                                              branch_width=branch_width,
-                                              branch_depth=branch_depth,
-                                              act_type=act_type,
-                                              norm_type=norm_type,
-                                              depthwise=depthwise
-                                              )
-        # ----------------------------- Output Proj -----------------------------
-        ## Head convs
-        self.head_conv_1 = Conv(round(128*channel_width), round(256*channel_width), k=3, s=1, p=1, act_type=act_type, norm_type=norm_type)
-        self.head_conv_2 = Conv(round(256*channel_width), round(512*channel_width), k=3, s=1, p=1, act_type=act_type, norm_type=norm_type)
-        self.head_conv_3 = Conv(round(512*channel_width), round(1024*channel_width), k=3, s=1, p=1, act_type=act_type, norm_type=norm_type)
-        ## Output projs
-        if out_dim is not None:
-            self.out_layers = nn.ModuleList([
-                Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-                for in_dim in [round(256*channel_width), round(512*channel_width), round(1024*channel_width)]
-                ])
-            self.out_dim = [out_dim] * 3
-        else:
-            self.out_layers = None
-            self.out_dim = [round(256*channel_width), round(512*channel_width), round(1024*channel_width)]
+        self.dowmsample_layer_2 = SCDown(round(512*cfg.width), round(512*cfg.width), kernel_size=3, stride=2)
+        self.bottom_up_layer_2 = C2fBlock(in_dim     = round(512*cfg.width) + self.in_dims[0],
+                                          out_dim    = round(512*cfg.width*cfg.ratio),
+                                          expansion  = 0.5,
+                                          num_blocks = round(3 * cfg.depth),
+                                          shortcut   = False,
+                                          use_cib    = True if self.model_scale in "mlx" else False
+                                          )
 
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, features):
         c3, c4, c5 = features
 
-        # Top down
+        # ------------------ Top down FPN ------------------
         ## P5 -> P4
-        c6 = self.reduce_layer_1(c5)
-        c7 = F.interpolate(c6, scale_factor=2.0)
-        c8 = torch.cat([c7, self.reduce_layer_2(c4)], dim=1)
-        c9 = self.top_down_layer_1(c8)
+        p5_up = F.interpolate(c5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
+
         ## P4 -> P3
-        c10 = self.reduce_layer_3(c9)
-        c11 = F.interpolate(c10, scale_factor=2.0)
-        c12 = torch.cat([c11, self.reduce_layer_4(c3)], dim=1)
-        c13 = self.top_down_layer_2(c12)
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
 
-        # Bottom up
+        # ------------------ Bottom up FPN ------------------
         ## p3 -> P4
-        c14 = self.downsample_layer_1(c13)
-        c15 = torch.cat([c14, c9], dim=1)
-        c16 = self.bottom_up_layer_1(c15)
-        ## P4 -> P5
-        c17 = self.downsample_layer_2(c16)
-        c18 = torch.cat([c17, c5], dim=1)
-        c19 = self.bottom_up_layer_2(c18)
-
-        c20 = self.head_conv_1(c13)
-        c21 = self.head_conv_2(c16)
-        c22 = self.head_conv_3(c19)
-        out_feats = [c20, c21, c22] # [P3, P4, P5]
-        
-        # output proj layers
-        if self.out_layers is not None:
-            out_feats_proj = []
-            for feat, layer in zip(out_feats, self.out_layers):
-                out_feats_proj.append(layer(feat))
-            return out_feats_proj
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
+
+        ## P4 -> 5
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
 
+        out_feats = [p3, p4, p5] # [P3, P4, P5]
+                
         return out_feats
+    
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv10-Base config
+    class Yolov10BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width = 0.25
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.model_scale = "n"
+
+            self.width = 0.50
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.model_scale = "s"
+
+            self.width = 0.75
+            self.depth = 0.67
+            self.ratio = 1.5
+            self.model_scale = "m"
+
+            self.width = 1.0
+            self.depth = 1.0
+            self.ratio = 1.0
+            self.model_scale = "l"
 
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
 
-def build_fpn(cfg, in_dims, out_dim=None):
-    model = cfg['fpn']
-    # build pafpn
-    if model == 'yolov7_pafpn':
-        fpn_net = Yolov7PaFPN(in_dims       = in_dims,
-                              out_dim       = out_dim,
-                              channel_width = cfg['channel_width'],
-                              branch_width  = cfg['branch_width'],
-                              branch_depth  = cfg['branch_depth'],
-                              act_type      = cfg['fpn_act'],
-                              norm_type     = cfg['fpn_norm'],
-                              depthwise     = cfg['fpn_depthwise']
-                              )
+    cfg = Yolov10BaseConfig()
+    # Build a head
+    in_dims  = [64, 128, 256]
+    fpn = Yolov10PaFPN(cfg, in_dims)
 
+    # Inference
+    x = [torch.randn(1, in_dims[0], 80, 80),
+         torch.randn(1, in_dims[1], 40, 40),
+         torch.randn(1, in_dims[2], 20, 20)]
+    t0 = time.time()
+    output = fpn(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== FPN output ====== ')
+    for level, feat in enumerate(output):
+        print("- Level-{} : ".format(level), feat.shape)
 
-    return fpn_net
+    flops, params = profile(fpn, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))