Browse Source

add YOLOv8

yjh0410 2 years ago
parent
commit
f89fa55b45

+ 6 - 6
README.md

@@ -66,13 +66,13 @@ python train.py --cuda -d voc --root path/to/VOCdevkit -v yolov1 -bs 16 --max_ep
 
 | Model  |   Backbone    | Scale |  IP  | Epoch | AP50 | FPS<sup>3090<br>FP32-bs1 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |--------|---------------|-------|------|-------|------|--------------------------|-------------------|--------------------|--------|
-| YOLOv1 | ResNet-18     |  640  |  √   |  150  | 76.7 |                          |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov1_voc.pth) |
-| YOLOv2 | DarkNet-19    |  640  |  √   |  150  | 79.8 |                          |   53.9            |   30.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov2_voc.pth) |
-| YOLOv3 | DarkNet-53    |  640  |  √   |  150  | 82.0 |                          |   167.4           |   54.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov3_voc.pth) |
-| YOLOv4 | CSPDarkNet-53 |  640  |  √   |  150  | 83.6 |                          |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov4_voc.pth) |
+| YOLOv1 | ResNet-18     |  640  |  √   |  150  | 76.7 |                          |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov1_voc.pth) |
+| YOLOv2 | DarkNet-19    |  640  |  √   |  150  | 79.8 |                          |   53.9            |   30.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov2_voc.pth) |
+| YOLOv3 | DarkNet-53    |  640  |  √   |  150  | 82.0 |                          |   167.4           |   54.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov3_voc.pth) |
+| YOLOv4 | CSPDarkNet-53 |  640  |  √   |  150  | 83.6 |                          |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_voc.pth) |
 | YOLOv5 | CSPDarkNet-L  |  640  |  √   |  150  |      |                          |                   |                    |  |
-| YOLOX  | CSPDarkNet-L  |  640  |  √   |  150  | 84.6 |                          |                   |                    |  |
-| YOLOv7 | ELANNet       |  640  |  √   |  150  |      |                          |                   |                    |  |
+| YOLOX  | CSPDarkNet-L  |  640  |  √   |  150  | 84.6 |                          |   155.4           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_voc.pth) |
+| YOLOv7 | ELANNet       |  640  |  √   |  150  |      |                          |   144.6           |   44.0             |  |
 
 *All models are trained with ImageNet pretrained weight (IP). All FLOPs are measured with a 640x640 image size on VOC2007 test. The FPS is measured with batch size 1 on 3090 GPU from the model inference to the NMS operation.*
 

+ 4 - 0
config/__init__.py

@@ -5,6 +5,7 @@ from .yolov3_config import yolov3_cfg
 from .yolov4_config import yolov4_cfg
 from .yolov5_config import yolov5_cfg
 from .yolov7_config import yolov7_cfg
+from .yolov8_config import yolov8_cfg
 from .yolox_config import yolox_cfg
 
 
@@ -29,6 +30,9 @@ def build_model_config(args):
     # YOLOv7
     elif args.model == 'yolov7':
         cfg = yolov7_cfg
+    # YOLOv8
+    elif args.model == 'yolov8':
+        cfg = yolov8_cfg
     # YOLOX
     elif args.model == 'yolox':
         cfg = yolox_cfg

+ 62 - 0
config/yolov8_config.py

@@ -0,0 +1,62 @@
+# yolov8 config
+
+yolov8_cfg = {
+    # input
+    'trans_type': 'yolov5',
+    'multi_scale': [0.5, 1.0],
+    # model
+    'backbone': 'elan_cspnet',
+    'pretrained': True,
+    'bk_act': 'silu',
+    'bk_norm': 'BN',
+    'bk_dpw': False,
+    'width': 1.0,
+    'depth': 1.0,
+    'ratio': 1.0,
+    'stride': [8, 16, 32],  # P3, P4, P5
+    # neck
+    'neck': 'sppf',
+    'expand_ratio': 0.5,
+    'pooling_size': 5,
+    'neck_act': 'silu',
+    'neck_norm': 'BN',
+    'neck_depthwise': False,
+    # fpn
+    'fpn': 'yolov8_pafpn',
+    'fpn_act': 'silu',
+    'fpn_norm': 'BN',
+    'fpn_depthwise': False,
+    # head
+    'head': 'decoupled_head',
+    'head_act': 'silu',
+    'head_norm': 'BN',
+    'num_cls_head': 2,
+    'num_reg_head': 2,
+    'head_depthwise': False,
+    'reg_max': 16,
+    # matcher
+    'matcher': {'topk': 10,
+                'alpha': 0.5,
+                'beta': 6.0},
+    # loss weight
+    'cls_loss': 'bce', # vfl (optional)
+    'loss_cls_weight': 0.5,
+    'loss_iou_weight': 7.5,
+    'loss_dfl_weight': 1.5,
+    # training configuration
+    'no_aug_epoch': 10,
+    # optimizer
+    'optimizer': 'sgd',      # optional: sgd, adamw
+    'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+    'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+    'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+    # model EMA
+    'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+    'ema_tau': 2000,
+    # lr schedule
+    'scheduler': 'linear',
+    'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
+    'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+    'warmup_momentum': 0.8,
+    'warmup_bias_lr': 0.1,
+}

