浏览代码

debug YOLOvx-Nano

yjh0410 2 年之前
父节点
当前提交
d5eee62269

+ 21 - 21
README.md

@@ -122,43 +122,43 @@ python train.py --cuda -d coco --root path/to/COCO -m yolov1 -bs 16 --max_epoch
 
 * YOLOv5:
 
-| Model         |   Backbone         | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|---------------|--------------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOv5-N      | CSPDarkNet-N       |  640  |  250  |         29.8           |       47.1        |   7.7             |   2.4              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_n_coco.pth) |
-| YOLOv5-S      | CSPDarkNet-S       |  640  |  250  |         37.8           |       56.5        |   27.1            |   9.0              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_s_coco.pth) |
-| YOLOv5-M      | CSPDarkNet-M       |  640  |  250  |         43.5           |       62.5        |   74.3            |   25.4             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_m_coco.pth) |
-| YOLOv5-L      | CSPDarkNet-L       |  640  |  250  |         46.7           |       65.5        |   155.6           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_l_coco.pth) |
+| Model     |  Backbone    | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|-----------|--------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
+| YOLOv5-N  | CSPDarkNet-N |  640  |  250  |         29.8           |       47.1        |   7.7             |   2.4              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_n_coco.pth) |
+| YOLOv5-S  | CSPDarkNet-S |  640  |  250  |         37.8           |       56.5        |   27.1            |   9.0              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_s_coco.pth) |
+| YOLOv5-M  | CSPDarkNet-M |  640  |  250  |         43.5           |       62.5        |   74.3            |   25.4             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_m_coco.pth) |
+| YOLOv5-L  | CSPDarkNet-L |  640  |  250  |         46.7           |       65.5        |   155.6           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov5_l_coco.pth) |
 
 *For **YOLOv5-M** and **YOLOv5-L**, increasing the batch size may improve performance. Due to my computing resources, I can only set the batch size to 16.*
 
 * YOLOX:
 
-| Model         |   Backbone         | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|---------------|--------------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOX-N       | CSPDarkNet-N       |  640  |  300  |         31.1           |       49.5        |   7.5             |   2.3              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_n_coco.pth) |
-| YOLOX-S       | CSPDarkNet-S       |  640  |  300  |         39.0           |       58.8        |   26.8            |   8.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_s_coco.pth) |
-| YOLOX-M       | CSPDarkNet-M       |  640  |  300  |                        |                   |   74.3            |   25.4             |  |
-| YOLOX-L       | CSPDarkNet-L       |  640  |  300  |                        |                   |   155.4           |   54.2             |  |
+| Model   |   Backbone    | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|---------|---------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
+| YOLOX-N | CSPDarkNet-N  |  640  |  300  |         31.1           |       49.5        |   7.5             |   2.3              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_n_coco.pth) |
+| YOLOX-S | CSPDarkNet-S  |  640  |  300  |         39.0           |       58.8        |   26.8            |   8.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_s_coco.pth) |
+| YOLOX-M | CSPDarkNet-M  |  640  |  300  |                        |                   |   74.3            |   25.4             |  |
+| YOLOX-L | CSPDarkNet-L  |  640  |  300  |                        |                   |   155.4           |   54.2             |  |
 
 *For **YOLOX-M** and **YOLOX-L**, increasing the batch size may improve performance. Due to my computing resources, I can only set the batch size to 16.*
 
 * YOLOv7:
 
-| Model         |   Backbone         | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|---------------|--------------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOv7-T      | ELANNet-Tiny       |  640  |  300  |         38.0           |       56.8        |   22.6            |   7.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
-| YOLOv7-L      | ELANNet-Large      |  640  |  300  |         48.0           |       67.5        |   144.6           |   44.0             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_large_coco.pth) |
+| Model    | Backbone      | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|----------|---------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
+| YOLOv7-T | ELANNet-Tiny  |  640  |  300  |         38.0           |       56.8        |   22.6            |   7.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
+| YOLOv7-L | ELANNet-Large |  640  |  300  |         48.0           |       67.5        |   144.6           |   44.0             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_large_coco.pth) |
 
 *While YOLOv7 incorporates several technical details, such as anchor box, SimOTA, AuxiliaryHead, and RepConv, I found it too challenging to fully reproduce. Instead, I created a simpler version of YOLOv7 using an anchor-free structure and SimOTA. As a result, my reproduction had poor performance due to the absence of the other technical details. However, since it was only intended as a tutorial, I am not too concerned about this gap.*
 
