yjh0410 2 年之前
父節點
當前提交
fca944e333
共有 2 個文件被更改,包括 4 次插入9 次删除
  1. 1 0
      config/model_config/rtcdet_v2_config.py
  2. 3 9
      models/detectors/rtcdet_v2/loss.py

+ 1 - 0
config/model_config/rtcdet_v2_config.py

@@ -53,6 +53,7 @@ rtcdet_v2_cfg = {
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
+        'loss_box_aux': False,
         'loss_cls_weight': {'tal': 0.5, 'ota': 1.0},
         'loss_box_weight': {'tal': 7.0, 'ota': 5.0},
         'loss_dfl_weight': {'tal': 1.5, 'ota': 1.0},

+ 3 - 9
models/detectors/rtcdet_v2/loss.py

@@ -20,6 +20,7 @@ class Criterion(object):
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_box_weight = cfg['loss_box_weight']
         self.loss_dfl_weight = cfg['loss_dfl_weight']
+        self.loss_box_aux    = cfg['loss_box_aux']
         # ---------------- Matcher ----------------
         matcher_config = cfg['matcher']
         ## TAL assigner
@@ -116,7 +117,6 @@ class Criterion(object):
 
         return loss_box_aux
 
-
     # ----------------- Loss with TAL assigner -----------------
     def tal_loss(self, outputs, targets, epoch=0):
         """ Compute loss with TAL assigner """
@@ -225,16 +225,10 @@ class Criterion(object):
         )
 
         # ------------------ Aux regression loss ------------------
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
             ## delta_preds
             delta_preds = torch.cat(outputs['pred_delta'], dim=1)
             delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
-            ## anchor tensors
-            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
-            anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
-            ## stride tensors
-            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
-            stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
             ## aux loss
             loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
             loss_box_aux = loss_box_aux.sum() / num_fgs
@@ -351,7 +345,7 @@ class Criterion(object):
         )
 
         # ------------------ Aux regression loss ------------------
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
             ## delta_preds
             delta_preds = torch.cat(outputs['pred_delta'], dim=1)
             delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]