Browse Source

modify matcher & criterion

yjh0410 1 year ago
parent
commit
07f0d09089
2 changed files with 84 additions and 89 deletions
  1. 33 16
      yolo/models/yolov6/loss.py
  2. 51 73
      yolo/models/yolov6/matcher.py

+ 33 - 16
yolo/models/yolov6/loss.py

@@ -23,16 +23,26 @@ class SetCriterion(object):
                                            beta            = cfg.tal_beta
                                            )
 
-    def loss_classes(self, pred_logits, gt_score):
-        alpha, gamma = 0.75, 2.0
-        pred_sigmoid = pred_logits.sigmoid()
-        focal_weight = gt_score * (gt_score > 0.0).float() + \
-            alpha * (pred_sigmoid - gt_score).abs().pow(gamma) * \
-            (gt_score <= 0.0).float()
+    def loss_classes(self, pred_cls, labels, scores):
+        # compute bce loss
+        alpha = 0.75
+        gamma = 2.0
+        # pred and target should be of the same size
+        bg_class_ind = pred_cls.shape[-1]
+        pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
+
+        new_scores = pred_cls.new_zeros(pred_cls.shape)
+        pos_labels = labels[pos_inds]
+        new_scores[pos_inds, pos_labels] = scores[pos_inds].clone().detach()
+
+        pred_sigmoid = pred_cls.sigmoid()
+        focal_weight = new_scores * (new_scores > 0.0).float() + \
+            alpha * (pred_sigmoid - new_scores).abs().pow(gamma) * \
+            (new_scores <= 0.0).float()
         
         loss_cls = F.binary_cross_entropy_with_logits(
-            pred_logits, gt_score, reduction='none') * focal_weight
-
+            pred_cls, new_scores, reduction='none') * focal_weight
+    
         return loss_cls
     
     def loss_bboxes(self, pred_box, gt_box, bbox_weight):
@@ -62,12 +72,13 @@ class SetCriterion(object):
         anchors = torch.cat(outputs['anchors'], dim=0)
         
         # --------------- label assignment ---------------
+        gt_label_targets = []
         gt_score_targets = []
         gt_bbox_targets = []
         fg_masks = []
-        for bid in range(bs):
-            tgt_labels = targets[bid]["labels"].to(device)     # [Mp,]
-            tgt_boxs = targets[bid]["boxes"].to(device)        # [Mp, 4]
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
 
             # check target
             if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
@@ -80,36 +91,42 @@ class SetCriterion(object):
                 tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
                 tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    _,          # [1, M]
+                    gt_label,   # 
                     gt_box,     # [1, M, 4]
                     gt_score,   # [1, M, C]
                     fg_mask,    # [1, M,]
                     _
                 ) = self.matcher(
-                    pd_scores = cls_preds[bid:bid+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[bid:bid+1].detach(),
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
                     anc_points = anchors,
                     gt_labels = tgt_labels,
                     gt_bboxes = tgt_boxs
                     )
+            gt_label_targets.append(gt_label)
             gt_score_targets.append(gt_score)
             gt_bbox_targets.append(gt_box)
             fg_masks.append(fg_mask)
 
         # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
         fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
         gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
         gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
         num_fgs = gt_score_targets.sum()
-        
+
         # Average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
         # ------------------ Classification loss ------------------
+        target_labels = torch.where(fg_masks > 0, gt_label_targets,
+                                    torch.full_like(gt_label_targets, self.num_classes))
+        target_scores = gt_score_targets.new_zeros(gt_score_targets.shape[0])
+        target_scores[fg_masks] = gt_score_targets[fg_masks, target_labels[fg_masks]]
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
+        loss_cls = self.loss_classes(cls_preds, target_labels, target_scores)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------

+ 51 - 73
yolo/models/yolov6/matcher.py

@@ -3,14 +3,15 @@ import torch.nn as nn
 from utils.box_ops import bbox_iou
 
 
-# -------------------------- Task Aligned Assigner --------------------------
+# ------------------ Task Aligned Assigner ------------------
 class TaskAlignedAssigner(nn.Module):
     def __init__(self,
                  num_classes     = 80,
                  topk_candidates = 10,
                  alpha           = 0.5,
                  beta            = 6.0, 
-                 eps             = 1e-9):
+                 eps             = 1e-9,
+                 ):
         super(TaskAlignedAssigner, self).__init__()
         self.topk_candidates = topk_candidates
         self.num_classes = num_classes
