Ver Fonte

redesign YOLOX-Plus with TAL & DFL & VFL

yjh0410 há 2 anos atrás
pai
commit
5708765a4c

+ 4 - 2
config/yolox_plus_config.py

@@ -36,6 +36,7 @@ yolox_plus_cfg = {
         'num_cls_head': 2,
         'num_reg_head': 2,
         'head_depthwise': False,
+        'reg_max': 16,
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.5],   # 320 -> 960
@@ -47,9 +48,10 @@ yolox_plus_cfg = {
                     'beta': 6.0},
         # ---------------- Loss config ----------------
         ## loss weight
-        'cls_loss': 'qfl',
+        'cls_loss': 'vfl',
         'loss_cls_weight': 1.0,
-        'loss_box_weight': 2.0,
+        'loss_iou_weight': 2.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,

+ 124 - 64
models/detectors/yolox_plus/loss.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from .matcher import TaskAlignedAssigner
-from utils.box_ops import bbox_iou
+from utils.box_ops import bbox2dist, bbox_iou
 
 
 
@@ -14,12 +14,15 @@ class Criterion(object):
         self.cfg = cfg
         self.device = device
         self.num_classes = num_classes
+        self.reg_max = cfg['reg_max']
+        self.use_dfl = cfg['reg_max'] > 1
         # loss
-        self.cls_lossf = ClassificationLoss(cfg)
-        self.reg_lossf = RegressionLoss(num_classes)
+        self.cls_lossf = ClassificationLoss(cfg, reduction='none')
+        self.reg_lossf = RegressionLoss(num_classes, cfg['reg_max'] - 1, self.use_dfl)
         # loss weight
         self.loss_cls_weight = cfg['loss_cls_weight']
-        self.loss_box_weight = cfg['loss_box_weight']
+        self.loss_iou_weight = cfg['loss_iou_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
         # matcher
         matcher_config = cfg['matcher']
         self.matcher = TaskAlignedAssigner(
@@ -30,7 +33,7 @@ class Criterion(object):
             )
 
 
-    def __call__(self, outputs, targets, epoch=0):        
+    def __call__(self, outputs, targets):        
         """
             outputs['pred_cls']: List(Tensor) [B, M, C]
             outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
@@ -44,12 +47,14 @@ class Criterion(object):
         """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
+        strides = outputs['stride_tensor']
         anchors = outputs['anchors']
         anchors = torch.cat(anchors, dim=0)
         num_anchors = anchors.shape[0]
 
         # preds: [B, M, C]
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
         box_preds = torch.cat(outputs['pred_box'], dim=1)
         
         # label assignment
@@ -65,15 +70,15 @@ class Criterion(object):
             # check target
             if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
                 # There is no valid gt
-                gt_label = cls_preds.new_full((1, num_anchors), self.num_classes).long()  #[1, M,]
+                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
                 gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
                 gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
-                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
             else:
                 tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
                 tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    gt_label,   #[1, M,]
+                    gt_label,   #[1, M]
                     gt_box,     #[1, M, 4]
                     gt_score,   #[1, M, C]
                     fg_mask,    #[1, M,]
@@ -92,102 +97,144 @@ class Criterion(object):
 
         # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
         fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1,)                   # [BM,]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
         gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
         gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-        num_fgs = max(gt_score_targets.sum(), 1)
-       
+        
         # cls loss
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.cls_lossf(cls_preds, gt_label_targets, gt_score_targets)
+        gt_label_targets = torch.where(
+            fg_masks > 0,
+            gt_label_targets,
+            torch.full_like(gt_label_targets, self.num_classes)
+            )
+        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
+        loss_cls = self.cls_lossf(cls_preds, gt_score_targets, gt_labels_one_hot)
 
         # reg loss
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)                           # [BM, 2]
+        strides = torch.cat(strides, dim=0).unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)  # [BM, 1]
         bbox_weight = gt_score_targets[fg_masks].sum(-1, keepdim=True)                 # [BM, 1]
+        reg_preds = reg_preds.view(-1, 4*self.reg_max)                                 # [BM, 4*(reg_max + 1)]
         box_preds = box_preds.view(-1, 4)                                              # [BM, 4]