-* YOLOX2 (Incremental Improved YOLOX):
+* YOLOvx (Incremental Improved YOLOX):
 
 | Model    | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |----------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOX2-N |  640  |  300  |                        |                   |                   |                    |  |
-| YOLOX2-S |  640  |  300  |                        |                   |                   |                    |  |
-| YOLOX2-M |  640  |  300  |                        |                   |                   |                    |  |
-| YOLOX2-L |  640  |  300  |                        |                   |                   |                    |  |
+| YOLOvx-N |  640  |  300  |                        |                   |                   |                    |  |
+| YOLOvx-S |  640  |  300  |                        |                   |                   |                    |  |
+| YOLOvx-M |  640  |  300  |                        |                   |                   |                    |  |
+| YOLOvx-L |  640  |  300  |                        |                   |                   |                    |  |
 
 * E2E-YOLO (End-to-End YOLO without NMS):
 

+ 4 - 4
config/__init__.py

@@ -80,8 +80,8 @@ from .model_config.yolov3_config import yolov3_cfg
 from .model_config.yolov4_config import yolov4_cfg
 from .model_config.yolov5_config import yolov5_cfg
 from .model_config.yolov7_config import yolov7_cfg
+from .model_config.yolovx_config import yolovx_cfg
 from .model_config.yolox_config import yolox_cfg
-from .model_config.yolox2_config import yolox2_cfg
 from .model_config.rtdetr_config import rtdetr_cfg
 from .model_config.e2eyolo_config import e2eyolo_cfg
 
@@ -110,9 +110,9 @@ def build_model_config(args):
     # YOLOX
     elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
         cfg = yolox_cfg[args.model]
-    # YOLOX2
-    elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
-        cfg = yolox2_cfg[args.model]
+    # YOLOvX
+    elif args.model in ['yolovx_n', 'yolovx_s', 'yolovx_m', 'yolovx_l', 'yolovx_x']:
+        cfg = yolovx_cfg[args.model]
     # RT-DETR
     elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
         cfg = rtdetr_cfg[args.model]

+ 0 - 0
config/model_config/yolov7.yaml → config/model_config/official_yaml/yolov7.yaml


+ 18 - 15
config/model_config/yolox2_config.py → config/model_config/yolovx_config.py

