Browse Source

add E2E-YOLOv8

yjh0410 1 year ago
parent
commit
b286568dcb

+ 12 - 9
yolo/config/__init__.py

@@ -1,13 +1,14 @@
 # ------------------ Model Config ------------------
-from .yolov1_config    import build_yolov1_config
-from .yolov2_config    import build_yolov2_config
-from .yolov3_config    import build_yolov3_config
-from .yolov5_config    import build_yolov5_config
-from .yolov5_af_config import build_yolov5af_config
-from .yolov6_config    import build_yolov6_config
-from .yolov8_config    import build_yolov8_config
-from .gelan_config     import build_gelan_config
-from .rtdetr_config    import build_rtdetr_config
+from .yolov1_config     import build_yolov1_config
+from .yolov2_config     import build_yolov2_config
+from .yolov3_config     import build_yolov3_config
+from .yolov5_config     import build_yolov5_config
+from .yolov5_af_config  import build_yolov5af_config
+from .yolov6_config     import build_yolov6_config
+from .yolov8_config     import build_yolov8_config
+from .yolov8_e2e_config import build_yolov8_e2e_config
+from .gelan_config      import build_gelan_config
+from .rtdetr_config     import build_rtdetr_config
 
 
 def build_config(args):
@@ -26,6 +27,8 @@ def build_config(args):
         cfg = build_yolov5_config(args)
     elif 'yolov6' in args.model:
         cfg = build_yolov6_config(args)
+    elif 'yolov8_e2e' in args.model:
+        cfg = build_yolov8_e2e_config(args)
     elif 'yolov8' in args.model:
         cfg = build_yolov8_config(args)
     elif 'gelan' in args.model:

+ 2 - 2
yolo/config/yolov8_config.py

@@ -88,7 +88,7 @@ class Yolov8BaseConfig(object):
         # ---------------- Lr Scheduler config ----------------
         self.warmup_epoch = 3
         self.lr_scheduler = "cosine"
-        self.max_epoch    = 300
+        self.max_epoch    = 500
         self.eval_epoch   = 10
         self.no_aug_epoch = 20
 
@@ -99,7 +99,7 @@ class Yolov8BaseConfig(object):
         self.mosaic_prob = 0.0
         self.mixup_prob  = 0.0
         self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
-        self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.25]
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
         ## Pixel mean & std
         self.pixel_mean = [0., 0., 0.]
         self.pixel_std  = [255., 255., 255.]

+ 197 - 0
yolo/config/yolov8_e2e_config.py