-        loss_box = self.reg_lossf(box_preds, gt_bbox_targets, bbox_weight, fg_masks)
+        loss_iou, loss_dfl = self.reg_lossf(
+            pred_regs = reg_preds,
+            pred_boxs = box_preds,
+            anchors = anchors,
+            gt_boxs = gt_bbox_targets,
+            bbox_weight = bbox_weight,
+            fg_masks = fg_masks,
+            strides = strides,
+            )
         
         # normalize loss
-        loss_cls = loss_cls.sum() / num_fgs
-        loss_box = loss_box.sum() / num_fgs
+        gt_score_targets_sum = max(gt_score_targets.sum(), 1)
+        loss_cls = loss_cls.sum() / gt_score_targets_sum
+        loss_iou = loss_iou.sum() / gt_score_targets_sum
+        loss_dfl = loss_dfl.sum() / gt_score_targets_sum
 
         # total loss
         losses = loss_cls * self.loss_cls_weight + \
-                 loss_box * self.loss_box_weight
-        
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                losses = losses
-        )
+                 loss_iou * self.loss_iou_weight
+        if self.use_dfl:
+            losses += loss_dfl * self.loss_dfl_weight
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_iou = loss_iou,
+                    loss_dfl = loss_dfl,
+                    losses = losses
+            )
+        else:
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_iou = loss_iou,
+                    losses = losses
+            )
 
         return loss_dict
     
 
 class ClassificationLoss(nn.Module):
-    def __init__(self, cfg):
+    def __init__(self, cfg, reduction='none'):
         super(ClassificationLoss, self).__init__()
         self.cfg = cfg
+        self.reduction = reduction
+        # For VFL
+        self.alpha = 0.75
+        self.gamma = 2.0
+
+    def varifocalloss(self, pred_logits, gt_score, gt_label, alpha=0.75, gamma=2.0):
+        focal_weight = alpha * pred_logits.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
+        with torch.cuda.amp.autocast(enabled=False):
+            bce_loss = F.binary_cross_entropy_with_logits(
+                pred_logits.float(), gt_score.float(), reduction='none')
+            loss = bce_loss * focal_weight
+
+            if self.reduction == 'sum':
+                loss = loss.sum()
+            elif self.reduction == 'mean':
+                loss = loss.mean()
 
-
-    def quality_focal_loss(self, pred_cls, gt_label, gt_score, beta=2.0):
-        # Quality FocalLoss
-        """
-            pred_cls: (torch.Tensor): [N, C]
-            gt_label: (torch.Tensor): [N,]
-            gt_score: (torch.Tensor): [N, C]
-        """
-        gt_label = gt_label.long()
-        gt_score = gt_score[torch.arange(gt_label.shape[0]), gt_label]
-
-        pred_sigmoid = pred_cls.sigmoid()
-        scale_factor = pred_sigmoid
-        zerolabel = scale_factor.new_zeros(pred_cls.shape)
-
-        ce_loss = F.binary_cross_entropy_with_logits(
-            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
-        
-        bg_class_ind = pred_cls.shape[-1]
-        pos = ((gt_label >= 0) & (gt_label < bg_class_ind)).nonzero().squeeze(1)
-        pos_label = gt_label[pos].long()
-
-        scale_factor = gt_score[pos] - pred_sigmoid[pos, pos_label]
-
-        ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
-            pred_cls[pos, pos_label], gt_score[pos],
-            reduction='none') * scale_factor.abs().pow(beta)
-
-        return ce_loss
-    
+        return loss
 
     def binary_cross_entropy(self, pred_logits, gt_score):
         loss = F.binary_cross_entropy_with_logits(
-            pred_logits, gt_score, reduction='none')
+            pred_logits.float(), gt_score.float(), reduction='none')
+
+        if self.reduction == 'sum':
+            loss = loss.sum()
+        elif self.reduction == 'mean':
+            loss = loss.mean()
 
         return loss
 
 
-    def forward(self, pred_logits, gt_label, gt_score):
+    def forward(self, pred_logits, gt_score, gt_label):
         if self.cfg['cls_loss'] == 'bce':
