yjh0410 2 роки тому
батько
коміт
f6753b0e9b

+ 15 - 15
config/model_config/rtcdet_config.py

@@ -14,6 +14,7 @@ rtcdet_cfg = {
         'depth': 0.34,
         'stride': [8, 16, 32],  # P3, P4, P5
         'max_stride': 32,
+        'reg_max': 16,
         ## Neck: SPP
         'neck': 'sppf',
         'neck_expand_ratio': 0.5,
@@ -44,16 +45,15 @@ rtcdet_cfg = {
         'trans_type': 'yolox_pico',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'soft_center_radius': 3.0,
-                    'topk_candidate': 13,
-                    'iou_weight': 3.0
-                    },
+        'matcher': {'center_sampling_radius': 2.5,
+                    'topk_candidate': 10},
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
         'loss_box_aux': True,
         'loss_cls_weight': 1.0,
-        'loss_box_weight': 2.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -70,6 +70,7 @@ rtcdet_cfg = {
         'depth': 0.34,
         'stride': [8, 16, 32],  # P3, P4, P5
         'max_stride': 32,
+        'reg_max': 16,
         ## Neck: SPP
         'neck': 'sppf',
         'neck_expand_ratio': 0.5,
@@ -100,16 +101,15 @@ rtcdet_cfg = {
         'trans_type': 'yolox_small',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'soft_center_radius': 3.0,
-                    'topk_candidate': 13,
-                    'iou_weight': 3.0
-                    },
+        'matcher': {'center_sampling_radius': 2.5,
+                    'topk_candidate': 10},
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
         'loss_box_aux': True,
         'loss_cls_weight': 1.0,
-        'loss_box_weight': 2.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -126,6 +126,7 @@ rtcdet_cfg = {
         'depth': 1.0,
         'stride': [8, 16, 32],  # P3, P4, P5
         'max_stride': 32,
+        'reg_max': 16,
         ## Neck: SPP
         'neck': 'sppf',
         'neck_expand_ratio': 0.5,
@@ -156,16 +157,15 @@ rtcdet_cfg = {
         'trans_type': 'yolox_large',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'soft_center_radius': 3.0,
-                    'topk_candidate': 13,
-                    'iou_weight': 3.0
-                    },
+        'matcher': {'center_sampling_radius': 2.5,
+                    'topk_candidate': 10},
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
         'loss_box_aux': True,
         'loss_cls_weight': 1.0,
-        'loss_box_weight': 2.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },

+ 127 - 85
models/detectors/rtcdet/loss.py

@@ -19,18 +19,17 @@ class Criterion(object):
         # ---------------- Loss weight ----------------
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_box_weight = cfg['loss_box_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
         self.loss_box_aux    = cfg['loss_box_aux']
         # ---------------- Matcher ----------------
         matcher_config = cfg['matcher']
         ## Aligned SimOTA assigner
         self.ota_matcher = AlignedSimOTA(
-            num_classes=num_classes,
-            soft_center_radius=matcher_config['soft_center_radius'],
+            center_sampling_radius=matcher_config['center_sampling_radius'],
             topk_candidate=matcher_config['topk_candidate'],
-            iou_weight=matcher_config['iou_weight']
+            num_classes=num_classes
         )
 
-
     def ema_update(self, name: str, value, initial_value, momentum=0.9):
         if hasattr(self, name):
             old = getattr(self, name)
@@ -40,56 +39,72 @@ class Criterion(object):
         setattr(self, name, new)
         return new
 
+    # ----------------- Loss functions -----------------
+    def loss_classes(self, pred_cls, gt_score, gt_label=None, vfl=False):
+        if vfl:
+            assert gt_label is not None
+            # compute varifocal loss
+            alpha, gamma = 0.75, 2.0
+            focal_weight = alpha * pred_cls.sigmoid().pow(gamma) * (1 - gt_label) + gt_score * gt_label
+            bce_loss = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
+            loss_cls = bce_loss * focal_weight
+        else:
+            # compute bce loss
+            loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
 
-    def loss_classes(self, pred_cls, target, beta=2.0):
-        # Quality FocalLoss
-        """
-            pred_cls: (torch.Tensor): [N, C]。
-            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> [N,], score -> [N,]
-        """
-        label, score = target
-        pred_sigmoid = pred_cls.sigmoid()
-        scale_factor = pred_sigmoid
-        zerolabel = scale_factor.new_zeros(pred_cls.shape)
-
-        ce_loss = F.binary_cross_entropy_with_logits(
-            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
-        
-        bg_class_ind = pred_cls.shape[-1]
-        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
-        pos_label = label[pos].long()
-
-        scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
-
-        ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
-            pred_cls[pos, pos_label], score[pos],
-            reduction='none') * scale_factor.abs().pow(beta)
-
-        return ce_loss
-    
+        return loss_cls
 
-    def loss_bboxes(self, pred_box, gt_box):
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight=None):
         # regression loss
         ious = get_ious(pred_box, gt_box, 'xyxy', 'giou')
         loss_box = 1.0 - ious
 
+        if bbox_weight is not None:
+            loss_box *= bbox_weight
+
         return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.cfg['reg_max'] - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.cfg['reg_max']),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.cfg['reg_max']),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
 
