yjh0410 пре 2 година
родитељ
комит
39f76994f6

+ 11 - 13
config/model_config/yolox2_config.py

@@ -41,30 +41,28 @@ yolox2_cfg = {
         'multi_scale': [0.5, 1.5],   # 320 -> 960
         'trans_type': 'yolox_nano',
         # ---------------- Assignment config ----------------
-        ## matcher
-        'matcher': {'topk': 10,
-                    'alpha': 0.5,
-                    'beta': 6.0},
+        'matcher': {'soft_center_radius': 3.0,
+                    'topk_candicate': 13,
+                    'iou_weight': 3.0},
         # ---------------- Loss config ----------------
         ## loss weight
-        'loss_obj_weight': 1.0,
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
-        'trainer_type': 'rtmdet',
+        'trainer_type': 'yolo',
         ## optimizer
-        'optimizer': 'adamw',      # optional: sgd, AdamW
-        'momentum': None,          # SGD: 0.9;      AdamW: None
-        'weight_decay': 5e-2,      # SGD: 5e-4;     AdamW: 5e-2
-        'clip_grad': 15,           # SGD: 10.0;     AdamW: -1
+        'optimizer': 'sgd',        # optional: sgd, AdamW
+        'momentum': 0.9,           # SGD: 0.9;      AdamW: None
+        'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 10.0,         # SGD: 10.0;     AdamW: -1
         ## model EMA
-        'ema_decay': 0.9998,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
         'ema_tau': 2000,
         ## lr schedule
-        'scheduler': 'linear',
-        'lr0': 0.001,               # SGD: 0.01;     AdamW: 0.001
+        'scheduler': 'cos_linear',
+        'lr0': 0.01,               # SGD: 0.01;     AdamW: 0.001
         'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.01
         'warmup_momentum': 0.8,
         'warmup_bias_lr': 0.1,

+ 0 - 5
models/detectors/yolox2/build.py

@@ -33,11 +33,6 @@ def build_yolox2(args, cfg, device, num_classes=80, trainable=False, deploy=Fals
     # Init head
     init_prob = 0.01
     bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-    ## obj pred
-    for obj_pred in model.obj_preds:
-        b = obj_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
     ## cls pred
     for cls_pred in model.cls_preds:
         b = cls_pred.bias.view(1, -1)

+ 68 - 87
models/detectors/yolox2/loss.py

@@ -1,6 +1,6 @@
 import torch
 import torch.nn.functional as F
-from .matcher import TaskAlignedAssigner
+from .matcher import AlignedSimOTA
 from utils.box_ops import get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
@@ -15,37 +15,48 @@ class Criterion(object):
         self.device = device
         self.num_classes = num_classes
         # loss weight
-        self.loss_obj_weight = cfg['loss_obj_weight']
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_box_weight = cfg['loss_box_weight']
         # matcher
         matcher_config = cfg['matcher']
-        self.matcher = TaskAlignedAssigner(
-            topk=matcher_config['topk'],
+        self.matcher = AlignedSimOTA(
             num_classes=num_classes,
-            alpha=matcher_config['alpha'],
-            beta=matcher_config['beta']
+            soft_center_radius=matcher_config['soft_center_radius'],
+            topk=matcher_config['topk_candicate'],
+            iou_weight=matcher_config['iou_weight']
             )
+     
+     
+    def loss_classes(self, pred_cls, target, beta=2.0):
+        """
+            Quality Focal Loss
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N,)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
 
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        pos_label = label[pos].long()
 
-    def loss_objectness(self, pred_obj, gt_obj):
-        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
-
-        return loss_obj
-    
+        scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
 
-    def loss_classes(self, pred_cls, gt_label):
-        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+        ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+            pred_cls[pos, pos_label], score[pos],
+            reduction='none') * scale_factor.abs().pow(beta)
 
-        return loss_cls
+        return ce_loss
 
 
     def loss_bboxes(self, pred_box, gt_box):
         # regression loss
-        ious = get_ious(pred_box,
-                        gt_box,
-                        box_mode="xyxy",
-                        iou_type='giou')
+        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
         loss_box = 1.0 - ious
 
         return loss_box
@@ -54,100 +65,70 @@ class Criterion(object):
     def __call__(self, outputs, targets):        
         """
             outputs['pred_cls']: List(Tensor) [B, M, C]
-            outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
-            outputs['pred_boxs']: List(Tensor) [B, M, 4]
-            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
             outputs['strides']: List(Int) [8, 16, 32] output stride
-            outputs['stride_tensor']: List(Tensor) [M, 1]
             targets: (List) [dict{'boxes': [...], 
                                  'labels': [...], 
                                  'orig_size': ...}, ...]
         """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
-        anchors = torch.cat(outputs['anchors'], dim=0)
-        num_anchors = anchors.shape[0]
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
 
         # preds: [B, M, C]
-        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
         box_preds = torch.cat(outputs['pred_box'], dim=1)
-        
-        # label assignment
-        gt_label_targets = []
-        gt_score_targets = []
-        gt_bbox_targets = []
-        fg_masks = []
 
+        cls_targets = []
+        box_targets = []
+        assign_metrics = []
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
-            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
-
-            # check target
-            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
-                # There is no valid gt
-                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
-                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
-                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
-                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
-            else:
-                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
-                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
-                (
-                    gt_label,   #[1, M]
-                    gt_box,     #[1, M, 4]
-                    gt_score,   #[1, M, C]
-                    fg_mask,    #[1, M,]
-                    _
-                ) = self.matcher(
-                    pd_scores = torch.sqrt(obj_preds[batch_idx:batch_idx+1].sigmoid() * \
-                                           cls_preds[batch_idx:batch_idx+1].sigmoid()).detach(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
-                    anc_points = anchors,
-                    gt_labels = tgt_labels,
-                    gt_bboxes = tgt_boxs
-                    )
-            gt_label_targets.append(gt_label)
-            gt_score_targets.append(gt_score)
-            gt_bbox_targets.append(gt_box)
-            fg_masks.append(fg_mask)
-
-        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
-        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
-        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
-        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-
-        obj_targets = fg_masks.unsqueeze(-1)        # [M, 1]
-        cls_targets = gt_score_targets[fg_masks]    # [Mp, C]
-        box_targets = gt_bbox_targets[fg_masks]     # [Mp, 4]
-        num_fgs = fg_masks.sum()
-        
+            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
+            # label assignment
+            assigned_result = self.matcher(fpn_strides=fpn_strides,
+                                           anchors=anchors,
+                                           pred_cls=cls_preds[batch_idx].detach(),
+                                           pred_box=box_preds[batch_idx].detach(),
+                                           gt_labels=tgt_labels,
+                                           gt_bboxes=tgt_bboxes
+                                           )
+            cls_targets.append(assigned_result['assigned_labels'])
+            box_targets.append(assigned_result['assigned_bboxes'])
+            assign_metrics.append(assigned_result['assign_metrics'])
+
+        cls_targets = torch.cat(cls_targets, dim=0)
+        box_targets = torch.cat(box_targets, dim=0)
+        assign_metrics = torch.cat(assign_metrics, dim=0)
+
+        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+        bg_class_ind = self.num_classes
+        pos_inds = ((cls_targets >= 0)
+                    & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
+        # num_fgs = assign_metrics.sum()
+        num_fgs = pos_inds.size(0)
+
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
-        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
-
-        # obj loss
-        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
-        loss_obj = loss_obj.sum() / num_fgs
+        num_fgs = max(num_fgs / get_world_size(), 1.0)
         
         # cls loss
-        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
-        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
         loss_cls = loss_cls.sum() / num_fgs
 
         # regression loss
-        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
-        loss_box = loss_box.sum() / num_fgs
+        box_preds_pos = box_preds.view(-1, 4)[pos_inds]
+        box_targets_pos = box_targets[pos_inds]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
+        loss_box = loss_box.sum() / box_preds_pos.shape[0]
 
         # total loss
-        losses = self.loss_obj_weight * loss_obj + \
-                 self.loss_cls_weight * loss_cls + \
+        losses = self.loss_cls_weight * loss_cls + \
                  self.loss_box_weight * loss_box
 
         loss_dict = dict(
-                loss_obj = loss_obj,
                 loss_cls = loss_cls,
                 loss_box = loss_box,
                 losses = losses

+ 162 - 189
models/detectors/yolox2/matcher.py

@@ -1,203 +1,176 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from utils.box_ops import bbox_iou
-
-
-# -------------------------- Task Aligned Assigner --------------------------
-class TaskAlignedAssigner(nn.Module):
-    def __init__(self,
-                 topk=10,
-                 num_classes=80,
-                 alpha=0.5,
-                 beta=6.0, 
-                 eps=1e-9):
-        super(TaskAlignedAssigner, self).__init__()
-        self.topk = topk
-        self.num_classes = num_classes
-        self.bg_idx = num_classes
-        self.alpha = alpha
-        self.beta = beta
-        self.eps = eps
-
-    @torch.no_grad()
-    def forward(self,
-                pd_scores,
-                pd_bboxes,
-                anc_points,
-                gt_labels,
-                gt_bboxes):
-        """This code referenced to
-           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
-        Args:
-            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
-            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
-            anc_points (Tensor): shape(num_total_anchors, 2)
-            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
-            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-        Returns:
-            target_labels (Tensor): shape(bs, num_total_anchors)
-            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
-            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
-            fg_mask (Tensor): shape(bs, num_total_anchors)
-        """
-        self.bs = pd_scores.size(0)
-        self.n_max_boxes = gt_bboxes.size(1)
-
-        mask_pos, align_metric, overlaps = self.get_pos_mask(
-            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
-
-        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
-            mask_pos, overlaps, self.n_max_boxes)
-
-        # assigned target
-        target_labels, target_bboxes, target_scores = self.get_targets(
-            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
-
-        # normalize
-        align_metric *= mask_pos
-        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
-        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
-        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
-        target_scores = target_scores * norm_align_metric
-
-        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
-
-
-    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
-        # get anchor_align metric, (b, max_num_obj, h*w)
-        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
-        # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
-        # get topk_metric mask, (b, max_num_obj, h*w)
-        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
-        # merge all mask to a final mask, (b, max_num_obj, h*w)
-        mask_pos = mask_topk * mask_in_gts
-
-        return mask_pos, align_metric, overlaps
-
-
-    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
-        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
-        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
-        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
-        # get the scores of each grid for each gt cls
-        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
+# ---------------------------------------------------------------------
+# Copyright (c) OpenMMLab. All rights reserved.
+# ---------------------------------------------------------------------
 
-        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False).squeeze(3).clamp(0)
-        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
-
-        return align_metric, overlaps
 
+import torch
+import torch.nn.functional as F
+from utils.box_ops import *
 
-    def select_topk_candidates(self, metrics, largest=True):
-        """
-        Args:
-            metrics: (b, max_num_obj, h*w).
-            topk_mask: (b, max_num_obj, topk) or None
-        """
 
-        num_anchors = metrics.shape[-1]  # h*w
-        # (b, max_num_obj, topk)
-        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
-        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
-        # (b, max_num_obj, topk)
-        topk_idxs[~topk_mask] = 0
-        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
-        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
-        # filter invalid bboxes
-        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
-        return is_in_topk.to(metrics.dtype)
+# RTMDet SimOTA
+class AlignedSimOTA(object):
+    """
+        This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
+    """
+    def __init__(self, num_classes, soft_center_radius=3.0, topk=13, iou_weight=3.0):
+        self.num_classes = num_classes
+        self.soft_center_radius = soft_center_radius
+        self.topk = topk
+        self.iou_weight = iou_weight
 
 
-    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box, 
+                 gt_labels,
+                 gt_bboxes):
+        # [M,]
+        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_gt = len(gt_labels)
+
+        # check gt
+        if num_gt == 0 or gt_bboxes.max().item() == 0.:
+            return {
+                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
+                                                      self.num_classes,
+                                                      dtype=torch.long),
+                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
+                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
+            }
+        
+        # get inside points: [N, M]
+        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
+        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
+
+        # ----------------------------------- soft center prior -----------------------------------
+        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
+        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
+                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
+        distance = distance * valid_mask.unsqueeze(0)
+        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
+
+        # ----------------------------------- regression cost -----------------------------------
+        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * self.iou_weight
+
+        # ----------------------------------- classification cost -----------------------------------
+        ## select the predicted scores corresponded to the gt_labels
+        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
+        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
+        ## scale factor
+        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
+        ## cls cost
+        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
+            pairwise_pred_scores, pair_wise_ious,
+            reduction="none") * scale_factor # [N, M]
+            
+        del pairwise_pred_scores
+
+        ## foreground cost matrix
+        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
+        max_pad_value = torch.ones_like(cost_matrix) * 1e9
+        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
+                                  cost_matrix, max_pad_value)
+
+        # ----------------------------------- dynamic label assignment -----------------------------------
+        (
+            matched_pred_ious,
+            matched_gt_inds,
+            fg_mask_inboxes
+        ) = self.dynamic_k_matching(
+            cost_matrix,
+            pair_wise_ious,
+            num_gt
+            )
+        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
+
+        # -----------------------------------process assigned labels -----------------------------------
+        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
+                                             self.num_classes)  # [M,]
+        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
+        assigned_labels = assigned_labels.long()  # [M,]
+
+        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
+        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
+
+        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M, 4]
+        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M, 4]
+
+        assigned_dict = dict(
+            assigned_labels=assigned_labels,
+            assigned_bboxes=assigned_bboxes,
+            assign_metrics=assign_metrics
+            )
+        
+        return assigned_dict
+
+
+    def find_inside_points(self, gt_bboxes, anchors):
         """
-        Args:
-            gt_labels: (b, max_num_obj, 1)
-            gt_bboxes: (b, max_num_obj, 4)
-            target_gt_idx: (b, h*w)
-            fg_mask: (b, h*w)
+            gt_bboxes: Tensor -> [N, 2]
+            anchors:   Tensor -> [M, 2]
         """
+        num_anchors = anchors.shape[0]
+        num_gt = gt_bboxes.shape[0]
 
-        # assigned target labels, (b, 1)
-        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
-        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
-        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
+        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
 
-        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
-        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+        # offset
+        lt = anchors_expand - gt_bboxes_expand[..., :2]
+        rb = gt_bboxes_expand[..., 2:] - anchors_expand
+        bbox_deltas = torch.cat([lt, rb], dim=-1)
 
-        # assigned target scores
-        target_labels.clamp(0)
-        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
-        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
-        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+        is_in_gts = bbox_deltas.min(dim=-1).values > 0
 
-        return target_labels, target_bboxes, target_scores
+        return is_in_gts
     
 
-# -------------------------- Basic Functions --------------------------
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
-    """select the positive anchors's center in gt
-    Args:
-        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
-        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    n_anchors = xy_centers.size(0)
-    bs, n_max_boxes, _ = gt_bboxes.size()
-    _gt_bboxes = gt_bboxes.reshape([-1, 4])
-    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
-    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
-    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
-    b_lt = xy_centers - gt_bboxes_lt
-    b_rb = gt_bboxes_rb - xy_centers
-    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
-    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
-    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
-
-
-def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
-    """if an anchor box is assigned to multiple gts,
-        the one with the highest iou will be selected.
-    Args:
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    Return:
-        target_gt_idx (Tensor): shape(bs, num_total_anchors)
-        fg_mask (Tensor): shape(bs, num_total_anchors)
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    fg_mask = mask_pos.sum(axis=-2)
-    if fg_mask.max() > 1:
-        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
-        max_overlaps_idx = overlaps.argmax(axis=1)
-        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
-        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
-        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
-        fg_mask = mask_pos.sum(axis=-2)
-    target_gt_idx = mask_pos.argmax(axis=-2)
-    return target_gt_idx, fg_mask , mask_pos
-
-
-def iou_calculator(box1, box2, eps=1e-9):
-    """Calculate iou for batch
-    Args:
-        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
-        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
-    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
-    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
-    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
-    x1y1 = torch.maximum(px1y1, gx1y1)
-    x2y2 = torch.minimum(px2y2, gx2y2)
-    overlap = (x2y2 - x1y1).clip(0).prod(-1)
-    area1 = (px2y2 - px1y1).clip(0).prod(-1)
-    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
-    union = area1 + area2 - overlap + eps
-
-    return overlap / union
+    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
+        """Use IoU and matching cost to calculate the dynamic top-k positive
+        targets.
+
+        Args:
+            cost_matrix (Tensor): Cost matrix.
+            pairwise_ious (Tensor): Pairwise iou matrix.
+            num_gt (int): Number of gt.
+            valid_mask (Tensor): Mask for valid bboxes.
+        Returns:
+            tuple: matched ious and gt indexes.
+        """
+        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
+        # select candidate topk ious for dynamic-k calculation
+        candidate_topk = min(self.topk, pairwise_ious.size(1))
+        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
+        # calculate dynamic k for each gt
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+
+        # sorting the batch cost matirx is faster than topk
+        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        for gt_idx in range(num_gt):
+            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
+            matching_matrix[gt_idx, :][topk_ids] = 1
+
+        del topk_ious, dynamic_ks, topk_ids
+
+        prior_match_gt_mask = matching_matrix.sum(0) > 1
+        if prior_match_gt_mask.sum() > 0:
+            cost_min, cost_argmin = torch.min(
+                cost_matrix[:, prior_match_gt_mask], dim=0)
+            matching_matrix[:, prior_match_gt_mask] *= 0
+            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+
+        # get foreground mask inside box and center prior
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        matched_pred_ious = (matching_matrix *
+                             pairwise_ious).sum(0)[fg_mask_inboxes]
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+
+        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes

+ 10 - 27
models/detectors/yolox2/yolox2.py

@@ -49,13 +49,9 @@ class YOLOX2(nn.Module):
         self.fpn_dims = self.fpn.out_dim
 
         ## ----------- Heads -----------
-        self.group_heads = build_head(cfg, self.fpn_dims, self.head_dim, num_classes) 
+        self.heads = build_head(cfg, self.fpn_dims, self.head_dim, num_classes) 
 
         ## ----------- Preds -----------
-        self.obj_preds = nn.ModuleList(
-                            [nn.Conv2d(self.head_dim, 1, kernel_size=1) 
-                                for _ in range(len(self.stride))
-                              ]) 
         self.cls_preds = nn.ModuleList(
                             [nn.Conv2d(self.head_dim, num_classes, kernel_size=1) 
                                 for _ in range(len(self.stride))
@@ -84,21 +80,19 @@ class YOLOX2(nn.Module):
         return anchors
         
     ## post-process
-    def post_process(self, obj_preds, cls_preds, box_preds):
+    def post_process(self, cls_preds, box_preds):
         """
         Input:
-            obj_preds: List(Tensor) [[H x W, 1], ...]
             cls_preds: List(Tensor) [[H x W, C], ...]
             box_preds: List(Tensor) [[H x W, 4], ...]
-            anchors:   List(Tensor) [[H x W, 2], ...]
         """
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
-            # (H x W x KA x C,)
-            scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            # (H x W x C,)
+            scores_i = cls_pred_i.sigmoid().flatten()
 
             # Keep top k top scoring indices only.
             num_topk = min(self.topk, box_pred_i.size(0))
@@ -151,15 +145,13 @@ class YOLOX2(nn.Module):
         pyramid_feats = self.fpn(pyramid_feats)
 
         # ---------------- Heads ----------------
-        cls_feats, reg_feats = self.group_heads(pyramid_feats)
+        cls_feats, reg_feats = self.heads(pyramid_feats)
 
         # ---------------- Preds ----------------
-        all_obj_preds = []
         all_cls_preds = []
         all_box_preds = []
         for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
             # prediction
-            obj_pred = self.obj_preds[level](reg_feat)
             cls_pred = self.cls_preds[level](cls_feat)
             reg_pred = self.reg_preds[level](reg_feat)
             
@@ -168,7 +160,6 @@ class YOLOX2(nn.Module):
             anchors = self.generate_anchors(level, fmp_size)
             
             # [1, C, H, W] -> [H, W, C] -> [M, C]
-            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
             cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
             reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
 
@@ -179,15 +170,13 @@ class YOLOX2(nn.Module):
             pred_x2y2 = ctr_pred + wh_pred * 0.5
             box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
-            all_obj_preds.append(obj_pred)
             all_cls_preds.append(cls_pred)
             all_box_preds.append(box_pred)
 
         if self.deploy:
-            obj_preds = torch.cat(all_obj_preds, dim=0)
             cls_preds = torch.cat(all_cls_preds, dim=0)
             box_preds = torch.cat(all_box_preds, dim=0)
-            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            scores = cls_preds.sigmoid()
             bboxes = box_preds
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
@@ -195,8 +184,7 @@ class YOLOX2(nn.Module):
             return outputs
         else:
             # post process
-            bboxes, scores, labels = self.post_process(
-                all_obj_preds, all_cls_preds, all_box_preds)
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
         
             return bboxes, scores, labels
 
@@ -216,16 +204,14 @@ class YOLOX2(nn.Module):
             pyramid_feats = self.fpn(pyramid_feats)
 
             # ---------------- Heads ----------------
-            cls_feats, reg_feats = self.group_heads(pyramid_feats)
+            cls_feats, reg_feats = self.heads(pyramid_feats)
 
             # ---------------- Preds ----------------
             all_anchors = []
-            all_obj_preds = []
             all_cls_preds = []
             all_box_preds = []
             for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
                 # prediction
-                obj_pred = self.obj_preds[level](reg_feat)
                 cls_pred = self.cls_preds[level](cls_feat)
                 reg_pred = self.reg_preds[level](reg_feat)
 
@@ -235,7 +221,6 @@ class YOLOX2(nn.Module):
                 anchors = self.generate_anchors(level, fmp_size)
                 
                 # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
 
@@ -246,14 +231,12 @@ class YOLOX2(nn.Module):
                 pred_x2y2 = ctr_pred + wh_pred * 0.5
                 box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
-                all_obj_preds.append(obj_pred)
                 all_cls_preds.append(cls_pred)
                 all_box_preds.append(box_pred)
                 all_anchors.append(anchors)
             
             # output dict
-            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
-                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+            outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
                        "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
                        "anchors": all_anchors,           # List(Tensor) [B, M, 2]
                        'strides': self.stride}           # List(Int) [8, 16, 32]

+ 2 - 2
utils/solver/lr_scheduler.py

@@ -12,11 +12,11 @@ def build_lr_scheduler(cfg, optimizer, epochs):
         
     elif cfg['scheduler'] == 'linear':
         lf = lambda x: (1 - x / epochs) * (1.0 - cfg['lrf']) + cfg['lrf']
-
+    elif cfg['scheduler'] == 'cos_linear':
+            lf = lambda x: (1 - x / epochs) * (1.0 - cfg['lrf']) + cfg['lrf'] if x > epochs // 2 else ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg['lrf'] - 1) + 1
     else:
         print('unknown lr scheduler.')
         exit(0)
-
     scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
 
     return scheduler, lf