-            loss = self.binary_cross_entropy(pred_logits, gt_score)
-        elif self.cfg['cls_loss'] == 'qfl':
-            loss = self.quality_focal_loss(pred_logits, gt_label, gt_score)
-            
-        return loss
+            return self.binary_cross_entropy(pred_logits, gt_score)
+        elif self.cfg['cls_loss'] == 'vfl':
+            return self.varifocalloss(pred_logits, gt_score, gt_label, self.alpha, self.gamma)
 
 
 class RegressionLoss(nn.Module):
-    def __init__(self, num_classes):
+    def __init__(self, num_classes, reg_max, use_dfl):
         super(RegressionLoss, self).__init__()
         self.num_classes = num_classes
+        self.reg_max = reg_max
+        self.use_dfl = use_dfl
+
+
+    def df_loss(self, pred_regs, target):
+        gt_left = target.to(torch.long)
+        gt_right = gt_left + 1
+        weight_left = gt_right.to(torch.float) - target
+        weight_right = 1 - weight_left
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss = (loss_left + loss_right).mean(-1, keepdim=True)
+        
+        return loss
 
 
-    def forward(self, pred_boxs, gt_boxs, bbox_weight, fg_masks):
+    def forward(self, pred_regs, pred_boxs, anchors, gt_boxs, bbox_weight, fg_masks, strides):
         """
         Input:
+            pred_regs: (Tensor) [BM, 4*(reg_max + 1)]
             pred_boxs: (Tensor) [BM, 4]
+            anchors: (Tensor) [BM, 2]
             gt_boxs: (Tensor) [BM, 4]
             bbox_weight: (Tensor) [BM, 1]
             fg_masks: (Tensor) [BM,]
+            strides: (Tensor) [BM, 1]
         """
         # select positive samples mask
         num_pos = fg_masks.sum()
@@ -203,10 +250,23 @@ class RegressionLoss(nn.Module):
                             CIoU=True)
             loss_iou = (1.0 - ious) * bbox_weight
                
+            # dfl loss
+            if self.use_dfl:
+                pred_regs_pos = pred_regs[fg_masks]
+                gt_boxs_s = gt_boxs / strides
+                anchors_s = anchors / strides
+                gt_ltrb_s = bbox2dist(anchors_s, gt_boxs_s, self.reg_max)
+                gt_ltrb_s_pos = gt_ltrb_s[fg_masks]
+                loss_dfl = self.df_loss(pred_regs_pos, gt_ltrb_s_pos)
+                loss_dfl *= bbox_weight
+            else:
+                loss_dfl = pred_regs.sum() * 0.
+
         else:
-            loss_iou = pred_boxs.sum() * 0.
+            loss_iou = pred_regs.sum() * 0.
+            loss_dfl = pred_regs.sum() * 0.
 
-        return loss_iou
+        return loss_iou, loss_dfl
 
 
 def build_criterion(cfg, device, num_classes):

+ 62 - 30
models/detectors/yolox_plus/yolox_plus.py

@@ -1,6 +1,7 @@
 # --------------- Torch components ---------------
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 
 # --------------- Model components ---------------
 from .yolox_plus_backbone import build_backbone
@@ -28,6 +29,7 @@ class YoloxPlus(nn.Module):
         self.cfg = cfg
         self.device = device
         self.stride = cfg['stride']
+        self.reg_max = cfg['reg_max']
         self.num_classes = num_classes
         self.trainable = trainable
         self.conf_thresh = conf_thresh
@@ -36,6 +38,11 @@ class YoloxPlus(nn.Module):
         self.deploy = deploy
         
         # ---------------------- Network Parameters ----------------------
+        ## ----------- proj_conv ------------
+        self.proj = nn.Parameter(torch.linspace(0, cfg['reg_max'], cfg['reg_max']), requires_grad=False)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
+        self.proj_conv.weight = nn.Parameter(self.proj.view([1, cfg['reg_max'], 1, 1]).clone().detach(), requires_grad=False)
+
         ## ----------- Backbone -----------
         self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
 
