浏览代码

modify YOLOv8's matcher & aligner

yjh0410 2 年之前
父节点
当前提交
0aedff165b
共有 8 个文件被更改,包括 373 次插入212 次删除
  1. 2 2
      README.md
  2. 2 2
      README_CN.md
  3. 2 2
      config/__init__.py
  4. 300 58
      config/yolov8_config.py
  5. 1 1
      models/__init__.py
  6. 10 40
      models/yolov8/loss.py
  7. 54 53
      models/yolov8/matcher.py
  8. 2 54
      models/yolov8/yolov8_backbone.py

+ 2 - 2
README.md

@@ -103,8 +103,8 @@ python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch
 | YOLOv1       | ResNet-18     |  640  |  √   |  150  |       |        27.9            |       47.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov1_coco.pth) |
 | YOLOv2       | DarkNet-19    |  640  |  √   |  150  |       |        32.7            |       50.9        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov2_coco.pth) |
 | YOLOv3       | DarkNet-53    |  640  |  √   |  250  |       |        42.9            |       63.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov3_coco.pth) |
-| YOLOv4       | CSPDarkNet-53 |  640  |  √   |  250  |       |                        |                   |  |
-| YOLOv5       | CSPDarkNet-L  |  640  |  √   |  250  |       |        46.6            |       65.8        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
+| YOLOv4       | CSPDarkNet-L  |  640  |  √   |  250  |       |        46.6            |       65.8        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
+| YOLOv5       | CSPDarkNet-53 |  640  |  √   |  250  |       |                        |                   |  |
 | YOLOX        | CSPDarkNet-L  |  640  |  √   |  300  |       |        46.6            |       66.1        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_coco.pth) |
 | YOLOv7-Nano  | ELANNet-Nano  |  640  |  √   |  300  |       |                        |                   |  |
 | YOLOv7-Tiny  | ELANNet-Tiny  |  640  |  √   |  300  |       |                        |                   |  |

+ 2 - 2
README_CN.md

@@ -106,8 +106,8 @@ python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch
 | YOLOv1       | ResNet-18     |  640  |  √   |  150  |       |        27.9            |       47.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov1_coco.pth) |
 | YOLOv2       | DarkNet-19    |  640  |  √   |  150  |       |        32.7            |       50.9        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov2_coco.pth) |
 | YOLOv3       | DarkNet-53    |  640  |  √   |  250  |       |        42.9            |       63.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov3_coco.pth) |
-| YOLOv4       | CSPDarkNet-53 |  640  |  √   |  250  |       |                        |                   |  |
-| YOLOv5       | CSPDarkNet-L  |  640  |  √   |  250  |       |        46.6            |       65.8        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
+| YOLOv4       | CSPDarkNet-L  |  640  |  √   |  250  |       |        46.6            |       65.8        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
+| YOLOv5       | CSPDarkNet-53 |  640  |  √   |  250  |       |                        |                   |  |
 | YOLOX        | CSPDarkNet-L  |  640  |  √   |  300  |       |        46.6            |       66.1        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_coco.pth) |
 | YOLOv7-Nano  | ELANNet-Nano  |  640  |  √   |  300  |       |                        |                   |  |
 | YOLOv7-Tiny  | ELANNet-Tiny  |  640  |  √   |  300  |       |                        |                   |  |

+ 2 - 2
config/__init__.py

@@ -31,8 +31,8 @@ def build_model_config(args):
     elif args.model in ['yolov7_nano', 'yolov7_tiny', 'yolov7_large', 'yolov7_huge']:
         cfg = yolov7_cfg[args.model]
     # YOLOv8
-    elif args.model == 'yolov8':
-        cfg = yolov8_cfg
+    elif args.model in ['yolov8_nano', 'yolov8_small', 'yolov8_medium', 'yolov8_large', 'yolov8_huge']:
+        cfg = yolov8_cfg[args.model]
     # YOLOX
     elif args.model == 'yolox':
         cfg = yolox_cfg

+ 300 - 58
config/yolov8_config.py

