yjh0410 2 ani în urmă
părinte
comite
96348a74f1
2 a modificat fișierele cu 17 adăugiri și 153 ștergeri
  1. 8 14
      config/model_config/rtcdet_v2_config.py
  2. 9 139
      models/detectors/rtcdet_v2/loss.py

+ 8 - 14
config/model_config/rtcdet_v2_config.py

@@ -44,19 +44,16 @@ rtcdet_v2_cfg = {
         'trans_type': 'yolox_small',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'tal': {'topk': 10,
-                            'alpha': 0.5,
-                            'beta': 6.0},
-                    'ota': {'center_sampling_radius': 2.5,
+        'matcher': {'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
                     },
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
         'loss_box_aux': True,
-        'loss_cls_weight': {'tal': 0.5, 'ota': 1.0},
-        'loss_box_weight': {'tal': 7.0, 'ota': 5.0},
-        'loss_dfl_weight': {'tal': 1.5, 'ota': 1.0},
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
         'trainer_type': 'yolox',
     },
@@ -103,19 +100,16 @@ rtcdet_v2_cfg = {
         'trans_type': 'yolox_large',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'tal': {'topk': 10,
-                            'alpha': 0.5,
-                            'beta': 6.0},
-                    'ota': {'center_sampling_radius': 2.5,
+        'matcher': {'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
                     },
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
         'loss_box_aux': True,
-        'loss_cls_weight': {'tal': 0.5, 'ota': 1.0},
-        'loss_box_weight': {'tal': 7.0, 'ota': 5.0},
-        'loss_dfl_weight': {'tal': 1.5, 'ota': 1.0},
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
         'trainer_type': 'yolox',
     },

+ 9 - 139
models/detectors/rtcdet_v2/loss.py

@@ -4,7 +4,7 @@ import torch.nn.functional as F
 from utils.box_ops import bbox2dist, get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import TaskAlignedAssigner, AlignedSimOTA
+from .matcher import AlignedSimOTA
 
 
 class Criterion(object):
@@ -23,13 +23,6 @@ class Criterion(object):
         self.loss_box_aux    = cfg['loss_box_aux']
         # ---------------- Matcher ----------------
         matcher_config = cfg['matcher']
-        ## TAL assigner
-        self.tal_matcher = TaskAlignedAssigner(
-            topk=matcher_config['tal']['topk'],
-            alpha=matcher_config['tal']['alpha'],
-            beta=matcher_config['tal']['beta'],
-            num_classes=num_classes
-            )
         ## SimOTA assigner
         self.ota_matcher = AlignedSimOTA(
             center_sampling_radius=matcher_config['ota']['center_sampling_radius'],
@@ -37,11 +30,6 @@ class Criterion(object):
             num_classes=num_classes
         )
 
-    def __call__(self, outputs, targets, epoch=0):
-        if epoch < self.args.max_epoch // 2:
-            return self.ota_loss(outputs, targets)
-        else:
-            return self.tal_loss(outputs, targets)
 
     def ema_update(self, name: str, value, initial_value, momentum=0.9):
         if hasattr(self, name):
@@ -52,7 +40,7 @@ class Criterion(object):
         setattr(self, name, new)
         return new
 
-    # ----------------- Loss functions -----------------
+
     def loss_classes(self, pred_cls, gt_score, gt_label=None, vfl=False):
         if vfl:
             assert gt_label is not None
@@ -67,6 +55,7 @@ class Criterion(object):
 
         return loss_cls
 
+
     def loss_bboxes(self, pred_box, gt_box, bbox_weight=None):
         # regression loss
         ious = get_ious(pred_box, gt_box, 'xyxy', 'giou')
@@ -77,6 +66,7 @@ class Criterion(object):
 
         return loss_box
 
+
     def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
         # rescale coords by stride
         gt_box_s = gt_box / stride
@@ -109,6 +99,7 @@ class Criterion(object):
 
         return loss_dfl
 
+
     def loss_bboxes_aux(self, pred_delta, gt_box, anchors, stride_tensors):
         gt_delta_tl = (anchors - gt_box[..., :2]) / stride_tensors
         gt_delta_rb = (gt_box[..., 2:] - anchors) / stride_tensors
@@ -117,129 +108,8 @@ class Criterion(object):
 
         return loss_box_aux
 
-    # ----------------- Loss with TAL assigner -----------------
-    def tal_loss(self, outputs, targets, epoch=0):
-        """ Compute loss with TAL assigner """
-        bs = outputs['pred_cls'][0].shape[0]
-        device = outputs['pred_cls'][0].device
-        anchors = torch.cat(outputs['anchors'], dim=0)
-        num_anchors = anchors.shape[0]
-        # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
-
-        # --------------- label assignment ---------------
-        gt_label_targets = []
-        gt_score_targets = []
-        gt_bbox_targets = []
-        fg_masks = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
 
-            # check target
-            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
-                # There is no valid gt
-                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
-                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
-                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
-                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
-            else:
-                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
-                tgt_bboxes = tgt_bboxes[None]                   # [1, Mp, 4]
-                (
-                    gt_label,   #[1, M]
-                    gt_box,     #[1, M, 4]
-                    gt_score,   #[1, M, C]
-                    fg_mask,    #[1, M,]
-                    _
-                ) = self.tal_matcher(
-                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
-                    anc_points = anchors,
-                    gt_labels = tgt_labels,
-                    gt_bboxes = tgt_bboxes
-                    )
-            gt_label_targets.append(gt_label)
-            gt_score_targets.append(gt_score)
-            gt_bbox_targets.append(gt_box)
-            fg_masks.append(fg_mask)
-
-        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
-        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
-        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
-        gt_label_targets = torch.where(fg_masks > 0, gt_label_targets, torch.full_like(gt_label_targets, self.num_classes))
-        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
-        bbox_weight = gt_score_targets[fg_masks].sum(-1)
-        num_fgs = max(gt_score_targets.sum(), 1)
-
-        # average loss normalizer across all the GPUs
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_fgs)
-        num_fgs = max(num_fgs / get_world_size(), 1.0)
-
-        # update loss normalizer with EMA
-        if self.use_ema_update:
-            normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
-        else:
-            normalizer = num_fgs
-
-        # ------------------ Classification loss ------------------
-        cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_labels_one_hot, vfl=False)
-        loss_cls = loss_cls.sum() / normalizer
-
-        # ------------------ Regression loss ------------------
-        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        box_targets_pos = gt_bbox_targets[fg_masks]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
-        loss_box = loss_box.sum() / normalizer
-
-        # ------------------ Distribution focal loss  ------------------
-        ## process anchors
-        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
-        ## process stride tensors
-        strides = torch.cat(outputs['stride_tensor'], dim=0)
-        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
-        ## fg preds
-        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
-        anchors_pos = anchors[fg_masks]
-        strides_pos = strides[fg_masks]
-        ## compute dfl
-        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
-        loss_dfl = loss_dfl.sum() / normalizer
-
-        # total loss
-        losses = self.loss_cls_weight['tal'] * loss_cls + \
-                 self.loss_box_weight['tal'] * loss_box + \
-                 self.loss_dfl_weight['tal'] * loss_dfl
-
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                loss_dfl = loss_dfl,
-                losses = losses
-        )
-
-        # ------------------ Aux regression loss ------------------
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
-            ## delta_preds
-            delta_preds = torch.cat(outputs['pred_delta'], dim=1)
-            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
-            ## aux loss
-            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
-            loss_box_aux = loss_box_aux.sum() / num_fgs
-
-            losses += loss_box_aux
-            loss_dict['loss_box_aux'] = loss_box_aux
-
-        return loss_dict
-    
-    # ----------------- Loss with SimOTA assigner -----------------
-    def ota_loss(self, outputs, targets, epoch=0):
+    def __call__(self, outputs, targets, epoch=0):
         """ Compute loss with SimOTA assigner """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
@@ -333,9 +203,9 @@ class Criterion(object):
         loss_dfl = loss_dfl.sum() / normalizer
 
         # total loss
-        losses = self.loss_cls_weight['ota'] * loss_cls + \
-                 self.loss_box_weight['ota'] * loss_box + \
-                 self.loss_dfl_weight['ota'] * loss_dfl
+        losses = self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box + \
+                 self.loss_dfl_weight * loss_dfl
 
         loss_dict = dict(
                 loss_cls = loss_cls,