@@ -143,31 +150,39 @@ class YoloxPlus(nn.Module):
             cls_pred, reg_pred = head(feat)
 
             # anchors: [M, 2]
-            fmp_size = cls_pred.shape[-2:]
+            B, _, H, W = reg_pred.size()
+            fmp_size = [H, W]
             anchors = self.generate_anchors(level, fmp_size)
 
-            # [1, C, H, W] -> [H, W, C] -> [M, C]
-            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
-            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
-
-            # decode bbox
-            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-            pred_x1y1 = ctr_pred - wh_pred * 0.5
-            pred_x2y2 = ctr_pred + wh_pred * 0.5
-            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+            # process preds
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+            # ----------------------- Decode bbox -----------------------
+            B, M = reg_pred.shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = anchors[None] - reg_pred[..., :2] * self.stride[level]
+            x2y2_pred = anchors[None] + reg_pred[..., 2:] * self.stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
 
             # collect preds
-            all_cls_preds.append(cls_pred)
-            all_box_preds.append(box_pred)
+            all_cls_preds.append(cls_pred[0])
+            all_box_preds.append(box_pred[0])
 
         if self.deploy:
+            # no post process
             cls_preds = torch.cat(all_cls_preds, dim=0)
-            box_preds = torch.cat(all_box_preds, dim=0)
-            scores = cls_preds.sigmoid()
-            bboxes = box_preds
+            box_pred = torch.cat(all_box_preds, dim=0)
             # [n_anchors_all, 4 + C]
-            outputs = torch.cat([bboxes, scores], dim=-1)
+            outputs = torch.cat([box_pred, cls_preds.sigmoid()], dim=-1)
 
             return outputs
 
@@ -195,36 +210,53 @@ class YoloxPlus(nn.Module):
             # ---------------- Heads ----------------
             all_anchors = []
             all_cls_preds = []
+            all_reg_preds = []
             all_box_preds = []
+            all_strides = []
             for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
                 # ---------------- Pred ----------------
                 cls_pred, reg_pred = head(feat)
 
-                # generate anchor boxes: [M, 4]
                 B, _, H, W = cls_pred.size()
                 fmp_size = [H, W]
+                # generate anchor boxes: [M, 4]
                 anchors = self.generate_anchors(level, fmp_size)
+                # stride tensor: [M, 1]
+                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
                 
                 # process preds
-                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-
-                # decode bbox
-                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
-                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
-                pred_x1y1 = ctr_pred - wh_pred * 0.5
-                pred_x2y2 = ctr_pred + wh_pred * 0.5
-                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+                # ----------------------- Decode bbox -----------------------
+                B, M = reg_pred.shape[:2]
+                # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+                reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
+                # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+                reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
+                # [B, reg_max, 4, M] -> [B, 1, 4, M]
+                reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
+                # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+                reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
+                ## tlbr -> xyxy
+                x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.stride[level]
+                x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.stride[level]
+                box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+                # collect preds
                 all_cls_preds.append(cls_pred)
+                all_reg_preds.append(reg_pred)
                 all_box_preds.append(box_pred)
                 all_anchors.append(anchors)
+                all_strides.append(stride_tensor)
             
             # output dict
             outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
                        "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
-                       "anchors": all_anchors,           # List(Tensor) [B, M, 2]
-                       'strides': self.stride}           # List(Int) [8, 16, 32]
+                       "anchors": all_anchors,           # List(Tensor) [M, 2]
+                       "strides": self.stride,           # List(Int) = [8, 16, 32]
+                       "stride_tensor": all_strides      # List(Tensor) [M, 1]
+                       }
             
             return outputs 

+ 1 - 1
models/detectors/yolox_plus/yolox_plus_head.py

@@ -59,7 +59,7 @@ class DecoupledHead(nn.Module):
 
         ## Pred
         self.cls_pred = nn.Conv2d(self.cls_out_dim, num_classes, kernel_size=1) 
-        self.reg_pred = nn.Conv2d(self.reg_out_dim, 4, kernel_size=1) 
+        self.reg_pred = nn.Conv2d(self.reg_out_dim, 4*cfg['reg_max'], kernel_size=1) 
 
 
     def forward(self, x):