+ 21 - 3
dataset/voc.py

@@ -271,22 +271,40 @@ if __name__ == "__main__":
         'hsv_v': 0.4,
         # Mosaic & Mixup
         'mosaic_prob': 1.0,
-        'mixup_prob': 0,
+        'mixup_prob': 0.15,
         'mosaic_type': 'yolov5_mosaic',
         'mixup_type': 'yolov5_mixup',
         'mixup_scale': [0.5, 1.5]
     }
+    yolox_trans_config = {
+        'aug_type': 'yolov5',
+        # Basic Augment
+        'degrees': 0.0,
+        'translate': 0.2,
+        'scale': 0.9,
+        'shear': 0.0,
+        'perspective': 0.0,
+        'hsv_h': 0.015,
+        'hsv_s': 0.7,
+        'hsv_v': 0.4,
+        # Mosaic & Mixup
+        'mosaic_prob': 1.0,
+        'mixup_prob': 1.0,
+        'mosaic_type': 'yolov5_mosaic',
+        'mixup_type': 'yolox_mixup',
+        'mixup_scale': [0.5, 1.5]
+    }
     ssd_trans_config = {
         'aug_type': 'ssd',
         'mosaic_prob': 0.0,
         'mixup_prob': 0.0
     }
-    transform = build_transform(img_size, yolov5_trans_config, is_train)
+    transform = build_transform(img_size, yolox_trans_config, is_train)
 
     dataset = VOCDetection(
         img_size=img_size,
         data_dir=args.root,
-        trans_config=yolov5_trans_config,
+        trans_config=yolox_trans_config,
         transform=transform,
         is_train=is_train
         )

+ 1 - 1
eval.py