+        return loss_dfl
 
-    def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
-        # xyxy -> cxcy&bwbh
-        gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
-        gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
-        # encode gt box
-        gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
-        gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
-        gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
-        # l1 loss
-        loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
+    def loss_bboxes_aux(self, pred_delta, gt_box, anchors, stride_tensors):
+        gt_delta_tl = (anchors - gt_box[..., :2]) / stride_tensors
+        gt_delta_rb = (gt_box[..., 2:] - anchors) / stride_tensors
+        gt_delta = torch.cat([gt_delta_tl, gt_delta_rb], dim=1)
+        loss_box_aux = F.l1_loss(pred_delta, gt_delta, reduction='none')
 
         return loss_box_aux
-
-
+    
+    # ----------------- Main process -----------------
     def __call__(self, outputs, targets, epoch=0):
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
@@ -104,31 +119,48 @@ class Criterion(object):
         # --------------- label assignment ---------------
         cls_targets = []
         box_targets = []
-        assign_metrics = []
+        fg_masks = []
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
-            # label assignment
-            assigned_result = self.ota_matcher(fpn_strides=fpn_strides,
-                                               anchors=anchors,
-                                               pred_cls=cls_preds[batch_idx].detach(),
-                                               pred_box=box_preds[batch_idx].detach(),
-                                               gt_labels=tgt_labels,
-                                               gt_bboxes=tgt_bboxes
-                                              )
-            cls_targets.append(assigned_result['assigned_labels'])
-            box_targets.append(assigned_result['assigned_bboxes'])
-            assign_metrics.append(assigned_result['assign_metrics'])
-        
-        cls_targets = torch.cat(cls_targets, dim=0)
-        box_targets = torch.cat(box_targets, dim=0)
-        assign_metrics = torch.cat(assign_metrics, dim=0)
-
-        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
-        bg_class_ind = self.num_classes
-        pos_inds = ((cls_targets >= 0) & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
-        num_fgs = assign_metrics.sum()
-        
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                # There is no valid gt
+                cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
+                box_target = cls_preds.new_zeros((0, 4))
+                fg_mask = cls_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    fg_mask,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
+                ) = self.ota_matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+                # prepare cls targets
+                assigned_labels = F.one_hot(assigned_labels.long(), self.num_classes)
+                assigned_labels = assigned_labels * assigned_ious.unsqueeze(-1)
+                cls_target = assigned_labels.new_zeros((num_anchors, self.num_classes))
+                cls_target[fg_mask] = assigned_labels
+                # prepare box targets
+                box_target = tgt_bboxes[assigned_indexs]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        box_targets = torch.cat(box_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
         # average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
@@ -142,38 +174,49 @@ class Criterion(object):
         
         # ------------------ Classification loss ------------------
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
+        loss_cls = self.loss_classes(cls_preds, cls_targets)
         loss_cls = loss_cls.sum() / normalizer
 
         # ------------------ Regression loss ------------------
-        box_preds_pos = box_preds.view(-1, 4)[pos_inds]
-        box_targets_pos = box_targets[pos_inds]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
         loss_box = loss_box.sum() / normalizer
 
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = torch.cat(anchors, dim=0)
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
+        loss_dfl = loss_dfl.sum() / normalizer
+
+        # total loss
         losses = self.loss_cls_weight * loss_cls + \
-                 self.loss_box_weight * loss_box
+                 self.loss_box_weight * loss_box + \
+                 self.loss_dfl_weight * loss_dfl
 
         loss_dict = dict(
                 loss_cls = loss_cls,
                 loss_box = loss_box,
+                loss_dfl = loss_dfl,
                 losses = losses
         )
 
         # ------------------ Aux regression loss ------------------
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
-            ## reg_preds
-            reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-            reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
-            ## anchor tensors
-            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
-            anchors_tensors_pos = anchors_tensors.view(-1, 2)[pos_inds]
-            ## stride tensors
-            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
-            stride_tensors_pos = stride_tensors.view(-1, 1)[pos_inds]
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
+            ## delta_preds
+            delta_preds = torch.cat(outputs['pred_delta'], dim=1)
+            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
             ## aux loss
-            loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets_pos, anchors_tensors_pos, stride_tensors_pos)
-            loss_box_aux = loss_box_aux.sum() / normalizer
+            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
 
             losses += loss_box_aux
             loss_dict['loss_box_aux'] = loss_box_aux