@@ -1,62 +1,304 @@
 # yolov8 config
 
 yolov8_cfg = {
-    # input
-    'trans_type': 'yolov5_strong',
-    'multi_scale': [0.5, 1.5],   # 320 -> 960
-    # model
-    'backbone': 'elan_cspnet',
-    'pretrained': True,
-    'bk_act': 'silu',
-    'bk_norm': 'BN',
-    'bk_dpw': False,
-    'width': 1.0,
-    'depth': 1.0,
-    'ratio': 1.0,
-    'stride': [8, 16, 32],  # P3, P4, P5
-    # neck
-    'neck': 'sppf',
-    'expand_ratio': 0.5,
-    'pooling_size': 5,
-    'neck_act': 'silu',
-    'neck_norm': 'BN',
-    'neck_depthwise': False,
-    # fpn
-    'fpn': 'yolov8_pafpn',
-    'fpn_act': 'silu',
-    'fpn_norm': 'BN',
-    'fpn_depthwise': False,
-    # head
-    'head': 'decoupled_head',
-    'head_act': 'silu',
-    'head_norm': 'BN',
-    'num_cls_head': 2,
-    'num_reg_head': 2,
-    'head_depthwise': False,
-    'reg_max': 16,
-    # matcher
-    'matcher': {'topk': 10,
-                'alpha': 0.5,
-                'beta': 6.0},
-    # loss weight
-    'cls_loss': 'bce', # vfl (optional)
-    'loss_cls_weight': 0.5,
-    'loss_iou_weight': 7.5,
-    'loss_dfl_weight': 1.5,
-    # training configuration
-    'no_aug_epoch': 10,
-    # optimizer
-    'optimizer': 'sgd',        # optional: sgd, adamw
-    'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
-    'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
-    'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
-    # model EMA
-    'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
-    'ema_tau': 2000,
-    # lr schedule
-    'scheduler': 'linear',
-    'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
-    'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
-    'warmup_momentum': 0.8,
-    'warmup_bias_lr': 0.1,
+    'yolov8_nano':{
+        # input
+        'trans_type': 'yolov5_weak',
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
+        # model
+        'backbone': 'elan_cspnet',
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 0.25,
+        'depth': 0.34,
+        'ratio': 2.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        # neck
+        'neck': 'sppf',
+        'expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        # fpn
+        'fpn': 'yolov8_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        # head
+        'head': 'decoupled_head',
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_depthwise': False,
+        'reg_max': 16,
+        # matcher
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
+        # loss weight
+        'cls_loss': 'bce', # vfl (optional)
+        'loss_cls_weight': 0.5,
+        'loss_iou_weight': 7.5,
+        'loss_dfl_weight': 1.5,
+        # training configuration
+        'no_aug_epoch': 10,
+        # optimizer
+        'optimizer': 'sgd',        # optional: sgd, adamw
+        'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+        'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+        # model EMA
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        # lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
+        'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+    },
+
+    'yolov8_small':{
+        # input
+        'trans_type': 'yolov5_strong',
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
+        # model
+        'backbone': 'elan_cspnet',
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 0.5,
+        'depth': 0.34,
+        'ratio': 2.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        # neck
+        'neck': 'sppf',
+        'expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        # fpn
+        'fpn': 'yolov8_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        # head
+        'head': 'decoupled_head',
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_depthwise': False,
+        'reg_max': 16,
+        # matcher
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
+        # loss weight
+        'cls_loss': 'bce', # vfl (optional)
+        'loss_cls_weight': 0.5,
+        'loss_iou_weight': 7.5,
+        'loss_dfl_weight': 1.5,
+        # training configuration
+        'no_aug_epoch': 10,
+        # optimizer
+        'optimizer': 'sgd',        # optional: sgd, adamw
+        'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+        'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+        # model EMA
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        # lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
+        'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+    },
+
+    'yolov8_medium':{
+        # input
+        'trans_type': 'yolov5_strong',
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
+        # model
+        'backbone': 'elan_cspnet',
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 0.75,
+        'depth': 0.67,
+        'ratio': 1.5,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        # neck
+        'neck': 'sppf',
+        'expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        # fpn
+        'fpn': 'yolov8_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        # head
+        'head': 'decoupled_head',
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_depthwise': False,
+        'reg_max': 16,
+        # matcher
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
+        # loss weight
+        'cls_loss': 'bce', # vfl (optional)
+        'loss_cls_weight': 0.5,
+        'loss_iou_weight': 7.5,
+        'loss_dfl_weight': 1.5,
+        # training configuration
+        'no_aug_epoch': 10,
+        # optimizer
+        'optimizer': 'sgd',        # optional: sgd, adamw
+        'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+        'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+        # model EMA
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        # lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
+        'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+    },
+
+    'yolov8_large':{
+        # input
+        'trans_type': 'yolov5_strong',
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
+        # model
+        'backbone': 'elan_cspnet',
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 1.0,
+        'depth': 1.0,
+        'ratio': 1.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        # neck
+        'neck': 'sppf',
+        'expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        # fpn
+        'fpn': 'yolov8_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        # head
+        'head': 'decoupled_head',
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_depthwise': False,
+        'reg_max': 16,
+        # matcher
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
+        # loss weight
+        'cls_loss': 'bce', # vfl (optional)
+        'loss_cls_weight': 0.5,
+        'loss_iou_weight': 7.5,
+        'loss_dfl_weight': 1.5,
+        # training configuration
+        'no_aug_epoch': 10,
+        # optimizer
+        'optimizer': 'sgd',        # optional: sgd, adamw
+        'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+        'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+        # model EMA
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        # lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
+        'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+    },
+
+    'yolov8_huge':{
+        # input
+        'trans_type': 'yolov5_strong',
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
+        # model
+        'backbone': 'elan_cspnet',
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 1.25,
+        'depth': 1.0,
+        'ratio': 1.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        # neck
+        'neck': 'sppf',
+        'expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        # fpn
+        'fpn': 'yolov8_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        # head
+        'head': 'decoupled_head',
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_depthwise': False,
+        'reg_max': 16,
+        # matcher
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
+        # loss weight
+        'cls_loss': 'bce', # vfl (optional)
+        'loss_cls_weight': 0.5,
+        'loss_iou_weight': 7.5,
+        'loss_dfl_weight': 1.5,
+        # training configuration
+        'no_aug_epoch': 10,
+        # optimizer
+        'optimizer': 'sgd',        # optional: sgd, adamw
+        'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+        'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+        # model EMA
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        # lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.01,              # SGD: 0.01;     AdamW: 0.004
+        'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+    },
+
 }

+ 1 - 1
models/__init__.py

@@ -43,7 +43,7 @@ def build_model(args,
         model, criterion = build_yolov7(
             args, model_cfg, device, num_classes, trainable)
     # YOLOv8
-    elif args.model == 'yolov8':
+    elif args.model in ['yolov8_nano', 'yolov8_small', 'yolov8_medium', 'yolov8_large', 'yolov8_huge']:
         model, criterion = build_yolov8(
             args, model_cfg, device, num_classes, trainable)
     # YOLOX   

+ 10 - 40
models/yolov8/loss.py

@@ -58,7 +58,6 @@ class Criterion(object):
         box_preds = torch.cat(outputs['pred_box'], dim=1)
         
         # label assignment
-        gt_label_targets = []
         gt_score_targets = []
         gt_bbox_targets = []
         fg_masks = []
@@ -71,17 +70,17 @@ class Criterion(object):
             if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
                 # There is no valid gt
                 fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
-                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
                 gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
                 gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
             else:
                 tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
                 tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    gt_label,   #[1, M,]
+                    _,
                     gt_box,     #[1, M, 4]
                     gt_score,   #[1, M, C]
-                    fg_mask     #[1, M,]
+                    fg_mask,    #[1, M,]
+                    _
                 ) = self.matcher(
                     pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
                     pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
@@ -89,26 +88,18 @@ class Criterion(object):
                     gt_labels = tgt_labels,
                     gt_bboxes = tgt_boxs
                     )
-            gt_label_targets.append(gt_label)
             gt_score_targets.append(gt_score)
             gt_bbox_targets.append(gt_box)
             fg_masks.append(fg_mask)
 
         # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
         fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
         gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
         gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
         
         # cls loss
         cls_preds = cls_preds.view(-1, self.num_classes)
-        gt_label_targets = torch.where(
-            fg_masks > 0,
-            gt_label_targets,
-            torch.full_like(gt_label_targets, self.num_classes)
-            )
-        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
-        loss_cls = self.cls_lossf(cls_preds, gt_score_targets, gt_labels_one_hot)
+        loss_cls = self.cls_lossf(cls_preds, gt_score_targets)
 
         # reg loss
         anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)                           # [BM, 2]
@@ -126,15 +117,11 @@ class Criterion(object):
             strides = strides,
             )
         
-        loss_cls = loss_cls.sum()
-        loss_iou = loss_iou.sum()
-        loss_dfl = loss_dfl.sum()
-        gt_score_targets_sum = gt_score_targets.sum()
         # normalize loss
-        if gt_score_targets_sum > 0:
-            loss_cls /= gt_score_targets_sum
-            loss_iou /= gt_score_targets_sum
-            loss_dfl /= gt_score_targets_sum
+        gt_score_targets_sum = max(gt_score_targets.sum(), 1)
+        loss_cls = loss_cls.sum() / gt_score_targets_sum
+        loss_iou = loss_iou.sum() / gt_score_targets_sum
+        loss_dfl = loss_dfl.sum() / gt_score_targets_sum
 
         # total loss
         losses = loss_cls * self.loss_cls_weight + \
@@ -167,21 +154,6 @@ class ClassificationLoss(nn.Module):
         self.gamma = 2.0
 
 
-    def varifocalloss(self, pred_logits, gt_score, gt_label, alpha=0.75, gamma=2.0):
-        focal_weight = alpha * pred_logits.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
-        with torch.cuda.amp.autocast(enabled=False):
-            bce_loss = F.binary_cross_entropy_with_logits(
-                pred_logits.float(), gt_score.float(), reduction='none')
-            loss = bce_loss * focal_weight
-
-            if self.reduction == 'sum':
-                loss = loss.sum()
-            elif self.reduction == 'mean':
-                loss = loss.mean()
-
-        return loss
-
-
     def binary_cross_entropy(self, pred_logits, gt_score):
         loss = F.binary_cross_entropy_with_logits(
             pred_logits.float(), gt_score.float(), reduction='none')
@@ -194,10 +166,8 @@ class ClassificationLoss(nn.Module):
         return loss
 
 
-    def forward(self, pred_logits, gt_score, gt_label):
-        if self.cfg['cls_loss'] == 'vfl':
-            return self.varifocalloss(pred_logits, gt_score, gt_label, self.alpha, self.gamma)
-        elif self.cfg['cls_loss'] == 'bce':
+    def forward(self, pred_logits, gt_score):
+        if self.cfg['cls_loss'] == 'bce':
             return self.binary_cross_entropy(pred_logits, gt_score)
 
 

+ 54 - 53
models/yolov8/matcher.py

@@ -1,8 +1,10 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from utils.box_ops import bbox_iou
 
 
+# -------------------------- Task Aligned Assigner --------------------------
 class TaskAlignedAssigner(nn.Module):
     def __init__(self,
                  topk=10,
@@ -54,89 +56,88 @@ class TaskAlignedAssigner(nn.Module):
 
         # normalize
         align_metric *= mask_pos
-        pos_align_metrics = align_metric.max(axis=-1, keepdim=True)[0]
-        pos_overlaps = (overlaps * mask_pos).max(axis=-1, keepdim=True)[0]
-        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
         target_scores = target_scores * norm_align_metric
 
-        return target_labels, target_bboxes, target_scores, fg_mask.bool()
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
 
 
-    def get_pos_mask(self,
-                     pd_scores,
-                     pd_bboxes,
-                     gt_labels,
-                     gt_bboxes,
-                     anc_points):
-
-        # get anchor_align metric
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get anchor_align metric, (b, max_num_obj, h*w)
         align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
-        # get in_gts mask
+        # get in_gts mask, (b, max_num_obj, h*w)
         mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
-        # get topk_metric mask
+        # get topk_metric mask, (b, max_num_obj, h*w)
         mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
-        # merge all mask to a final mask
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
         mask_pos = mask_topk * mask_in_gts
 
         return mask_pos, align_metric, overlaps
 
 
-    def get_box_metrics(self,
-                        pd_scores,
-                        pd_bboxes,
-                        gt_labels,
-                        gt_bboxes):
-
-        pd_scores = pd_scores.permute(0, 2, 1)
-        gt_labels = gt_labels.long()
-        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
-        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)
-        ind[1] = gt_labels.squeeze(-1)
-        bbox_scores = pd_scores[ind[0], ind[1]]
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
+        # get the scores of each grid for each gt cls
+        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
 
-        overlaps = iou_calculator(gt_bboxes, pd_bboxes)
+        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False,
+                            CIoU=True).squeeze(3).clamp(0)
         align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
 
         return align_metric, overlaps
 
 
     def select_topk_candidates(self, metrics, largest=True):
-        num_anchors = metrics.shape[-1]
-        topk_metrics, topk_idxs = torch.topk(
-            metrics, self.topk, axis=-1, largest=largest)
-        topk_mask = (topk_metrics.max(axis=-1, keepdim=True)[0] > self.eps).tile(
-            [1, 1, self.topk])
-        topk_idxs = torch.where(topk_mask, topk_idxs, torch.zeros_like(topk_idxs))
-        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
-        is_in_topk = torch.where(is_in_topk > 1,
-            torch.zeros_like(is_in_topk), is_in_topk)
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+
+        num_anchors = metrics.shape[-1]  # h*w
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
+        # (b, max_num_obj, topk)
+        topk_idxs[~topk_mask] = 0
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+        # filter invalid bboxes
+        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
         return is_in_topk.to(metrics.dtype)
 
 
-    def get_targets(self,
-                    gt_labels,
-                    gt_bboxes,
-                    target_gt_idx,
-                    fg_mask):
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        """
+        Args:
+            gt_labels: (b, max_num_obj, 1)
+            gt_bboxes: (b, max_num_obj, 4)
+            target_gt_idx: (b, h*w)
+            fg_mask: (b, h*w)
+        """
 
-        # assigned target labels
-        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[...,None]
-        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes
-        target_labels = gt_labels.long().flatten()[target_gt_idx]
+        # assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
 
-        # assigned target boxes
-        target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx]
+        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
 
         # assigned target scores
-        target_labels[target_labels<0] = 0
-        target_scores = F.one_hot(target_labels, self.num_classes)
-        fg_scores_mask  = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
-        target_scores = torch.where(fg_scores_mask > 0, target_scores,
-                                        torch.full_like(target_scores, 0))
+        target_labels.clamp(0)
+        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
 
         return target_labels, target_bboxes, target_scores
     
 
+# -------------------------- Basic Functions --------------------------
 def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
     """select the positive anchors's center in gt
     Args:

+ 2 - 54
models/yolov8/yolov8_backbone.py

@@ -7,18 +7,8 @@ except:
     from yolov8_basic import Conv, ELAN_CSP_Block
 
 
-# ---------------------------- ImageNet pretrained weights ----------------------------
-model_urls = {
-    'elan_cspnet_nano': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_nano.pth",
-    'elan_cspnet_small': None,
-    'elan_cspnet_medium': None,
-    'elan_cspnet_large': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elan_cspnet_large.pth",
-    'elan_cspnet_huge': None,
-}
-
-
 # ---------------------------- Backbones ----------------------------
-# ELAN-CSPNet
+## ELAN-CSPNet
 class ELAN_CSPNet(nn.Module):
     def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
         super(ELAN_CSPNet, self).__init__()
@@ -66,37 +56,7 @@ class ELAN_CSPNet(nn.Module):
 
 
 # ---------------------------- Functions ----------------------------
-## load pretrained weight
-def load_weight(model, model_name):
-    # load weight
-    print('Loading pretrained weight ...')
-    url = model_urls[model_name]
-    if url is not None:
-        checkpoint = torch.hub.load_state_dict_from_url(
-            url=url, map_location="cpu", check_hash=True)
-        # checkpoint state dict
-        checkpoint_state_dict = checkpoint.pop("model")
-        # model state dict
-        model_state_dict = model.state_dict()
-        # check
-        for k in list(checkpoint_state_dict.keys()):
-            if k in model_state_dict:
-                shape_model = tuple(model_state_dict[k].shape)
-                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                if shape_model != shape_checkpoint:
-                    checkpoint_state_dict.pop(k)
-            else:
-                checkpoint_state_dict.pop(k)
-                print(k)
-
-        model.load_state_dict(checkpoint_state_dict)
-    else:
-        print('No pretrained for {}'.format(model_name))
-
-    return model
-
-
-# build ELAN-Net
+## build ELAN-Net
 def build_backbone(cfg): 
     # model
     backbone = ELAN_CSPNet(
@@ -108,18 +68,6 @@ def build_backbone(cfg):
         depthwise=cfg['bk_dpw']
         )
         
-    # check whether to load imagenet pretrained weight
-    if cfg['pretrained']:
-        if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['ratio'] == 2.0:
-            backbone = load_weight(backbone, model_name='elan_cspnet_nano')
-        elif cfg['width'] == 0.5 and cfg['depth'] == 0.34 and cfg['ratio'] == 2.0:
-            backbone = load_weight(backbone, model_name='elan_cspnet_small')
-        elif cfg['width'] == 0.75 and cfg['depth'] == 0.67 and cfg['ratio'] == 1.5:
-            backbone = load_weight(backbone, model_name='elan_cspnet_medium')
-        elif cfg['width'] == 1.0 and cfg['depth'] == 1.0 and cfg['ratio'] == 1.0:
-            backbone = load_weight(backbone, model_name='elan_cspnet_large')
-        elif cfg['width'] == 1.25 and cfg['depth'] == 1.34 and cfg['ratio'] == 1.0:
-            backbone = load_weight(backbone, model_name='elan_cspnet_huge')
     feat_dims = backbone.feat_dims
 
     return backbone, feat_dims