@@ -28,7 +28,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolox'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolov8', 'yolox'], help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('--conf_thresh', default=0.001, type=float,

+ 6 - 1
models/__init__.py

@@ -8,6 +8,7 @@ from .yolov3.build import build_yolov3
 from .yolov4.build import build_yolov4
 from .yolov5.build import build_yolov5
 from .yolov7.build import build_yolov7
+from .yolov8.build import build_yolov8
 from .yolox.build import build_yolox
 
 
@@ -37,10 +38,14 @@ def build_model(args,
     elif args.model == 'yolov5':
         model, criterion = build_yolov5(
             args, model_cfg, device, num_classes, trainable)
-    # YOLOv5   
+    # YOLOv7
     elif args.model == 'yolov7':
         model, criterion = build_yolov7(
             args, model_cfg, device, num_classes, trainable)
+    # YOLOv8
+    elif args.model == 'yolov8':
+        model, criterion = build_yolov8(
+            args, model_cfg, device, num_classes, trainable)
     # YOLOX   
     elif args.model == 'yolox':
         model, criterion = build_yolox(

+ 30 - 0
models/yolov8/build.py

@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .loss import build_criterion
+from .yolov8 import YOLOv8
+
+
+# build object detector
+def build_yolov8(args, cfg, device, num_classes=80, trainable=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    model = YOLOv8(
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk
+        )
+
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+    return model, criterion

+ 288 - 0
models/yolov8/loss.py

@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .matcher import TaskAlignedAssigner
+from utils.box_ops import bbox2dist, bbox_iou
+
+
+
+class Criterion(object):
+    def __init__(self, 
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        self.reg_max = cfg['reg_max']
+        self.use_dfl = cfg['reg_max'] > 1
+        # loss
+        self.cls_lossf = ClassificationLoss(cfg, reduction='none')
+        self.reg_lossf = RegressionLoss(num_classes, cfg['reg_max'] - 1, self.use_dfl)
+        # loss weight
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_iou_weight = cfg['loss_iou_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = TaskAlignedAssigner(
+            topk=matcher_config['topk'],
+            num_classes=num_classes,
+            alpha=matcher_config['alpha'],
+            beta=matcher_config['beta']
+            )
+
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_boxs']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        strides = outputs['stride_tensor']
+        anchors = outputs['anchors']
+        anchors = torch.cat(anchors, dim=0)
+        num_anchors = anchors.shape[0]
+
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        
+        # label assignment
+        gt_label_targets = []
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+                # There is no valid gt
+                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
+                (
+                    gt_label,   #[1, M,]
+                    gt_box,     #[1, M, 4]
+                    gt_score,   #[1, M, C]
+                    fg_mask     #[1, M,]
+                ) = self.matcher(
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
+                    )
+            gt_label_targets.append(gt_label)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        
+        # cls loss
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        gt_label_targets = torch.where(
+            fg_masks > 0,
+            gt_label_targets,
+            torch.full_like(gt_label_targets, self.num_classes)
+            )
+        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
+        loss_cls = self.cls_lossf(cls_preds, gt_score_targets, gt_labels_one_hot)
+
+        # reg loss
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)                           # [BM, 2]
+        strides = torch.cat(strides, dim=0).unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)  # [BM, 1]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1, keepdim=True)                 # [BM, 1]
+        reg_preds = reg_preds.view(-1, 4*self.reg_max)                                 # [BM, 4*(reg_max + 1)]
+        box_preds = box_preds.view(-1, 4)                                              # [BM, 4]
+        loss_iou, loss_dfl = self.reg_lossf(
+            pred_regs = reg_preds,
+            pred_boxs = box_preds,
+            anchors = anchors,
+            gt_boxs = gt_bbox_targets,
+            bbox_weight = bbox_weight,
+            fg_masks = fg_masks,
+            strides = strides,
+            )
+        
+        loss_cls = loss_cls.sum()
+        loss_iou = loss_iou.sum()
+        loss_dfl = loss_dfl.sum()
+        gt_score_targets_sum = gt_score_targets.sum()
+        # normalize loss
+        if gt_score_targets_sum > 0:
+            loss_cls /= gt_score_targets_sum
+            loss_iou /= gt_score_targets_sum
+            loss_dfl /= gt_score_targets_sum
+
+        # total loss
+        losses = loss_cls * self.loss_cls_weight + \
+                 loss_iou * self.loss_iou_weight
+        if self.use_dfl:
+            losses += loss_dfl * self.loss_dfl_weight
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_iou = loss_iou,
+                    loss_dfl = loss_dfl,
+                    losses = losses
+            )
+        else:
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_iou = loss_iou,
+                    losses = losses
+            )
+
+        return loss_dict
+    
+
+class ClassificationLoss(nn.Module):
+    def __init__(self, cfg, reduction='none'):
+        super(ClassificationLoss, self).__init__()
+        self.cfg = cfg
+        self.reduction = reduction
+        # For VFL
+        self.alpha = 0.75
+        self.gamma = 2.0
+
+
+    def varifocalloss(self, pred_logits, gt_score, gt_label, alpha=0.75, gamma=2.0):
+        focal_weight = alpha * pred_logits.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
+        with torch.cuda.amp.autocast(enabled=False):
+            bce_loss = F.binary_cross_entropy_with_logits(
+                pred_logits.float(), gt_score.float(), reduction='none')
+            loss = bce_loss * focal_weight
+
+            if self.reduction == 'sum':
+                loss = loss.sum()
+            elif self.reduction == 'mean':
+                loss = loss.mean()
+
+        return loss
+
+
+    def binary_cross_entropy(self, pred_logits, gt_score):
+        loss = F.binary_cross_entropy_with_logits(
+            pred_logits.float(), gt_score.float(), reduction='none')
+
+        if self.reduction == 'sum':
+            loss = loss.sum()
+        elif self.reduction == 'mean':
+            loss = loss.mean()
+
+        return loss
+
+
+    def forward(self, pred_logits, gt_score, gt_label):
+        if self.cfg['cls_loss'] == 'vfl':
+            return self.varifocalloss(pred_logits, gt_score, gt_label, self.alpha, self.gamma)
+        elif self.cfg['cls_loss'] == 'bce':
+            return self.binary_cross_entropy(pred_logits, gt_score)
+
+
+class RegressionLoss(nn.Module):
+    def __init__(self, num_classes, reg_max, use_dfl):
+        super(RegressionLoss, self).__init__()
+        self.num_classes = num_classes
+        self.reg_max = reg_max
+        self.use_dfl = use_dfl
+
+
+    def df_loss(self, pred_regs, target):
+        gt_left = target.to(torch.long)
+        gt_right = gt_left + 1
+        weight_left = gt_right.to(torch.float) - target
+        weight_right = 1 - weight_left
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss = (loss_left + loss_right).mean(-1, keepdim=True)
+        
+        return loss
+
+
+    def forward(self, pred_regs, pred_boxs, anchors, gt_boxs, bbox_weight, fg_masks, strides):
+        """
+        Input:
+            pred_regs: (Tensor) [BM, 4*(reg_max + 1)]
+            pred_boxs: (Tensor) [BM, 4]
+            anchors: (Tensor) [BM, 2]
+            gt_boxs: (Tensor) [BM, 4]
+            bbox_weight: (Tensor) [BM, 1]
+            fg_masks: (Tensor) [BM,]
+            strides: (Tensor) [BM, 1]
+        """
+        # select positive samples mask
+        num_pos = fg_masks.sum()
+
+        if num_pos > 0:
+            pred_boxs_pos = pred_boxs[fg_masks]
+            gt_boxs_pos = gt_boxs[fg_masks]
+
+            # iou loss
+            ious = bbox_iou(pred_boxs_pos,
+                            gt_boxs_pos,
+                            xywh=False,
+                            CIoU=True)
+            loss_iou = (1.0 - ious) * bbox_weight
+               
+            # dfl loss
+            if self.use_dfl:
+                pred_regs_pos = pred_regs[fg_masks]
+                gt_boxs_s = gt_boxs / strides
+                anchors_s = anchors / strides
+                gt_ltrb_s = bbox2dist(anchors_s, gt_boxs_s, self.reg_max)
+                gt_ltrb_s_pos = gt_ltrb_s[fg_masks]
+                loss_dfl = self.df_loss(pred_regs_pos, gt_ltrb_s_pos)
+                loss_dfl *= bbox_weight
+            else:
+                loss_dfl = pred_regs.sum() * 0.
+
+        else:
+            loss_iou = pred_regs.sum() * 0.
+            loss_dfl = pred_regs.sum() * 0.
+
+        return loss_iou, loss_dfl
+
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 203 - 0
models/yolov8/matcher.py

@@ -0,0 +1,203 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self,
+                 topk=10,
+                 num_classes=80,
+                 alpha=0.5,
+                 beta=6.0, 
+                 eps=1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk = topk
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        """This code referenced to
+           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
+        Args:
+            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            anc_points (Tensor): shape(num_total_anchors, 2)
+            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+        Returns:
+            target_labels (Tensor): shape(bs, num_total_anchors)
+            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+        """
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.max(axis=-1, keepdim=True)[0]
+        pos_overlaps = (overlaps * mask_pos).max(axis=-1, keepdim=True)[0]
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool()
+
+
+    def get_pos_mask(self,
+                     pd_scores,
+                     pd_bboxes,
+                     gt_labels,
+                     gt_bboxes,
+                     anc_points):
+
+        # get anchor_align metric
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
+        # get in_gts mask
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get topk_metric mask
+        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
+        # merge all mask to a final mask
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+
+    def get_box_metrics(self,
+                        pd_scores,
+                        pd_bboxes,
+                        gt_labels,
+                        gt_bboxes):
+
+        pd_scores = pd_scores.permute(0, 2, 1)
+        gt_labels = gt_labels.long()
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)
+        ind[1] = gt_labels.squeeze(-1)
+        bbox_scores = pd_scores[ind[0], ind[1]]
+
+        overlaps = iou_calculator(gt_bboxes, pd_bboxes)
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+
+        return align_metric, overlaps
+
+
+    def select_topk_candidates(self, metrics, largest=True):
+        num_anchors = metrics.shape[-1]
+        topk_metrics, topk_idxs = torch.topk(
+            metrics, self.topk, axis=-1, largest=largest)
+        topk_mask = (topk_metrics.max(axis=-1, keepdim=True)[0] > self.eps).tile(
+            [1, 1, self.topk])
+        topk_idxs = torch.where(topk_mask, topk_idxs, torch.zeros_like(topk_idxs))
+        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
+        is_in_topk = torch.where(is_in_topk > 1,
+            torch.zeros_like(is_in_topk), is_in_topk)
+        return is_in_topk.to(metrics.dtype)
+
+
+    def get_targets(self,
+                    gt_labels,
+                    gt_bboxes,
+                    target_gt_idx,
+                    fg_mask):
+
+        # assigned target labels
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[...,None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes
+        target_labels = gt_labels.long().flatten()[target_gt_idx]
+
+        # assigned target boxes
+        target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx]
+
+        # assigned target scores
+        target_labels[target_labels<0] = 0
+        target_scores = F.one_hot(target_labels, self.num_classes)
+        fg_scores_mask  = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores,
+                                        torch.full_like(target_scores, 0))
+
+        return target_labels, target_bboxes, target_scores
+    
+
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(axis=-2)
+    if fg_mask.max() > 1:
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
+        max_overlaps_idx = overlaps.argmax(axis=1)
+        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
+        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
+        fg_mask = mask_pos.sum(axis=-2)
+    target_gt_idx = mask_pos.argmax(axis=-2)
+    return target_gt_idx, fg_mask , mask_pos
+
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 300 - 0
models/yolov8/yolov8.py