@@ -181,7 +224,6 @@ class Criterion(object):
 
         return loss_dict
 
-
 def build_criterion(args, cfg, device, num_classes):
     criterion = Criterion(
         args=args,

+ 144 - 141
models/detectors/rtcdet/matcher.py

@@ -1,23 +1,17 @@
-# ---------------------------------------------------------------------
-# Copyright (c) OpenMMLab. All rights reserved.
-# ---------------------------------------------------------------------
-
-
 import torch
 import torch.nn.functional as F
 from utils.box_ops import *
 
 
-# RTMDet's Assigner
+# -------------------------- Aligned SimOTA Assigner --------------------------
 class AlignedSimOTA(object):
     """
-        This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
+        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
     """
-    def __init__(self, num_classes=80, soft_center_radius=3.0, topk_candidate=13, iou_weight=3.0):
+    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
         self.num_classes = num_classes
-        self.soft_center_radius = soft_center_radius
+        self.center_sampling_radius = center_sampling_radius
         self.topk_candidate = topk_candidate
-        self.iou_weight = iou_weight
 
 
     @torch.no_grad()
@@ -26,151 +20,160 @@ class AlignedSimOTA(object):
                  anchors, 
                  pred_cls, 
                  pred_box, 
-                 gt_labels,
-                 gt_bboxes):
+                 tgt_labels,
+                 tgt_bboxes):
         # [M,]
-        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
                                 for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
         # List[F, M, 2] -> [M, 2]
         anchors = torch.cat(anchors, dim=0)
-        num_gt = len(gt_labels)
-
-        # check gt
-        if num_gt == 0 or gt_bboxes.max().item() == 0.:
-            return {
-                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
-                                                      self.num_classes,
-                                                      dtype=torch.long),
-                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
-                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
-            }
-        
-        # get inside points: [N, M]
-        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
-        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
-
-        # ----------------------------------- soft center prior -----------------------------------
-        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
-        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
-                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
-        distance = distance * valid_mask.unsqueeze(0)
-        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
-
-        # ----------------------------------- regression cost -----------------------------------
-        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
-        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * self.iou_weight
-
-        # ----------------------------------- classification cost -----------------------------------
-        ## select the predicted scores corresponded to the gt_labels
-        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
-        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
-        ## scale factor
-        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
-        ## cls cost
-        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
-            pairwise_pred_scores, pair_wise_ious,
-            reduction="none") * scale_factor # [N, M]
-            
-        del pairwise_pred_scores
-
-        ## foreground cost matrix
-        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
-        max_pad_value = torch.ones_like(cost_matrix) * 1e9
-        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
-                                  cost_matrix, max_pad_value)
-
-        # ----------------------------------- dynamic label assignment -----------------------------------
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        # ----------------------- Find inside points -----------------------
+        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
+            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
+        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
+        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
+
+        # ----------------------- Reg cost -----------------------
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
+        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
+
+        # ----------------------- Cls cost -----------------------
+        with torch.cuda.amp.autocast(enabled=False):
+            # [Mp, C] -> [N, Mp, C]
+            score_preds = cls_preds.sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
+            # prepare cls_target
+            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
+            cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
+            cls_targets *= pair_wise_ious.unsqueeze(-1)  # iou-aware
+            # [N, Mp]
+            cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
+        del score_preds
+
+        #----------------------- Dynamic K-Matching -----------------------
+        cost_matrix = (
+            cls_cost
+            + 3.0 * reg_cost
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
         (
-            matched_pred_ious,
-            matched_gt_inds,
-            fg_mask_inboxes
+            assigned_labels,         # [num_fg,]
+            assigned_ious,           # [num_fg,]
+            assigned_indexs,         # [num_fg,]
         ) = self.dynamic_k_matching(
             cost_matrix,
             pair_wise_ious,
-            num_gt
-            )
-        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
-
-        # -----------------------------------process assigned labels -----------------------------------
-        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
-                                             self.num_classes)  # [M,]
-        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
-        assigned_labels = assigned_labels.long()  # [M,]
-
-        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
-        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
-
-        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M, 4]
-        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M, 4]
-
-        assigned_dict = dict(
-            assigned_labels=assigned_labels,
-            assigned_bboxes=assigned_bboxes,
-            assign_metrics=assign_metrics
+            tgt_labels,
+            num_gt,
+            fg_mask
             )