@@ -0,0 +1,197 @@
+# yolo Config
+
+
+def build_yolov8_e2e_config(args):
+    if   args.model == 'yolov8_e2e_n':
+        return Yolov8E2E_N_Config()
+    elif args.model == 'yolov8_e2e_s':
+        return Yolov8E2E_S_Config()
+    elif args.model == 'yolov8_e2e_m':
+        return Yolov8E2E_M_Config()
+    elif args.model == 'yolov8_e2e_l':
+        return Yolov8E2E_L_Config()
+    elif args.model == 'yolov8_e2e_x':
+        return Yolov8E2E_X_Config()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLOv8-E2E Base config
+class Yolov8E2EBaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.width    = 1.0
+        self.depth    = 1.0
+        self.ratio    = 1.0
+        self.reg_max  = 16
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        self.num_levels = 3
+        self.scale      = "b"
+        ## Backbone
+        self.bk_act   = 'silu'
+        self.bk_norm  = 'BN'
+        self.bk_depthwise = False
+        self.use_pretrained = True
+        ## Neck
+        self.neck_act       = 'silu'
+        self.neck_norm      = 'BN'
+        self.neck_depthwise = False
+        self.neck_expand_ratio = 0.5
+        self.spp_pooling_size  = 5
+        ## FPN
+        self.fpn_act  = 'silu'
+        self.fpn_norm = 'BN'
+        self.fpn_depthwise = False
+        ## Head
+        self.head_act  = 'silu'
+        self.head_norm = 'BN'
+        self.head_depthwise = False
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 100
+        self.val_conf_thresh = 0.001
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.tal_topk_candidates = 10
+        self.tal_alpha = 0.5
+        self.tal_beta  = 6.0
+        ## Loss weight
+        self.loss_cls = 0.5
+        self.loss_box = 7.5
+        self.loss_dfl = 1.5
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.per_image_lr = 0.001 / 64
+        self.base_lr      = None      # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 500
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.use_ablu = True
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.2,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv8-E2E N
+class Yolov8E2E_N_Config(Yolov8E2EBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.25
+        self.depth = 0.34
+        self.ratio = 2.0
+        self.scale = "n"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.5
+
+# YOLOv8-S
+class Yolov8E2E_S_Config(Yolov8E2EBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.50
+        self.depth = 0.34
+        self.ratio = 2.0
+        self.scale = "s"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.5
+
+# YOLOv8-M
+class Yolov8E2E_M_Config(Yolov8E2EBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.75
+        self.depth = 0.67
+        self.ratio = 1.5
+        self.scale = "m"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.5
+
+# YOLOv8-L
+class Yolov8E2E_L_Config(Yolov8E2EBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.0
+        self.depth = 1.0
+        self.ratio = 1.0
+        self.scale = "l"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.5
+
+# YOLOv8-X
+class Yolov8E2E_X_Config(Yolov8E2EBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.25
+        self.depth = 1.0
+        self.ratio = 1.0
+        self.scale = "x"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.5

+ 13 - 9
yolo/models/__init__.py

@@ -2,15 +2,16 @@
 # -*- coding:utf-8 -*-
 
 import torch
-from .yolov1.build    import build_yolov1
-from .yolov2.build    import build_yolov2
-from .yolov3.build    import build_yolov3
-from .yolov5.build    import build_yolov5
-from .yolov5_af.build import build_yolov5af
-from .yolov6.build    import build_yolov6
-from .yolov8.build    import build_yolov8
-from .gelan.build     import build_gelan
-from .rtdetr.build    import build_rtdetr
+from .yolov1.build     import build_yolov1
+from .yolov2.build     import build_yolov2
+from .yolov3.build     import build_yolov3
+from .yolov5.build     import build_yolov5
+from .yolov5_af.build  import build_yolov5af
+from .yolov6.build     import build_yolov6
+from .yolov8.build     import build_yolov8
+from .yolov8_e2e.build import build_yolov8_e2e
+from .gelan.build      import build_gelan
+from .rtdetr.build     import build_rtdetr
 
 
 # build object detector
@@ -35,6 +36,9 @@ def build_model(args, cfg, is_val=False):
     elif 'yolov6' in args.model:
         model, criterion = build_yolov6(cfg, is_val)
     ## YOLOv8
+    elif 'yolov8_e2e' in args.model:
+        model, criterion = build_yolov8_e2e(cfg, is_val)
+    ## YOLOv8
     elif 'yolov8' in args.model:
         model, criterion = build_yolov8(cfg, is_val)
     ## GElan

+ 60 - 0
yolo/models/yolov8_e2e/README.md

@@ -0,0 +1,60 @@
+# End-to-End YOLOv8:
+
+Inspired by YOLOv10, I deploy two parallel detection heads, one using one-to-many assinger (o2m head) and the other using one-to-one assinger (o2o head). To avoid conflicts between the gradients returned by o2o head and o2m head, we truncate the gradients returned from o2o head to the backbone and neck, and only allow the gradients returned from o2m head to update the backbone and neck. This operation is consistent with the practice of YOLOv10. For evaluation, we remove the o2m head and only use o2o head without NMS.
+
+However, I have no GPU to train YOLOv8-E2E.
+
+- VOC
+
+|     Model   | Batch | Scale | AP<sup>val<br>0.5 | Weight |  Logs  |
+|-------------|-------|-------|-------------------|--------|--------|
+| YOLOv8-E2E-S    | 1xb16 |  640  |               |  |  |
+
+- COCO
+
+|    Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |  Logs  |
+|-------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|--------|
+| YOLOv8-E2E-S    | 1xb16 |  640  |                    |               |   26.9            |   8.9             |  |  |
+
+
+
+## Train YOLOv8-E2E
+### Single GPU
+Taking training YOLOv8-E2E-S on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m yolov8_e2e_s -bs 16 --fp16 
+```
+
+### Multi GPU
+Taking training YOLOv8-E2E-S on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root path/to/coco -m yolov8_e2e_s -bs 256 --fp16 
+```
+
+## Test YOLOv8
+Taking testing YOLOv8-E2E-S on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m yolov8_e2e_s --weight path/to/yolov8.pth --show 
+```
+
+## Evaluate YOLOv8
+Taking evaluating YOLOv8-E2E-S on COCO-val as the example,
+```Shell
+python eval.py --cuda -d coco --root path/to/coco -m yolov8_e2e_s --weight path/to/yolov8.pth 
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov8_e2e_s --weight path/to/weight --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov8_e2e_s --weight path/to/weight --show --gif
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m yolov8_e2e_s --weight path/to/weight --show --gif
+```

+ 24 - 0
yolo/models/yolov8_e2e/build.py

@@ -0,0 +1,24 @@
+import torch.nn as nn
+
+from .loss import SetCriterion
+from .yolov8 import Yolov8E2E
+
+
+# build object detector
+def build_yolov8_e2e(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Yolov8E2E(cfg, is_val)
+
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 204 - 0
yolo/models/yolov8_e2e/loss.py

@@ -0,0 +1,204 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.box_ops import bbox2dist, bbox_iou
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import TaskAlignedAssigner
+
+
+class SetCriterion(object):
+    def __init__(self, cfg):
+        # --------------- Basic parameters ---------------
+        self.cfg = cfg
+        self.reg_max = cfg.reg_max
+        self.num_classes = cfg.num_classes
+        # --------------- Loss config ---------------
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+        self.loss_dfl_weight = cfg.loss_dfl
+        # --------------- Matcher config ---------------
+        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
+                                           topk_candidates = cfg.tal_topk_candidates,
+                                           alpha           = cfg.tal_alpha,
+                                           beta            = cfg.tal_beta
+                                           )
+
+    def loss_classes(self, pred_cls, gt_score):
+        # compute bce loss
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
+
+        return loss_cls
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
+        # regression loss
+        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
+        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
+
+        return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
+
+        return loss_dfl
+
+    def compute_loss(self, outputs, targets):
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        bs, num_anchors = cls_preds.shape[:2]
+        device = cls_preds.device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        
+        # --------------- label assignment ---------------
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+
+            if self.cfg.normalize_coords:
+                img_h, img_w = outputs['image_size']
+                tgt_boxs[..., [0, 2]] *= img_w
+                tgt_boxs[..., [1, 3]] *= img_h
+            
+            if self.cfg.box_format == 'xywh':
+                tgt_boxs_x1y1 = tgt_boxs[..., :2] - 0.5 * tgt_boxs[..., 2:]
+                tgt_boxs_x2y2 = tgt_boxs[..., :2] + 0.5 * tgt_boxs[..., 2:]
+                tgt_boxs = torch.cat([tgt_boxs_x1y1, tgt_boxs_x2y2], dim=-1)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+                # There is no valid gt
+                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
+                (
+                    _,
+                    gt_box,     # [1, M, 4]
+                    gt_score,   # [1, M, C]
+                    fg_mask,    # [1, M,]
+                    _
+                ) = self.matcher(
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
+                    )
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = gt_score_targets.sum()
+        
+        # Average loss normalizer across all the GPUs
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
+        loss_box = loss_box.sum() / num_fgs
+
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
+        # total loss
+        losses = loss_cls * self.loss_cls_weight + \
+                 loss_box * self.loss_box_weight + \
+                 loss_dfl * self.loss_dfl_weight
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
+
+        return loss_dict
+    
+    def __call__(self, outputs, targets):
+        self.matcher.topk_candidates = self.cfg.tal_topk_candidates
+        o2m_loss_dict = self.compute_loss(outputs["outputs_o2m"], targets)
+
+        self.matcher.topk_candidates = 1
+        o2o_loss_dict = self.compute_loss(outputs["outputs_o2o"], targets)
+
+        loss_dict = {}
+        loss_dict["losses"] = o2o_loss_dict["losses"] + o2m_loss_dict["losses"]
+        for k in o2m_loss_dict:
+            loss_dict['o2m_' + k] = o2m_loss_dict[k]
+        for k in o2o_loss_dict:
+            loss_dict['o2o_' + k] = o2o_loss_dict[k]
+
+        return loss_dict
+
+if __name__ == "__main__":
+    pass

+ 199 - 0
yolo/models/yolov8_e2e/matcher.py

@@ -0,0 +1,199 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.box_ops import bbox_iou
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self,
+                 num_classes     = 80,
+                 topk_candidates = 10,
+                 alpha           = 0.5,
+                 beta            = 6.0, 
+                 eps             = 1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk_candidates = topk_candidates
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # Assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
+        """Compute alignment metric given predicted and ground truth bounding boxes."""
+        na = pd_bboxes.shape[-2]
+        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
+
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
+        # Get the scores of each grid for each gt cls
+        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
+        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        return align_metric, overlaps
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
+        # (b, max_num_obj, topk)
+        topk_idxs.masked_fill_(~topk_mask, 0)
+
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
+        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
+        for k in range(self.topk_candidates):
+            # Expand topk_idxs for each value of k and add 1 at the specified positions
+            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
+        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
+        # Filter invalid bboxes
+        count_tensor.masked_fill_(count_tensor > 1, 0)
+
+        return count_tensor.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        # Assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # Assigned target scores
+        target_labels.clamp_(0)
+
+        # 10x faster than F.one_hot()
+        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
+                                    dtype=torch.int64,
+                                    device=target_labels.device)  # (b, h*w, 80)
+        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
+
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(-2)
+    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+        fg_mask = mask_pos.sum(-2)
+    # Find each grid serve which gt(index)
+    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+
+    return target_gt_idx, fg_mask, mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 179 - 0
yolo/models/yolov8_e2e/yolov8.py

@@ -0,0 +1,179 @@
+# --------------- Torch components ---------------
+import copy
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .yolov8_backbone import Yolov8Backbone
+from .yolov8_neck     import SPPF
+from .yolov8_pafpn    import Yolov8PaFPN
+from .yolov8_head     import Yolov8DetHead
+from .yolov8_pred     import Yolov8DetPredLayer
+
+
+# End-to-End YOLOv8
+class Yolov8E2E(nn.Module):
+    def __init__(self, cfg, is_val = False):
+        super(Yolov8E2E, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.no_multi_labels  = False if is_val else True
+        
+        # ---------------------- Model Parameters ----------------------
+        ## Backbone
+        self.backbone = Yolov8Backbone(cfg)
+        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
+        ## Neck
+        self.neck     = SPPF(cfg, self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
+        self.pyramid_feat_dims[-1] = self.neck.out_dim
+        ## Neck: PaFPN
+        self.fpn      = Yolov8PaFPN(cfg, self.backbone.feat_dims)
+        ## Head (one-to-one)
+        self.head_o2o = Yolov8DetHead(cfg, self.fpn.out_dims)
+        ## Pred (one-to-one)
+        self.pred_o2o = Yolov8DetPredLayer(cfg, self.head_o2o.cls_head_dim, self.head_o2o.reg_head_dim)
+
+        ## Aux head (one-to-many)
+        self.head_o2m = copy.deepcopy(self.head_o2o)
+        ## Aux Pred (one-to-many)
+        self.pred_o2m = copy.deepcopy(self.pred_o2o)
+
+    def post_process(self, cls_preds, box_preds):
+        """
+        We process predictions at each scale hierarchically
+        Input:
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+            else:
+                # [M, C] -> [MC,]
+                scores_i = cls_pred_i.sigmoid().flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        return bboxes, scores, labels
+    
+    def inference_o2o(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head_o2o(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred_o2o(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
+
+        all_cls_preds = outputs['pred_cls']
+        all_box_preds = outputs['pred_box']
+
+        # post process (no NMS)
+        bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+        outputs = {
+            "scores": scores,
+            "labels": labels,
+            "bboxes": bboxes
+        }
+        
+        return outputs 
+
+    def forward(self, x):
+        if not self.training:
+            return self.inference_o2o(x)
+        else:
+            # ---------------- Backbone ----------------
+            pyramid_feats = self.backbone(x)
+            # ---------------- Neck: SPP ----------------
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # ---------------- Neck: PaFPN ----------------
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # ---------------- Heads ----------------
+            o2m_cls_feats, o2m_reg_feats = self.head_o2m(pyramid_feats)
+
+            # ---------------- Preds ----------------
+            outputs_o2m = self.pred_o2m(o2m_cls_feats, o2m_reg_feats)
+            outputs_o2m['image_size'] = [x.shape[2], x.shape[3]]
+            
+            # ---------------- Heads (one-to-one) ----------------
+            o2o_cls_feats, o2o_reg_feats = self.head_o2o([feat.detach() for feat in pyramid_feats])
+
+            # ---------------- Preds (one-to-one) ----------------
+            outputs_o2o = self.pred_o2o(o2o_cls_feats, o2o_reg_feats)
+            outputs_o2o['image_size'] = [x.shape[2], x.shape[3]]
+
+            outputs = {
+                "outputs_o2m": outputs_o2m,
+                "outputs_o2o": outputs_o2o,
+            }
+            
+            return outputs 

+ 181 - 0
yolo/models/yolov8_e2e/yolov8_backbone.py

@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov8_basic import BasicConv, ELANLayer
+except:
+    from  yolov8_basic import BasicConv, ELANLayer
+
+# IN1K pretrained weight
+pretrained_urls = {
+    'n': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_n_in1k_62.1.pth",
+    's': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_s_in1k_71.3.pth",
+    'm': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_m_in1k_75.7.pth",
+    'l': None,
+    'x': None,
+}
+
+# ---------------------------- Basic functions ----------------------------
+class Yolov8Backbone(nn.Module):
+    def __init__(self, cfg):
+        super(Yolov8Backbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.scale
+        self.feat_dims = [round(64  * cfg.width),
+                          round(128 * cfg.width),
+                          round(256 * cfg.width),
+                          round(512 * cfg.width),
+                          round(512 * cfg.width * cfg.ratio)]
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(3, self.feat_dims[0],
+                                 kernel_size=3, padding=1, stride=2,
+                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[1],
+                      out_dim    = self.feat_dims[1],
+                      num_blocks = round(3*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[2],
+                      out_dim    = self.feat_dims[2],
+                      num_blocks = round(6*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[3],
+                      out_dim    = self.feat_dims[3],
+                      num_blocks = round(6*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[4],
+                      out_dim    = self.feat_dims[4],
+                      num_blocks = round(3*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+
+        # Initialize all layers
+        self.init_weights()
+        
+        # Load imagenet pretrained weight
+        if cfg.use_pretrained:
+            self.load_pretrained()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def load_pretrained(self):
+        url = pretrained_urls[self.model_scale]
+        if url is not None:
+            print('Loading backbone pretrained weight from : {}'.format(url))
+            # checkpoint state dict
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = self.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print('Unused key: ', k)
+            # load the weight
+            self.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## build Yolo's Backbone
+def build_backbone(cfg): 
+    # model
+    backbone = Yolov8Backbone(cfg)
+        
+    return backbone
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.bk_act = 'silu'
+            self.bk_norm = 'BN'
+            self.bk_depthwise = False
+            self.use_pretrained = True
+            self.width = 0.50
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.scale = "s"
+
+    cfg = BaseConfig()
+    model = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 172 - 0
yolo/models/yolov8_e2e/yolov8_basic.py

@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        use_bias = False if norm_type is not None else True
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- Yolov8 modules ---------------------
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :List  = [1, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = False,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CSPLayer(nn.Module):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 num_blocks  :int   = 1,
+                 kernel_size :List = [3, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = True,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super().__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.input_proj_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.output_proj  = BasicConv(2 * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
+                                                           inter_dim,
+                                                           kernel_size,
+                                                           expansion   = 1.0,
+                                                           shortcut    = shortcut,
+                                                           act_type    = act_type,
+                                                           norm_type   = norm_type,
+                                                           depthwise   = depthwise,
+                                                           ) for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x1 = self.input_proj_1(x)
+        x2 = self.input_proj_2(x)
+        x2 = self.module(x2)
+        out = self.output_proj(torch.cat([x1, x2], dim=1))
+
+        return out
+
+class ELANLayer(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expansion  :float = 0.5,
+                 num_blocks :int   = 1,
+                 shortcut   :bool  = False,
+                 act_type   :str   = 'silu',
+                 norm_type  :str   = 'BN',
+                 depthwise  :bool  = False,
+                 ) -> None:
+        super(ELANLayer, self).__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj  = BasicConv(in_dim, inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.output_proj = BasicConv((2 + num_blocks) * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module      = nn.ModuleList([YoloBottleneck(inter_dim,
+                                                         inter_dim,
+                                                         kernel_size = [3, 3],
+                                                         expansion   = 1.0,
+                                                         shortcut    = shortcut,
+                                                         act_type    = act_type,
+                                                         norm_type   = norm_type,
+                                                         depthwise   = depthwise)
+                                                         for _ in range(num_blocks)])
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.module)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out

+ 179 - 0
yolo/models/yolov8_e2e/yolov8_head.py

@@ -0,0 +1,179 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov8_basic import BasicConv
+except:
+    from  yolov8_basic import BasicConv
+
+
+# -------------------- Detection Head --------------------
+## Single-level Detection Head
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 act_type     :str  = "silu",
+                 norm_type    :str  = "BN",
+                 depthwise    :bool = False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+## Multi-level Detection Head
+class Yolov8DetHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     act_type     = cfg.head_act,
+                     norm_type    = cfg.head_norm,
+                     depthwise    = cfg.head_depthwise)
+                     for level in range(cfg.num_levels)
+                     ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov8BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.ratio    = 2.0
+            self.reg_max  = 16
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+            self.head_act  = 'lrelu'
+            self.head_norm = 'BN'
+            self.head_depthwise = False
+            self.num_cls_head   = 2
+            self.num_reg_head   = 2
+
+    cfg = Yolov8BaseConfig()
+    cfg.num_classes = 20
+
+    # Build a head
+    fpn_dims = [128, 256, 512]
+    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80),
+                     torch.randn(1, fpn_dims[1], 40, 40),
+                     torch.randn(1, fpn_dims[2], 20, 20)]
+    head = Yolov8DetHead(cfg, fpn_dims)
+
+
+    # Inference
+    t0 = time.time()
+    cls_feats, reg_feats = head(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print("====== Yolov8 Head output ======")
+    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
+        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
+
+    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))
+    

+ 85 - 0
yolo/models/yolov8_e2e/yolov8_neck.py

@@ -0,0 +1,85 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov8_basic import BasicConv
+except:
+    from  yolov8_basic import BasicConv
+    
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim):
+        super().__init__()
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * cfg.neck_expand_ratio)
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.cv1 = BasicConv(in_dim, inter_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.cv2 = BasicConv(inter_dim * 4, out_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.m = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size,
+                              stride=1,
+                              padding=cfg.spp_pooling_size // 2)
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov8BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = 32
+            self.max_stride = 32
+            ## Neck
+            self.neck_act       = 'lrelu'
+            self.neck_norm      = 'BN'
+            self.neck_depthwise = False
+            self.neck_expand_ratio = 0.5
+            self.spp_pooling_size  = 5
+
+    cfg = Yolov8BaseConfig()
+    # Build a head
+    in_dim  = 512
+    out_dim = 512
+    neck = SPPF(cfg, in_dim, out_dim)
+
+    # Inference
+    x = torch.randn(1, in_dim, 20, 20)
+    t0 = time.time()
+    output = neck(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('Neck output: ', output.shape)
+
+    flops, params = profile(neck, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 152 - 0
yolo/models/yolov8_e2e/yolov8_pafpn.py

@@ -0,0 +1,152 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+try:
+    from .yolov8_basic import BasicConv, ELANLayer
+except:
+    from  yolov8_basic import BasicConv, ELANLayer
+
+
+# YOLOv8's PaFPN
+class Yolov8PaFPN(nn.Module):
+    def __init__(self,
+                 cfg,
+                 in_dims :List = [256, 512, 1024],
+                 ) -> None:
+        super(Yolov8PaFPN, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("Yolo PaFPN"))
+        # --------------------------- Basic Parameters ---------------------------
+        self.in_dims = in_dims[::-1]
+        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
+
+        # ----------------------------- Yolov8's Top-down FPN -----------------------------
+        ## P5 -> P4
+        self.top_down_layer_1 = ELANLayer(in_dim     = self.in_dims[0] + self.in_dims[1],
+                                          out_dim    = round(512*cfg.width),
+                                          expansion  = 0.5,
+                                          num_blocks = round(3 * cfg.depth),
+                                          shortcut   = False,
+                                          act_type   = cfg.fpn_act,
+                                          norm_type  = cfg.fpn_norm,
+                                          depthwise  = cfg.fpn_depthwise,
+                                          )
+        ## P4 -> P3
+        self.top_down_layer_2 = ELANLayer(in_dim     = self.in_dims[2] + round(512*cfg.width),
+                                          out_dim    = round(256*cfg.width),
+                                          expansion  = 0.5,
+                                          num_blocks = round(3 * cfg.depth),
+                                          shortcut   = False,
+                                          act_type   = cfg.fpn_act,
+                                          norm_type  = cfg.fpn_norm,
+                                          depthwise  = cfg.fpn_depthwise,
+                                          )
+        # ----------------------------- Yolov8's Bottom-up PAN -----------------------------
+        ## P3 -> P4
+        self.dowmsample_layer_1 = BasicConv(round(256*cfg.width), round(256*cfg.width),
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        self.bottom_up_layer_1 = ELANLayer(in_dim     = round(256*cfg.width) + round(512*cfg.width),
+                                           out_dim    = round(512*cfg.width),
+                                           expansion  = 0.5,
+                                           num_blocks = round(3 * cfg.depth),
+                                           shortcut   = False,
+                                           act_type   = cfg.fpn_act,
+                                           norm_type  = cfg.fpn_norm,
+                                           depthwise  = cfg.fpn_depthwise,
+                                           )
+        ## P4 -> P5
+        self.dowmsample_layer_2 = BasicConv(round(512*cfg.width), round(512*cfg.width),
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        self.bottom_up_layer_2 = ELANLayer(in_dim     = round(512*cfg.width) + self.in_dims[0],
+                                           out_dim    = round(512*cfg.width*cfg.ratio),
+                                           expansion  = 0.5,
+                                           num_blocks = round(3 * cfg.depth),
+                                           shortcut   = False,
+                                           act_type   = cfg.fpn_act,
+                                           norm_type  = cfg.fpn_norm,
+                                           depthwise  = cfg.fpn_depthwise,
+                                           )
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # ------------------ Top down FPN ------------------
+        ## P5 -> P4
+        p5_up = F.interpolate(c5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
+
+        ## P4 -> P3
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
+
+        # ------------------ Bottom up FPN ------------------
+        ## p3 -> P4
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
+
+        ## P4 -> 5
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
+
+        out_feats = [p3, p4, p5] # [P3, P4, P5]
+                
+        return out_feats
+    
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov8BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.ratio    = 2.0
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## FPN
+            self.fpn_act  = 'silu'
+            self.fpn_norm = 'BN'
+            self.fpn_depthwise = False
+            ## Head
+            self.head_dim = 256
+
+    cfg = Yolov8BaseConfig()
+    # Build a head
+    in_dims  = [128, 256, 512]
+    fpn = Yolov8PaFPN(cfg, in_dims)
+
+    # Inference
+    x = [torch.randn(1, in_dims[0], 80, 80),
+         torch.randn(1, in_dims[1], 40, 40),
+         torch.randn(1, in_dims[2], 20, 20)]
+    t0 = time.time()
+    output = fpn(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== FPN output ====== ')
+    for level, feat in enumerate(output):
+        print("- Level-{} : ".format(level), feat.shape)
+
+    flops, params = profile(fpn, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 210 - 0
yolo/models/yolov8_e2e/yolov8_pred.py

@@ -0,0 +1,210 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# -------------------- Detection Pred Layer --------------------
+## Single-level pred layer
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 reg_max     :int = 16,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.reg_max = reg_max
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+## Multi-level pred layer
+class Yolov8DetPredLayer(nn.Module):
+    def __init__(self,
+                 cfg,
+                 cls_dim,
+                 reg_dim,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          stride      = cfg.out_stride[level],
+                          reg_max     = cfg.reg_max,
+                          num_classes = cfg.num_classes,
+                          num_coords  = 4 * cfg.reg_max)
+                          for level in range(cfg.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.cfg.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(box_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
+                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov8BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 1.0
+            self.depth    = 1.0
+            self.ratio    = 1.0
+            self.reg_max  = 16
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+
+    cfg = Yolov8BaseConfig()
+    cfg.num_classes = 20
+    cls_dim = 128
+    reg_dim = 64
+    # Build a pred layer
+    pred = Yolov8DetPredLayer(cfg, cls_dim, reg_dim)
+
+    # Inference
+    cls_feats = [torch.randn(1, cls_dim, 80, 80),
+                 torch.randn(1, cls_dim, 40, 40),
+                 torch.randn(1, cls_dim, 20, 20),]
+    reg_feats = [torch.randn(1, reg_dim, 80, 80),
+                 torch.randn(1, reg_dim, 40, 40),
+                 torch.randn(1, reg_dim, 20, 20),]
+    t0 = time.time()
+    output = pred(cls_feats, reg_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== Pred output ======= ')
+    pred_cls = output["pred_cls"]
+    pred_reg = output["pred_reg"]
+    pred_box = output["pred_box"]
+    anchors  = output["anchors"]
+    
+    for level in range(cfg.num_levels):
+        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
+        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
+        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
+        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
+
+    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))