@@ -1,8 +1,8 @@
-# YOLOX2 Config
+# YOLOvx Config
 
 
-yolox2_cfg = {
-    'yolox2_n':{
+yolovx_cfg = {
+    'yolovx_n':{
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'elannet',
@@ -36,32 +36,35 @@ yolox2_cfg = {
         'num_cls_head': 2,
         'num_reg_head': 2,
         'head_depthwise': False,
+        'reg_max': 16,
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.5],   # 320 -> 960
         'trans_type': 'yolox_nano',
         # ---------------- Assignment config ----------------
-        'matcher': {'soft_center_radius': 3.0,
-                    'topk_candicate': 13,
-                    'iou_weight': 3.0},
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
         # ---------------- Loss config ----------------
         ## loss weight
+        'cls_loss': 'vfl', # vfl (optional)
         'loss_cls_weight': 1.0,
-        'loss_box_weight': 5.0,
+        'loss_iou_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
-        ## close strong augmentation
+        # training configuration
         'no_aug_epoch': 20,
         'trainer_type': 'yolo',
-        ## optimizer
-        'optimizer': 'sgd',        # optional: sgd, AdamW
-        'momentum': 0.9,           # SGD: 0.9;      AdamW: None
+        # optimizer
+        'optimizer': 'sgd',        # optional: sgd, adam, adamw
+        'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
         'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
-        'clip_grad': 10.0,         # SGD: 10.0;     AdamW: -1
-        ## model EMA
+        'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+        # model EMA
         'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
         'ema_tau': 2000,
-        ## lr schedule
-        'scheduler': 'cos_linear',
+        # lr schedule
+        'scheduler': 'linear',
         'lr0': 0.01,               # SGD: 0.01;     AdamW: 0.001
         'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.01
         'warmup_momentum': 0.8,

+ 4 - 4
models/detectors/__init__.py

@@ -8,8 +8,8 @@ from .yolov3.build import build_yolov3
 from .yolov4.build import build_yolov4
 from .yolov5.build import build_yolov5
 from .yolov7.build import build_yolov7
+from .yolovx.build import build_yolovx
 from .yolox.build import build_yolox
-from .yolox2.build import build_yolox2
 from .rtdetr.build import build_rtdetr
 from .e2eyolo.build import build_e2eyolo
 
@@ -49,9 +49,9 @@ def build_model(args,
     elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
         model, criterion = build_yolox(
             args, model_cfg, device, num_classes, trainable, deploy)
-    # YOLOX2
-    elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
-        model, criterion = build_yolox2(
+    # YOLOvx
+    elif args.model in ['yolovx_n', 'yolovx_s', 'yolovx_m', 'yolovx_l', 'yolovx_x']:
+        model, criterion = build_yolovx(
             args, model_cfg, device, num_classes, trainable, deploy)
     # RT-DETR
     elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:

+ 3 - 3
models/detectors/yolox2/build.py → models/detectors/yolovx/build.py

@@ -5,16 +5,16 @@ import torch
 import torch.nn as nn
 
 from .loss import build_criterion
-from .yolox2 import YOLOX2
+from .yolovx import YOLOvx
 
 
 # build object detector
-def build_yolox2(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+def build_yolovx(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
         
     # -------------- Build YOLO --------------
-    model = YOLOX2(
+    model = YOLOvx(
         cfg=cfg,
         device=device, 
         num_classes=num_classes,

+ 283 - 0
models/detectors/yolovx/loss.py

@@ -0,0 +1,283 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .matcher import TaskAlignedAssigner
+from utils.box_ops import bbox2dist, bbox_iou
+
+
+
+class Criterion(object):
+    def __init__(self, 
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        self.reg_max = cfg['reg_max']
+        self.use_dfl = cfg['reg_max'] > 1
+        # loss
+        self.cls_lossf = ClassificationLoss(cfg, reduction='none')
+        self.reg_lossf = RegressionLoss(num_classes, cfg['reg_max'] - 1, self.use_dfl)
+        # loss weight
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_iou_weight = cfg['loss_iou_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = TaskAlignedAssigner(
+            topk=matcher_config['topk'],
+            num_classes=num_classes,
+            alpha=matcher_config['alpha'],
+            beta=matcher_config['beta']
+            )
+
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_boxs']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        strides = outputs['stride_tensor']
+        anchors = outputs['anchors']
+        anchors = torch.cat(anchors, dim=0)
+        num_anchors = anchors.shape[0]
+
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        
+        # label assignment
+        gt_label_targets = []
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+                # There is no valid gt
+                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
+                (
+                    gt_label,   #[1, M]
+                    gt_box,     #[1, M, 4]
+                    gt_score,   #[1, M, C]
+                    fg_mask,    #[1, M,]
+                    _
+                ) = self.matcher(
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
+                    )
+            gt_label_targets.append(gt_label)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        
+        # cls loss
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        gt_label_targets = torch.where(
+            fg_masks > 0,
+            gt_label_targets,
+            torch.full_like(gt_label_targets, self.num_classes)
+            )
+        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
+        loss_cls = self.cls_lossf(cls_preds, gt_score_targets, gt_labels_one_hot)
+
+        # reg loss
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)                           # [BM, 2]
+        strides = torch.cat(strides, dim=0).unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)  # [BM, 1]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1, keepdim=True)                 # [BM, 1]
+        reg_preds = reg_preds.view(-1, 4*self.reg_max)                                 # [BM, 4*(reg_max + 1)]
+        box_preds = box_preds.view(-1, 4)                                              # [BM, 4]
+        loss_iou, loss_dfl = self.reg_lossf(
+            pred_regs = reg_preds,
+            pred_boxs = box_preds,
+            anchors = anchors,
+            gt_boxs = gt_bbox_targets,
+            bbox_weight = bbox_weight,
+            fg_masks = fg_masks,
+            strides = strides,
+            )
+        
+        # normalize loss
+        gt_score_targets_sum = max(gt_score_targets.sum(), 1)
+        loss_cls = loss_cls.sum() / gt_score_targets_sum
+        loss_iou = loss_iou.sum() / gt_score_targets_sum
+        loss_dfl = loss_dfl.sum() / gt_score_targets_sum
+
+        # total loss
+        losses = loss_cls * self.loss_cls_weight + \
+                 loss_iou * self.loss_iou_weight
+        if self.use_dfl:
+            losses += loss_dfl * self.loss_dfl_weight
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_iou = loss_iou,
+                    loss_dfl = loss_dfl,
+                    losses = losses
+            )
+        else:
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_iou = loss_iou,
+                    losses = losses
+            )
+
+        return loss_dict
+    
+
+class ClassificationLoss(nn.Module):
+    def __init__(self, cfg, reduction='none'):
+        super(ClassificationLoss, self).__init__()
+        self.cfg = cfg
+        self.reduction = reduction
+        # For VFL
+        self.alpha = 0.75
+        self.gamma = 2.0
+
+    def varifocalloss(self, pred_logits, gt_score, gt_label, alpha=0.75, gamma=2.0):
+        focal_weight = alpha * pred_logits.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
+        with torch.cuda.amp.autocast(enabled=False):
+            bce_loss = F.binary_cross_entropy_with_logits(
+                pred_logits.float(), gt_score.float(), reduction='none')
+            loss = bce_loss * focal_weight
+
+            if self.reduction == 'sum':
+                loss = loss.sum()
+            elif self.reduction == 'mean':
+                loss = loss.mean()
+
+        return loss
+
+    def binary_cross_entropy(self, pred_logits, gt_score):
+        loss = F.binary_cross_entropy_with_logits(
+            pred_logits.float(), gt_score.float(), reduction='none')
+
+        if self.reduction == 'sum':
+            loss = loss.sum()
+        elif self.reduction == 'mean':
+            loss = loss.mean()
+
+        return loss
+
+
+    def forward(self, pred_logits, gt_score, gt_label):
+        if self.cfg['cls_loss'] == 'bce':
+            return self.binary_cross_entropy(pred_logits, gt_score)
+        elif self.cfg['cls_loss'] == 'vfl':
+            return self.varifocalloss(pred_logits, gt_score, gt_label, self.alpha, self.gamma)
+
+
+class RegressionLoss(nn.Module):
+    def __init__(self, num_classes, reg_max, use_dfl):
+        super(RegressionLoss, self).__init__()
+        self.num_classes = num_classes
+        self.reg_max = reg_max
+        self.use_dfl = use_dfl
+
+
+    def df_loss(self, pred_regs, target):
+        gt_left = target.to(torch.long)
+        gt_right = gt_left + 1
+        weight_left = gt_right.to(torch.float) - target
+        weight_right = 1 - weight_left
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss = (loss_left + loss_right).mean(-1, keepdim=True)
+        
+        return loss
+
+
+    def forward(self, pred_regs, pred_boxs, anchors, gt_boxs, bbox_weight, fg_masks, strides):
+        """
+        Input:
+            pred_regs: (Tensor) [BM, 4*(reg_max + 1)]
+            pred_boxs: (Tensor) [BM, 4]
+            anchors: (Tensor) [BM, 2]
+            gt_boxs: (Tensor) [BM, 4]
+            bbox_weight: (Tensor) [BM, 1]
+            fg_masks: (Tensor) [BM,]
+            strides: (Tensor) [BM, 1]
+        """
+        # select positive samples mask
+        num_pos = fg_masks.sum()
+
+        if num_pos > 0:
+            pred_boxs_pos = pred_boxs[fg_masks]
+            gt_boxs_pos = gt_boxs[fg_masks]
+
+            # iou loss
+            ious = bbox_iou(pred_boxs_pos,
+                            gt_boxs_pos,
+                            xywh=False,
+                            CIoU=True)
+            loss_iou = (1.0 - ious) * bbox_weight
+               
+            # dfl loss
+            if self.use_dfl:
+                pred_regs_pos = pred_regs[fg_masks]
+                gt_boxs_s = gt_boxs / strides
+                anchors_s = anchors / strides
+                gt_ltrb_s = bbox2dist(anchors_s, gt_boxs_s, self.reg_max)
+                gt_ltrb_s_pos = gt_ltrb_s[fg_masks]
+                loss_dfl = self.df_loss(pred_regs_pos, gt_ltrb_s_pos)
+                loss_dfl *= bbox_weight
+            else:
+                loss_dfl = pred_regs.sum() * 0.
+
+        else:
+            loss_iou = pred_regs.sum() * 0.
+            loss_dfl = pred_regs.sum() * 0.
+
+        return loss_iou, loss_dfl
+
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 204 - 0
models/detectors/yolovx/matcher.py

@@ -0,0 +1,204 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.box_ops import bbox_iou
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self,
+                 topk=10,
+                 num_classes=80,
+                 alpha=0.5,
+                 beta=6.0, 
+                 eps=1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk = topk
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        """This code referenced to
+           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
+        Args:
+            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            anc_points (Tensor): shape(num_total_anchors, 2)
+            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+        Returns:
+            target_labels (Tensor): shape(bs, num_total_anchors)
+            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+        """
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
+        # get the scores of each grid for each gt cls
+        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
+
+        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False)
+        overlaps = overlaps.squeeze(3).clamp(0)
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+
+        return align_metric, overlaps
+
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+
+        num_anchors = metrics.shape[-1]  # h*w
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
+        # (b, max_num_obj, topk)
+        topk_idxs[~topk_mask] = 0
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+        # filter invalid bboxes
+        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
+        return is_in_topk.to(metrics.dtype)
+
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        """
+        Args:
+            gt_labels: (b, max_num_obj, 1)
+            gt_bboxes: (b, max_num_obj, 4)
+            target_gt_idx: (b, h*w)
+            fg_mask: (b, h*w)
+        """
+
+        # assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # assigned target scores
+        target_labels.clamp(0)
+        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(axis=-2)
+    if fg_mask.max() > 1:
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
+        max_overlaps_idx = overlaps.argmax(axis=1)
+        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
+        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
+        fg_mask = mask_pos.sum(axis=-2)
+    target_gt_idx = mask_pos.argmax(axis=-2)
+    return target_gt_idx, fg_mask , mask_pos
+
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 76 - 43
models/detectors/yolox2/yolox2.py → models/detectors/yolovx/yolovx.py

@@ -1,19 +1,20 @@
 # --------------- Torch components ---------------
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 
 # --------------- Model components ---------------
-from .yolox2_backbone import build_backbone
-from .yolox2_neck import build_neck
-from .yolox2_pafpn import build_fpn
-from .yolox2_head import build_head
+from .yolovx_backbone import build_backbone
+from .yolovx_neck import build_neck
+from .yolovx_pafpn import build_fpn
+from .yolovx_head import build_head
 
 # --------------- External components ---------------
 from utils.misc import multiclass_nms
 
 
-# YOLOX-2
-class YOLOX2(nn.Module):
+# YOLOvx
+class YOLOvx(nn.Module):
     def __init__(self, 
                  cfg,
                  device, 
@@ -23,11 +24,12 @@ class YOLOX2(nn.Module):
                  trainable = False, 
                  topk = 1000,
                  deploy = False):
-        super(YOLOX2, self).__init__()
+        super(YOLOvx, self).__init__()
         # ---------------------- Basic Parameters ----------------------
         self.cfg = cfg
         self.device = device
         self.stride = cfg['stride']
+        self.reg_max = cfg['reg_max']
         self.num_classes = num_classes
         self.trainable = trainable
         self.conf_thresh = conf_thresh
@@ -37,6 +39,11 @@ class YOLOX2(nn.Module):
         self.head_dim = round(256*cfg['width'])
         
         # ---------------------- Network Parameters ----------------------
+        ## ----------- proj_conv ------------
+        self.proj = nn.Parameter(torch.linspace(0, cfg['reg_max'], cfg['reg_max']), requires_grad=False)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
+        self.proj_conv.weight = nn.Parameter(self.proj.view([1, cfg['reg_max'], 1, 1]).clone().detach(), requires_grad=False)
+
         ## ----------- Backbone -----------
         self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
 
@@ -57,7 +64,7 @@ class YOLOX2(nn.Module):
                                 for _ in range(len(self.stride))
                               ]) 
         self.reg_preds = nn.ModuleList(
-                            [nn.Conv2d(self.head_dim, 4, kernel_size=1) 
+                            [nn.Conv2d(self.head_dim, 4*cfg['reg_max'], kernel_size=1) 
                                 for _ in range(len(self.stride))
                               ])                 
 
@@ -156,36 +163,45 @@ class YOLOX2(nn.Module):
             reg_pred = self.reg_preds[level](reg_feat)
             
             # anchors: [M, 2]
-            fmp_size = cls_feat.shape[-2:]
-            anchors = self.generate_anchors(level, fmp_size)
+            B, _, H, W = cls_feat.size()
+            anchors = self.generate_anchors(level, [H, W])
             
-            # [1, C, H, W] -> [H, W, C] -> [M, C]
-            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
-            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
-
-            # decode bbox
-            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-            pred_x1y1 = ctr_pred - wh_pred * 0.5
-            pred_x2y2 = ctr_pred + wh_pred * 0.5
-            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-            all_cls_preds.append(cls_pred)
-            all_box_preds.append(box_pred)
+            # process preds
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+            # ----------------------- Decode bbox -----------------------
+            B, M = reg_pred.shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = anchors[None] - reg_pred[..., :2] * self.stride[level]
+            x2y2_pred = anchors[None] + reg_pred[..., 2:] * self.stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect preds
+            all_cls_preds.append(cls_pred[0])
+            all_box_preds.append(box_pred[0])
 
         if self.deploy:
+            # no post process
             cls_preds = torch.cat(all_cls_preds, dim=0)
-            box_preds = torch.cat(all_box_preds, dim=0)
-            scores = cls_preds.sigmoid()
-            bboxes = box_preds
+            box_pred = torch.cat(all_box_preds, dim=0)
             # [n_anchors_all, 4 + C]
-            outputs = torch.cat([bboxes, scores], dim=-1)
+            outputs = torch.cat([box_pred, cls_preds.sigmoid()], dim=-1)
 
             return outputs
+
         else:
             # post process
             bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-        
+            
             return bboxes, scores, labels
 
 
@@ -208,7 +224,9 @@ class YOLOX2(nn.Module):
 
             # ---------------- Preds ----------------
             all_anchors = []
+            all_strides = []
             all_cls_preds = []
+            all_reg_preds = []
             all_box_preds = []
             for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
                 # prediction
@@ -216,29 +234,44 @@ class YOLOX2(nn.Module):
                 reg_pred = self.reg_preds[level](reg_feat)
 
                 B, _, H, W = cls_pred.size()
-                fmp_size = [H, W]
                 # generate anchor boxes: [M, 4]
-                anchors = self.generate_anchors(level, fmp_size)
+                anchors = self.generate_anchors(level, [H, W])
+                # stride tensor: [M, 1]
+                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
                 
-                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+                # process preds
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-
-                # decode bbox
-                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-                pred_x1y1 = ctr_pred - wh_pred * 0.5
-                pred_x2y2 = ctr_pred + wh_pred * 0.5
-                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+                # ----------------------- Decode bbox -----------------------
+                B, M = reg_pred.shape[:2]
+                # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+                reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
+                # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+                reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
+                # [B, reg_max, 4, M] -> [B, 1, 4, M]
+                reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
+                # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+                reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
+                ## tlbr -> xyxy
+                x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.stride[level]
+                x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.stride[level]
+                box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+                # collect preds
                 all_cls_preds.append(cls_pred)
+                all_reg_preds.append(reg_pred)
                 all_box_preds.append(box_pred)
                 all_anchors.append(anchors)
+                all_strides.append(stride_tensor)
             
             # output dict
             outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
                        "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
-                       "anchors": all_anchors,           # List(Tensor) [B, M, 2]
-                       'strides': self.stride}           # List(Int) [8, 16, 32]
-
+                       "anchors": all_anchors,           # List(Tensor) [M, 2]
+                       "strides": self.stride,           # List(Int) = [8, 16, 32]
+                       "stride_tensor": all_strides      # List(Tensor) [M, 1]
+                       }
+            
             return outputs 

+ 2 - 2
models/detectors/yolox2/yolox2_backbone.py → models/detectors/yolovx/yolovx_backbone.py

@@ -1,9 +1,9 @@
 import torch
 import torch.nn as nn
 try:
-    from .yolox2_basic import Conv, ELANBlock, DownSample
+    from .yolovx_basic import Conv, ELANBlock, DownSample
 except:
-    from yolox2_basic import Conv, ELANBlock, DownSample
+    from yolovx_basic import Conv, ELANBlock, DownSample
 
 
 

+ 0 - 0
models/detectors/yolox2/yolox2_basic.py → models/detectors/yolovx/yolovx_basic.py


+ 1 - 1
models/detectors/yolox2/yolox2_head.py → models/detectors/yolovx/yolovx_head.py

@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 
-from .yolox2_basic import Conv
+from .yolovx_basic import Conv
 
 
 class SingleLevelHead(nn.Module):

+ 1 - 1
models/detectors/yolox2/yolox2_neck.py → models/detectors/yolovx/yolovx_neck.py

@@ -1,6 +1,6 @@
 import torch
 import torch.nn as nn
-from .yolox2_basic import Conv
+from .yolovx_basic import Conv
 
 
 # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher

+ 1 - 1
models/detectors/yolox2/yolox2_pafpn.py → models/detectors/yolovx/yolovx_pafpn.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .yolox2_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+from .yolovx_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
 
 
 # YOLO-Style PaFPN

+ 0 - 151
models/detectors/yolox2/loss.py

@@ -1,151 +0,0 @@
-import torch
-import torch.nn.functional as F
-from .matcher import AlignedSimOTA
-from utils.box_ops import get_ious
-from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
-
-
-
-class Criterion(object):
-    def __init__(self, 
-                 cfg, 
-                 device, 
-                 num_classes=80):
-        self.cfg = cfg
-        self.device = device
-        self.num_classes = num_classes
-        # loss weight
-        self.loss_cls_weight = cfg['loss_cls_weight']
-        self.loss_box_weight = cfg['loss_box_weight']
-        # matcher
-        matcher_config = cfg['matcher']
-        self.matcher = AlignedSimOTA(
-            num_classes=num_classes,
-            soft_center_radius=matcher_config['soft_center_radius'],
-            topk=matcher_config['topk_candicate'],
-            iou_weight=matcher_config['iou_weight']
-            )
-     
-     
-    def loss_classes(self, pred_cls, target, beta=2.0):
-        """
-            Quality Focal Loss
-            pred_cls: (torch.Tensor): [N, C]。
-            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N,)
-        """
-        label, score = target
-        pred_sigmoid = pred_cls.sigmoid()
-        scale_factor = pred_sigmoid
-        zerolabel = scale_factor.new_zeros(pred_cls.shape)
-
-        ce_loss = F.binary_cross_entropy_with_logits(
-            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
-        
-        bg_class_ind = pred_cls.shape[-1]
-        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
-        pos_label = label[pos].long()
-
-        scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
-
-        ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
-            pred_cls[pos, pos_label], score[pos],
-            reduction='none') * scale_factor.abs().pow(beta)
-
-        return ce_loss
-
-
-    def loss_bboxes(self, pred_box, gt_box):
-        # regression loss
-        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
-        loss_box = 1.0 - ious
-
-        return loss_box
-
-
-    def __call__(self, outputs, targets):        
-        """
-            outputs['pred_cls']: List(Tensor) [B, M, C]
-            outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['strides']: List(Int) [8, 16, 32] output stride
-            targets: (List) [dict{'boxes': [...], 
-                                 'labels': [...], 
-                                 'orig_size': ...}, ...]
-        """
-        bs = outputs['pred_cls'][0].shape[0]
-        device = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
-        anchors = outputs['anchors']
-
-        # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
-
-        cls_targets = []
-        box_targets = []
-        assign_metrics = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
-            # label assignment
-            assigned_result = self.matcher(fpn_strides=fpn_strides,
-                                           anchors=anchors,
-                                           pred_cls=cls_preds[batch_idx].detach(),
-                                           pred_box=box_preds[batch_idx].detach(),
-                                           gt_labels=tgt_labels,
-                                           gt_bboxes=tgt_bboxes
-                                           )
-            cls_targets.append(assigned_result['assigned_labels'])
-            box_targets.append(assigned_result['assigned_bboxes'])
-            assign_metrics.append(assigned_result['assign_metrics'])
-
-        cls_targets = torch.cat(cls_targets, dim=0)
-        box_targets = torch.cat(box_targets, dim=0)
-        assign_metrics = torch.cat(assign_metrics, dim=0)
-
-        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
-        bg_class_ind = self.num_classes
-        pos_inds = ((cls_targets >= 0)
-                    & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
-        # num_fgs = assign_metrics.sum()
-        num_fgs = pos_inds.size(0)
-
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_fgs)
-        num_fgs = max(num_fgs / get_world_size(), 1.0)
-        
-        # cls loss
-        cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
-        loss_cls = loss_cls.sum() / num_fgs
-
-        # regression loss
-        box_preds_pos = box_preds.view(-1, 4)[pos_inds]
-        box_targets_pos = box_targets[pos_inds]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
-        loss_box = loss_box.sum() / box_preds_pos.shape[0]
-
-        # total loss
-        losses = self.loss_cls_weight * loss_cls + \
-                 self.loss_box_weight * loss_box
-
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                losses = losses
-        )
-
-        return loss_dict
-    
-
-def build_criterion(cfg, device, num_classes):
-    criterion = Criterion(
-        cfg=cfg,
-        device=device,
-        num_classes=num_classes
-        )
-
-    return criterion
-
-
-if __name__ == "__main__":
-    pass

+ 0 - 176
models/detectors/yolox2/matcher.py

@@ -1,176 +0,0 @@
-# ---------------------------------------------------------------------
-# Copyright (c) OpenMMLab. All rights reserved.
-# ---------------------------------------------------------------------
-
-
-import torch
-import torch.nn.functional as F
-from utils.box_ops import *
-
-
-# RTMDet SimOTA
-class AlignedSimOTA(object):
-    """
-        This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
-    """
-    def __init__(self, num_classes, soft_center_radius=3.0, topk=13, iou_weight=3.0):
-        self.num_classes = num_classes
-        self.soft_center_radius = soft_center_radius
-        self.topk = topk
-        self.iou_weight = iou_weight
-
-
-    @torch.no_grad()
-    def __call__(self, 
-                 fpn_strides, 
-                 anchors, 
-                 pred_cls, 
-                 pred_box, 
-                 gt_labels,
-                 gt_bboxes):
-        # [M,]
-        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
-        anchors = torch.cat(anchors, dim=0)
-        num_gt = len(gt_labels)
-
-        # check gt
-        if num_gt == 0 or gt_bboxes.max().item() == 0.:
-            return {
-                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
-                                                      self.num_classes,
-                                                      dtype=torch.long),
-                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
-                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
-            }
-        
-        # get inside points: [N, M]
-        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
-        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
-
-        # ----------------------------------- soft center prior -----------------------------------
-        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
-        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
-                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
-        distance = distance * valid_mask.unsqueeze(0)
-        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
-
-        # ----------------------------------- regression cost -----------------------------------
-        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
-        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * self.iou_weight
-
-        # ----------------------------------- classification cost -----------------------------------
-        ## select the predicted scores corresponded to the gt_labels
-        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
-        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
-        ## scale factor
-        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
-        ## cls cost
-        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
-            pairwise_pred_scores, pair_wise_ious,
-            reduction="none") * scale_factor # [N, M]
-            
-        del pairwise_pred_scores
-
-        ## foreground cost matrix
-        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
-        max_pad_value = torch.ones_like(cost_matrix) * 1e9
-        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
-                                  cost_matrix, max_pad_value)
-
-        # ----------------------------------- dynamic label assignment -----------------------------------
-        (
-            matched_pred_ious,
-            matched_gt_inds,
-            fg_mask_inboxes
-        ) = self.dynamic_k_matching(
-            cost_matrix,
-            pair_wise_ious,
-            num_gt
-            )
-        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
-
-        # -----------------------------------process assigned labels -----------------------------------
-        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
-                                             self.num_classes)  # [M,]
-        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
-        assigned_labels = assigned_labels.long()  # [M,]
-
-        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
-        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
-
-        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M, 4]
-        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M, 4]
-
-        assigned_dict = dict(
-            assigned_labels=assigned_labels,
-            assigned_bboxes=assigned_bboxes,
-            assign_metrics=assign_metrics
-            )
-        
-        return assigned_dict
-
-
-    def find_inside_points(self, gt_bboxes, anchors):
-        """
-            gt_bboxes: Tensor -> [N, 2]
-            anchors:   Tensor -> [M, 2]
-        """
-        num_anchors = anchors.shape[0]
-        num_gt = gt_bboxes.shape[0]
-
-        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
-        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
-
-        # offset
-        lt = anchors_expand - gt_bboxes_expand[..., :2]
-        rb = gt_bboxes_expand[..., 2:] - anchors_expand
-        bbox_deltas = torch.cat([lt, rb], dim=-1)
-
-        is_in_gts = bbox_deltas.min(dim=-1).values > 0
-
-        return is_in_gts
-    
-
-    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
-        """Use IoU and matching cost to calculate the dynamic top-k positive
-        targets.
-
-        Args:
-            cost_matrix (Tensor): Cost matrix.
-            pairwise_ious (Tensor): Pairwise iou matrix.
-            num_gt (int): Number of gt.
-            valid_mask (Tensor): Mask for valid bboxes.
-        Returns:
-            tuple: matched ious and gt indexes.
-        """
-        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
-        # select candidate topk ious for dynamic-k calculation
-        candidate_topk = min(self.topk, pairwise_ious.size(1))
-        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
-        # calculate dynamic k for each gt
-        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-
-        # sorting the batch cost matirx is faster than topk
-        _, sorted_indices = torch.sort(cost_matrix, dim=1)
-        for gt_idx in range(num_gt):
-            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
-            matching_matrix[gt_idx, :][topk_ids] = 1
-
-        del topk_ious, dynamic_ks, topk_ids
-
-        prior_match_gt_mask = matching_matrix.sum(0) > 1
-        if prior_match_gt_mask.sum() > 0:
-            cost_min, cost_argmin = torch.min(
-                cost_matrix[:, prior_match_gt_mask], dim=0)
-            matching_matrix[:, prior_match_gt_mask] *= 0
-            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
-
-        # get foreground mask inside box and center prior
-        fg_mask_inboxes = matching_matrix.sum(0) > 0
-        matched_pred_ious = (matching_matrix *
-                             pairwise_ious).sum(0)[fg_mask_inboxes]
-        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
-
-        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes

+ 1 - 1
train.sh

@@ -3,7 +3,7 @@ python train.py \
         --cuda \
         -d coco \
         --root /mnt/share/ssd2/dataset/ \
-        -m yolox2_n \
+        -m yolovx_n \
         -bs 16 \
         -size 640 \
         --wp_epoch 3 \