+        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
+
+        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
+
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
         
-        return assigned_dict
-
-
-    def find_inside_points(self, gt_bboxes, anchors):
-        """
-            gt_bboxes: Tensor -> [N, 2]
-            anchors:   Tensor -> [M, 2]
-        """
-        num_anchors = anchors.shape[0]
-        num_gt = gt_bboxes.shape[0]
-
-        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
-        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
-
-        # offset
-        lt = anchors_expand - gt_bboxes_expand[..., :2]
-        rb = gt_bboxes_expand[..., 2:] - anchors_expand
-        bbox_deltas = torch.cat([lt, rb], dim=-1)
-
-        is_in_gts = bbox_deltas.min(dim=-1).values > 0
-
-        return is_in_gts
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
     
-
-    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
-        """Use IoU and matching cost to calculate the dynamic top-k positive
-        targets.
-
-        Args:
-            cost_matrix (Tensor): Cost matrix.
-            pairwise_ious (Tensor): Pairwise iou matrix.
-            num_gt (int): Number of gt.
-            valid_mask (Tensor): Mask for valid bboxes.
-        Returns:
-            tuple: matched ious and gt indexes.
-        """
-        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
-        # select candidate topk ious for dynamic-k calculation
-        candidate_topk = min(self.topk_candidate, pairwise_ious.size(1))
-        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
-        # calculate dynamic k for each gt
+    
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
         dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-
-        # sorting the batch cost matirx is faster than topk
-        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        dynamic_ks = dynamic_ks.tolist()
         for gt_idx in range(num_gt):
-            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
-            matching_matrix[gt_idx, :][topk_ids] = 1
-
-        del topk_ious, dynamic_ks, topk_ids
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
 
-        prior_match_gt_mask = matching_matrix.sum(0) > 1
-        if prior_match_gt_mask.sum() > 0:
-            cost_min, cost_argmin = torch.min(
-                cost_matrix[:, prior_match_gt_mask], dim=0)
-            matching_matrix[:, prior_match_gt_mask] *= 0
-            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+        del topk_ious, dynamic_ks, pos_idx
 
-        # get foreground mask inside box and center prior
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
         fg_mask_inboxes = matching_matrix.sum(0) > 0
-        matched_pred_ious = (matching_matrix *
-                             pairwise_ious).sum(0)[fg_mask_inboxes]
-        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
 
-        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        assigned_labels = gt_classes[assigned_indexs]
+
+        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return assigned_labels, assigned_ious, assigned_indexs

+ 3 - 2
models/detectors/rtcdet/rtcdet.py

@@ -29,6 +29,7 @@ class RTCDet(nn.Module):
         self.cfg = cfg
         self.device = device
         self.stride = cfg['stride']
+        self.reg_max = cfg['reg_max']
         self.num_classes = num_classes
         self.trainable = trainable
         self.conf_thresh = conf_thresh
@@ -55,8 +56,8 @@ class RTCDet(nn.Module):
 
         ## ----------- Preds -----------
         self.pred_layers = build_pred_layer(
-            self.det_heads.cls_head_dim, self.det_heads.reg_head_dim,
-            self.stride, num_classes, num_coords=4, num_levels=len(self.stride))
+            self.det_heads.cls_head_dim, self.det_heads.reg_head_dim, self.stride,
+            num_classes=num_classes, num_coords=4, num_levels=len(self.stride), reg_max=self.reg_max)
 
 
     ## post-process

+ 3 - 3
models/detectors/rtcdet/rtcdet_backbone.py

@@ -141,9 +141,9 @@ if __name__ == '__main__':
         'pretrained': False,
         'bk_act': 'silu',
         'bk_norm': 'BN',
-        'bk_depthwise': True,
-        'width': 0.25,
-        'depth': 0.34,
+        'bk_depthwise': False,
+        'width': 1.0,
+        'depth': 1.0,
         'stride': [8, 16, 32],  # P3, P4, P5
         'max_stride': 32,
     }