@@ -32,7 +33,7 @@ class TaskAlignedAssigner(nn.Module):
         mask_pos, align_metric, overlaps = self.get_pos_mask(
             pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
 
-        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+        target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(
             mask_pos, overlaps, self.n_max_boxes)
 
         # Assigned target
@@ -48,9 +49,55 @@ class TaskAlignedAssigner(nn.Module):
 
         return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
 
+    def select_highest_overlaps(self, mask_pos, overlaps, n_max_boxes):
+        """if an anchor box is assigned to multiple gts,
+            the one with the highest iou will be selected.
+        Args:
+            mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+            overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        Return:
+            target_gt_idx (Tensor): shape(bs, num_total_anchors)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+            mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        """
+        fg_mask = mask_pos.sum(-2)
+        if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+            max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+            fg_mask = mask_pos.sum(-2)
+        # Find each grid serve which gt(index)
+        target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+
+        return target_gt_idx, fg_mask, mask_pos
+
+    def select_candidates_in_gts(self, xy_centers, gt_bboxes, eps=1e-9):
+        """select the positive anchors's center in gt
+        Args:
+            xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+        Return:
+            (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        """
+        n_anchors = xy_centers.size(0)
+        bs, n_max_boxes, _ = gt_bboxes.size()
+        _gt_bboxes = gt_bboxes.reshape([-1, 4])
+        xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+        gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+        gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+        b_lt = xy_centers - gt_bboxes_lt
+        b_rb = gt_bboxes_rb - xy_centers
+        bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+        bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+        return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
     def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
         # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
         # get anchor_align metric, (b, max_num_obj, h*w)
         align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
         # get topk_metric mask, (b, max_num_obj, h*w)
@@ -127,72 +174,3 @@ class TaskAlignedAssigner(nn.Module):
         target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
 
         return target_labels, target_bboxes, target_scores
-    
-
-# -------------------------- Basic Functions --------------------------
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
-    """select the positive anchors's center in gt
-    Args:
-        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
-        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    n_anchors = xy_centers.size(0)
-    bs, n_max_boxes, _ = gt_bboxes.size()
-    _gt_bboxes = gt_bboxes.reshape([-1, 4])
-    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
-    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
-    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
-    b_lt = xy_centers - gt_bboxes_lt
-    b_rb = gt_bboxes_rb - xy_centers
-    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
-    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
-    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
-
-def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
-    """if an anchor box is assigned to multiple gts,
-        the one with the highest iou will be selected.
-    Args:
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    Return:
-        target_gt_idx (Tensor): shape(bs, num_total_anchors)
-        fg_mask (Tensor): shape(bs, num_total_anchors)
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    fg_mask = mask_pos.sum(-2)
-    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
-        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
-        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
-
-        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
-        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
-
-        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
-        fg_mask = mask_pos.sum(-2)
-    # Find each grid serve which gt(index)
-    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
-
-    return target_gt_idx, fg_mask, mask_pos
-
-def iou_calculator(box1, box2, eps=1e-9):
-    """Calculate iou for batch
-    Args:
-        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
-        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
-    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
-    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
-    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
-    x1y1 = torch.maximum(px1y1, gx1y1)
-    x2y2 = torch.minimum(px2y2, gx2y2)
-    overlap = (x2y2 - x1y1).clip(0).prod(-1)
-    area1 = (px2y2 - px1y1).clip(0).prod(-1)
-    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
-    union = area1 + area2 - overlap + eps
-
-    return overlap / union