@@ -0,0 +1,300 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .yolov8_backbone import build_backbone
+from .yolov8_neck import build_neck
+from .yolov8_pafpn import build_fpn
+from .yolov8_head import build_head
+
+from utils.nms import multiclass_nms
+
+
+# Anchor-free YOLO
+class YOLOv8(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 device, 
+                 num_classes = 20, 
+                 conf_thresh = 0.05,
+                 nms_thresh = 0.6,
+                 trainable = False, 
+                 topk = 1000):
+        super(YOLOv8, self).__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.device = device
+        self.stride = cfg['stride']
+        self.reg_max = cfg['reg_max']
+        self.use_dfl = cfg['reg_max'] > 1
+        self.num_classes = num_classes
+        self.trainable = trainable
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+        self.topk = topk
+        
+        # --------- Network Parameters ----------
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
+
+        ## backbone
+        self.backbone, feats_dim = build_backbone(cfg=cfg)
+
+        ## neck
+        self.neck = build_neck(cfg=cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
+        feats_dim[-1] = self.neck.out_dim
+        
+        ## fpn
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim)
+        fpn_dims = self.fpn.out_dim
+
+        ## non-shared heads
+        self.non_shared_heads = nn.ModuleList(
+            [build_head(cfg, feat_dim, fpn_dims, num_classes) 
+            for feat_dim in fpn_dims
+            ])
+
+        ## pred
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 4*(cfg['reg_max']), kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ])                 
+
+        # --------- Network Initialization ----------
+        # init bias
+        self.init_yolo()
+
+
+    def init_yolo(self): 
+        # Init yolo
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eps = 1e-3
+                m.momentum = 0.03    
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # cls pred
+        for cls_pred in self.cls_preds:
+            b = cls_pred.bias.view(1, -1)
+            b.data.fill_(bias_value.item())
+            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        for reg_pred in self.reg_preds:
+            b = reg_pred.bias.view(-1, )
+            b.data.fill_(1.0)
+            reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+            w = reg_pred.weight
+            w.data.fill_(0.)
+            reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+        self.proj = nn.Parameter(torch.linspace(0, self.reg_max, self.reg_max), requires_grad=False)
+        self.proj_conv.weight = nn.Parameter(self.proj.view([1, self.reg_max, 1, 1]).clone().detach(),
+                                                   requires_grad=False)
+
+
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
+        anchor_xy *= self.stride[level]
+        anchors = anchor_xy.to(self.device)
+
+        return anchors
+        
+
+    def decode_boxes(self, anchors, pred_regs, stride):
+        """
+        Input:
+            anchors:  (List[Tensor]) [1, M, 2]
+            pred_reg: (List[Tensor]) [B, M, 4*(reg_max)]
+        Output:
+            pred_box: (Tensor) [B, M, 4]
+        """
+        if self.use_dfl:
+            B, M = pred_regs.shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            pred_regs = pred_regs.reshape([B, M, 4, self.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            pred_regs = pred_regs.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            pred_regs = self.proj_conv(F.softmax(pred_regs, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            pred_regs = pred_regs.view(B, 4, M).permute(0, 2, 1).contiguous()
+
+        # tlbr -> xyxy
+        pred_x1y1 = anchors - pred_regs[..., :2] * stride
+        pred_x2y2 = anchors + pred_regs[..., 2:] * stride
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+
+
+    def post_process(self, cls_preds, reg_preds, anchors):
+        """
+        Input:
+            cls_preds: List(Tensor) [[B, H x W, C], ...]
+            reg_preds: List(Tensor) [[B, H x W, 4*(reg_max)], ...]
+            anchors:   List(Tensor) [[H x W, 2], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for level, (cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(cls_preds, reg_preds, anchors)):
+            # [B, M, C] -> [M, C]
+            cur_cls_pred_i = cls_pred_i[0]
+            cur_reg_pred_i = reg_pred_i[0]
+            # [MC,]
+            scores_i = cur_cls_pred_i.sigmoid().flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, cur_reg_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            cur_reg_pred_i = cur_reg_pred_i[anchor_idxs]
+            anchors_i = anchors_i[anchor_idxs]
+
+            # decode box: [M, 4]
+            box_pred_i = self.decode_boxes(
+                anchors_i[None], cur_reg_pred_i[None], self.stride[level])
+            bboxes = box_pred_i[0]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # threshold
+        keep_idxs = scores.gt(self.conf_thresh)
+        scores = scores[keep_idxs]
+        labels = labels[keep_idxs]
+        bboxes = bboxes[keep_idxs]
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # backbone
+        pyramid_feats = self.backbone(x)
+
+        # neck
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # fpn
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # non-shared heads
+        all_cls_preds = []
+        all_reg_preds = []
+        all_anchors = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # pred
+            cls_pred = self.cls_preds[level](cls_feat)  # [B, C, H, W]
+            reg_pred = self.reg_preds[level](reg_feat)  # [B, 4*(reg_max), H, W]
+
+            B, _, H, W = cls_pred.size()
+            fmp_size = [H, W]
+            # [M, 2]
+            anchors = self.generate_anchors(level, fmp_size)
+
+            # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_anchors.append(anchors)
+
+        # post process
+        bboxes, scores, labels = self.post_process(
+            all_cls_preds, all_reg_preds, all_anchors)
+        
+        return bboxes, scores, labels
+
+
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # backbone
+            pyramid_feats = self.backbone(x)
+
+            # neck
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # fpn
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # non-shared heads
+            all_anchors = []
+            all_cls_preds = []
+            all_reg_preds = []
+            all_box_preds = []
+            all_strides = []
+            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+                cls_feat, reg_feat = head(feat)
+
+                # pred
+                cls_pred = self.cls_preds[level](cls_feat)  # [B, C, H, W]
+                reg_pred = self.reg_preds[level](reg_feat)  # [B, 4*(reg_max), H, W]
+
+                B, _, H, W = cls_pred.size()
+                fmp_size = [H, W]
+                # generate anchor boxes: [M, 2]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+                # decode box: [B, M, 4]
+                box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
+
+                # stride tensor: [M, 1]
+                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
+
+                all_cls_preds.append(cls_pred)
+                all_reg_preds.append(reg_pred)
+                all_box_preds.append(box_pred)
+                all_anchors.append(anchors)
+                all_strides.append(stride_tensor)
+            
+            # output dict
+            outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
+                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                       "anchors": all_anchors,           # List(Tensor) [M, 2]
+                       "strides": self.stride,           # List(Int) = [8, 16, 32]
+                       "stride_tensor": all_strides      # List(Tensor) [M, 1]
+                       }           
+            return outputs

+ 154 - 0
models/yolov8/yolov8_backbone.py

@@ -0,0 +1,154 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov8_basic import Conv, ELAN_CSP_Block
+except:
+    from yolov8_basic import Conv, ELAN_CSP_Block
+
+
+# ---------------------------- ImageNet pretrained weights ----------------------------
+model_urls = {
+    'elan_cspnet_nano': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_nano.pth",
+    'elan_cspnet_small': None,
+    'elan_cspnet_medium': None,
+    'elan_cspnet_large': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_large.pth",
+    'elan_cspnet_huge': None,
+}
+
+
+# ---------------------------- Backbones ----------------------------
+# ELAN-CSPNet
+class ELAN_CSPNet(nn.Module):
+    def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELAN_CSPNet, self).__init__()
+        self.feat_dims = [int(256 * width), int(512 * width), int(512 * width * ratio)]
+        
+        # stride = 2
+        self.layer_1 =  Conv(3, int(64*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type)
+        
+        # stride = 4
+        self.layer_2 = nn.Sequential(
+            Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ELAN_CSP_Block(int(128*width), int(128*width), nblocks=int(3*depth), shortcut=True,
+                           act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # stride = 8
+        self.layer_3 = nn.Sequential(
+            Conv(int(128*width), int(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ELAN_CSP_Block(int(256*width), int(256*width), nblocks=int(6*depth), shortcut=True,
+                           act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # stride = 16
+        self.layer_4 = nn.Sequential(
+            Conv(int(256*width), int(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ELAN_CSP_Block(int(512*width), int(512*width), nblocks=int(6*depth), shortcut=True,
+                           act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # stride = 32
+        self.layer_5 = nn.Sequential(
+            Conv(int(512*width), int(512*width*ratio), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ELAN_CSP_Block(int(512*width*ratio), int(512*width*ratio), nblocks=int(3*depth), shortcut=True,
+                           act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## load pretrained weight
+def load_weight(model, model_name):
+    # load weight
+    print('Loading pretrained weight ...')
+    url = model_urls[model_name]
+    if url is not None:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url=url, map_location="cpu", check_hash=True)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        # model state dict
+        model_state_dict = model.state_dict()
+        # check
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    checkpoint_state_dict.pop(k)
+            else:
+                checkpoint_state_dict.pop(k)
+                print(k)
+
+        model.load_state_dict(checkpoint_state_dict)
+    else:
+        print('No pretrained for {}'.format(model_name))
+
+    return model
+
+
+# build ELAN-Net
+def build_backbone(cfg): 
+    # model
+    backbone = ELAN_CSPNet(
+        width=cfg['width'],
+        depth=cfg['depth'],
+        ratio=cfg['ratio'],
+        act_type=cfg['bk_act'],
+        norm_type=cfg['bk_norm'],
+        depthwise=cfg['bk_dpw']
+        )
+        
+    # check whether to load imagenet pretrained weight
+    if cfg['pretrained']:
+        if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['ratio'] == 2.0:
+            backbone = load_weight(backbone, model_name='elan_cspnet_nano')
+        elif cfg['width'] == 0.5 and cfg['depth'] == 0.34 and cfg['ratio'] == 2.0:
+            backbone = load_weight(backbone, model_name='elan_cspnet_small')
+        elif cfg['width'] == 0.75 and cfg['depth'] == 0.67 and cfg['ratio'] == 1.5:
+            backbone = load_weight(backbone, model_name='elan_cspnet_medium')
+        elif cfg['width'] == 1.0 and cfg['depth'] == 1.0 and cfg['ratio'] == 1.0:
+            backbone = load_weight(backbone, model_name='elan_cspnet_large')
+        elif cfg['width'] == 1.25 and cfg['depth'] == 1.34 and cfg['ratio'] == 1.0:
+            backbone = load_weight(backbone, model_name='elan_cspnet_huge')
+    feat_dims = backbone.feat_dims
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 1.0,
+        'depth': 1.0,
+        'ratio': 1.0,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 132 - 0
models/yolov8/yolov8_basic.py

@@ -0,0 +1,132 @@
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='silu',      # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# BottleNeck
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 shortcut=False,
+                 depthwise=False,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
+        self.cv1 = Conv(in_dim, inter_dim, k=3, p=1, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.cv2 = Conv(inter_dim, out_dim, k=3, p=1, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+
+# ELAN-CSP-Block
+class ELAN_CSP_Block(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 nblocks=1,
+                 shortcut=False,
+                 depthwise=False,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(ELAN_CSP_Block, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.m = nn.Sequential(*(
+            Bottleneck(inter_dim, inter_dim, 1.0, shortcut, depthwise, act_type, norm_type)
+            for _ in range(nblocks)))
+        self.cv3 = Conv((2 + nblocks) * inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        out = list([x1, x2])
+
+        out.extend(m(out[-1]) for m in self.m)
+
+        out = self.cv3(torch.cat(out, dim=1))
+
+        return out

+ 137 - 0
models/yolov8/yolov8_head.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+try:
+    from .yolov8_basic import Conv
+except:
+    from yolov8_basic import Conv
+
+
+class DecoupledHead(nn.Module):
+    def __init__(self, cfg, in_dim, fpn_dims, num_classes=80):
+        super().__init__()
+        print('==============================')
+        print('Head: Decoupled Head')
+        self.in_dim = in_dim
+        self.num_cls_head=cfg['num_cls_head']
+        self.num_reg_head=cfg['num_reg_head']
+        self.act_type=cfg['head_act']
+        self.norm_type=cfg['head_norm']
+
+        # cls head
+        cls_feats = []
+        self.cls_out_dim = max(fpn_dims[0], num_classes)
+        for i in range(cfg['num_cls_head']):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+                
+        # reg head
+        reg_feats = []
+        self.reg_out_dim = max(16, fpn_dims[0]//4, 4*cfg['reg_max'])
+        for i in range(cfg['num_reg_head']):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_head(cfg, in_dim, max_dim, num_classes=80):
+    head = DecoupledHead(cfg, in_dim, max_dim, num_classes) 
+
+    return head
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'head_depthwise': False,
+        'reg_max': 16,
+    }
+    fpn_dims = [256, 512, 512]
+    # Head-1
+    model = build_head(cfg, 256, fpn_dims, num_classes=80)
+    x = torch.randn(1, 256, 80, 80)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-1: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-2
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 40, 40)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-2: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-3
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 20, 20)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-3: Params : {:.2f} M'.format(params / 1e6))

+ 98 - 0
models/yolov8/yolov8_neck.py

@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov8_basic import Conv
+except:
+    from yolov8_basic import Conv
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='', norm_type=''):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+# SPPF block with CSP module
+class SPPFBlockCSP(nn.Module):
+    """
+        CSP Spatial Pyramid Pooling Block
+    """
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 pooling_size=5,
+                 act_type='lrelu',
+                 norm_type='BN',
+                 depthwise=False
+                 ):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.Sequential(
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=act_type, norm_type=norm_type, 
+                 depthwise=depthwise),
+            SPPF(inter_dim, 
+                 inter_dim, 
+                 expand_ratio=1.0, 
+                 pooling_size=pooling_size, 
+                 act_type=act_type, 
+                 norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=act_type, norm_type=norm_type, 
+                 depthwise=depthwise)
+        )
+        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+        
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm']
+            )
+    elif model == 'sppf_block_csp':
+        neck = SPPFBlockCSP(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm'],
+            depthwise=cfg['neck_depthwise']
+            )
+
+    return neck
+    

+ 150 - 0
models/yolov8/yolov8_pafpn.py

@@ -0,0 +1,150 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+try:
+    from .yolov8_basic import Conv, ELAN_CSP_Block
+except:
+    from yolov8_basic import Conv, ELAN_CSP_Block
+
+
+# PaFPN-ELAN
+class Yolov8PaFPN(nn.Module):
+    def __init__(self, 
+                 in_dims=[256, 512, 512],
+                 width=1.0,
+                 depth=1.0,
+                 ratio=1.0,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False):
+        super(Yolov8PaFPN, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("ELAN_PaFPN"))
+        self.in_dims = in_dims
+        self.width = width
+        self.depth = depth
+        c3, c4, c5 = in_dims
+
+        # top dwon
+        ## P5 -> P4
+        self.head_elan_1 = ELAN_CSP_Block(in_dim=c5 + c4,
+                                          out_dim=int(512*width),
+                                          expand_ratio=0.5,
+                                          nblocks=int(3*depth),
+                                          shortcut=False,
+                                          depthwise=depthwise,
+                                          norm_type=norm_type,
+                                          act_type=act_type
+                                          )
+
+        # P4 -> P3
+        self.head_elan_2 = ELAN_CSP_Block(in_dim=int(512*width) + c3,
+                                          out_dim=int(256*width),
+                                          expand_ratio=0.5,
+                                          nblocks=int(3*depth),
+                                          shortcut=False,
+                                          depthwise=depthwise,
+                                          norm_type=norm_type,
+                                          act_type=act_type
+                                          )
+
+
+        # bottom up
+        # P3 -> P4
+        self.mp1 = Conv(int(256*width), int(256*width), k=3, p=1, s=2,
+                        act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.head_elan_3 = ELAN_CSP_Block(in_dim=int(256*width) + int(512*width),
+                                          out_dim=int(512*width),
+                                          expand_ratio=0.5,
+                                          nblocks=int(3*depth),
+                                          shortcut=False,
+                                          depthwise=depthwise,
+                                          norm_type=norm_type,
+                                          act_type=act_type
+                                          )
+
+        # P4 -> P5
+        self.mp2 = Conv(int(512 * width), int(512 * width), k=3, p=1, s=2,
+                        act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.head_elan_4 = ELAN_CSP_Block(in_dim=int(512 * width) + c5,
+                                          out_dim=int(512 * width * ratio),
+                                          expand_ratio=0.5,
+                                          nblocks=int(3*depth),
+                                          shortcut=False,
+                                          depthwise=depthwise,
+                                          norm_type=norm_type,
+                                          act_type=act_type
+                                          )
+
+        self.out_dim = [int(256 * width), int(512 * width), int(512 * width * ratio)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # Top down
+        ## P5 -> P4
+        c6 = F.interpolate(c5, scale_factor=2.0)
+        c7 = torch.cat([c6, c4], dim=1)
+        c8 = self.head_elan_1(c7)
+        ## P4 -> P3
+        c9 = F.interpolate(c8, scale_factor=2.0)
+        c10 = torch.cat([c9, c3], dim=1)
+        c11 = self.head_elan_2(c10)
+
+        # Bottom up
+        # p3 -> P4
+        c12 = self.mp1(c11)
+        c13 = torch.cat([c12, c8], dim=1)
+        c14 = self.head_elan_3(c13)
+        # P4 -> P5
+        c15 = self.mp2(c14)
+        c16 = torch.cat([c15, c5], dim=1)
+        c17 = self.head_elan_4(c16)
+
+        out_feats = [c11, c14, c17] # [P3, P4, P5]
+        
+        return out_feats
+
+
+def build_fpn(cfg, in_dims):
+    model = cfg['fpn']
+    # build neck
+    if model == 'yolov8_pafpn':
+        fpn_net = Yolov8PaFPN(in_dims=in_dims,
+                             width=cfg['width'],
+                             depth=cfg['depth'],
+                             ratio=cfg['ratio'],
+                             act_type=cfg['fpn_act'],
+                             norm_type=cfg['fpn_norm'],
+                             depthwise=cfg['fpn_depthwise']
+                             )
+    return fpn_net
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'fpn': 'Yolov8PaFPN',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        'width': 1.0,
+        'depth': 1.0,
+        'ratio': 1.0,
+    }
+    model = build_fpn(cfg, in_dims=[256, 512, 512])
+    pyramid_feats = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 512, 20, 20)]
+    t0 = time.time()
+    outputs = model(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 1 - 1
models/yolox/yolox.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 
 from .yolox_backbone import build_backbone
-from .yolox_pafpn import build_fpn
+from .yolox_fpn import build_fpn
 from .yolox_head import build_head
 
 from utils.nms import multiclass_nms

+ 0 - 0
models/yolox/yolox_pafpn.py → models/yolox/yolox_fpn.py


+ 1 - 1
test.py

@@ -40,7 +40,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolox'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolov8', 'yolox'], help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,

+ 1 - 1
train.py

@@ -56,7 +56,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolox'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolov8', 'yolox'], help='build yolo')
     parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
                         help='confidence threshold')
     parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,

+ 52 - 0
utils/box_ops.py

@@ -1,4 +1,5 @@
 import torch
+import math
 import numpy as np
 from torchvision.ops.boxes import box_area
 
@@ -97,6 +98,57 @@ def rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas=None):
     return bboxes
 
 
+def bbox2dist(anchor_points, bbox, reg_max):
+    '''Transform bbox(xyxy) to dist(ltrb).'''
+    x1y1, x2y2 = torch.split(bbox, 2, -1)
+    lt = anchor_points - x1y1
+    rb = x2y2 - anchor_points
+    dist = torch.cat([lt, rb], -1).clamp(0, reg_max - 0.01)
+    return dist
+
+
+# copy from YOLOv5
+def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
+    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
+
+    # Get the coordinates of bounding boxes
+    if xywh:  # transform from xywh to xyxy
+        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
+        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
+        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
+        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
+    else:  # x1, y1, x2, y2 = box1
+        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
+        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
+        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+    # Intersection area
+    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
+            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
+
+    # Union Area
+    union = w1 * h1 + w2 * h2 - inter + eps
+
+    # IoU
+    iou = inter / union
+    if CIoU or DIoU or GIoU:
+        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
+        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
+        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
+            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
+            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
+            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
+                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
+                with torch.no_grad():
+                    alpha = v / (v - iou + (1 + eps))
+                return iou - (rho2 / c2 + v * alpha)  # CIoU
+            return iou - rho2 / c2  # DIoU
+        c_area = cw * ch + eps  # convex area
+        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
+    return iou  # IoU
+
+
 if __name__ == '__main__':
     box1 = torch.tensor([[10, 10, 20, 20]])
     box2 = torch.tensor([[15, 15, 20, 20]])