+ 28 - 12
models/detectors/rtcdet/rtcdet_pred.py

@@ -49,7 +49,7 @@ class SingleLevelPredLayer(nn.Module):
 
 # Multi-level pred layer
 class MultiLevelPredLayer(nn.Module):
-    def __init__(self, cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
+    def __init__(self, cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3, reg_max=16):
         super().__init__()
         # --------- Basic Parameters ----------
         self.cls_dim = cls_dim
@@ -58,6 +58,7 @@ class MultiLevelPredLayer(nn.Module):
         self.num_classes = num_classes
         self.num_coords = num_coords
         self.num_levels = num_levels
+        self.reg_max = reg_max
 
         # ----------- Network Parameters -----------
         ## pred layers
@@ -66,9 +67,13 @@ class MultiLevelPredLayer(nn.Module):
                 cls_dim,
                 reg_dim,
                 num_classes,
-                num_coords)
+                num_coords * self.reg_max)
                 for _ in range(num_levels)
             ])
+        ## proj conv
+        self.proj = nn.Parameter(torch.linspace(0, reg_max, reg_max), requires_grad=False)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
+        self.proj_conv.weight = nn.Parameter(self.proj.view([1, reg_max, 1, 1]).clone().detach(), requires_grad=False)
 
 
     def generate_anchors(self, level, fmp_size):
@@ -92,6 +97,7 @@ class MultiLevelPredLayer(nn.Module):
         all_cls_preds = []
         all_reg_preds = []
         all_box_preds = []
+        all_delta_preds = []
         for level in range(self.num_levels):
             # pred
             cls_pred, reg_pred = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
@@ -106,18 +112,27 @@ class MultiLevelPredLayer(nn.Module):
             
             # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
             cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
 
             # ----------------------- Decode bbox -----------------------
-            ctr_pred = reg_pred[..., :2] * self.strides[level] + anchors[..., :2]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * self.strides[level]
-            pred_x1y1 = ctr_pred - wh_pred * 0.5
-            pred_x2y2 = ctr_pred + wh_pred * 0.5
-            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+            B, M = reg_pred.shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            delta_pred = reg_pred.reshape([B, M, 4, self.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = anchors[None] - delta_pred[..., :2] * self.strides[level]
+            x2y2_pred = anchors[None] + delta_pred[..., 2:] * self.strides[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
 
             all_cls_preds.append(cls_pred)
             all_reg_preds.append(reg_pred)
             all_box_preds.append(box_pred)
+            all_delta_preds.append(delta_pred)
             all_anchors.append(anchors)
             all_strides.append(stride_tensor)
         
@@ -125,16 +140,17 @@ class MultiLevelPredLayer(nn.Module):
         outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
                    "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
                    "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                   "pred_delta": all_delta_preds,    # List(Tensor) [B, M, 4]
                    "anchors": all_anchors,           # List(Tensor) [M, 2]
                    "strides": self.strides,          # List(Int) = [8, 16, 32]
-                   "stride_tensors": all_strides      # List(Tensor) [M, 1]
+                   "stride_tensor": all_strides      # List(Tensor) [M, 1]
                    }
 
         return outputs
     
 
 # build detection head
-def build_pred_layer(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
-    pred_layers = MultiLevelPredLayer(cls_dim, reg_dim, strides, num_classes, num_coords, num_levels) 
+def build_pred_layer(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3, reg_max=16):
+    pred_layers = MultiLevelPredLayer(cls_dim, reg_dim, strides, num_classes, num_coords, num_levels, reg_max) 
 
-    return pred_layers
+    return pred_layers

+ 1 - 1
train_multi_gpus.sh

@@ -4,7 +4,7 @@ python -m torch.distributed.run --nproc_per_node=8 train.py \
                                                     -dist \
                                                     -d coco \
                                                     --root /data/datasets/ \
-                                                    -m rtcdet_l \
+                                                    -m rtcdet_s \
                                                     -bs 128 \
                                                     -size 640 \
                                                     --wp_epoch 3 \

+ 1 - 1
train_single_gpu.sh

@@ -3,7 +3,7 @@ python train.py \
         --cuda \
         -d coco \
         --root /data/datasets/ \
-        -m rtcdet_l \
+        -m rtcdet_s \
         -bs 8 \
         -size 640 \
         --wp_epoch 3 \