浏览代码

fix a bug in num_gts

yjh0410 1 年之前
父节点
当前提交
7211aab24e
共有 3 个文件被更改,包括 5 次插入6 次删除
  1. 1 2
      config/model_config/rtdetr_config.py
  2. 1 1
      models/detectors/rtdetr/loss.py
  3. 3 3
      models/detectors/rtdetr/loss_utils.py

+ 1 - 2
config/model_config/rtdetr_config.py

@@ -52,8 +52,7 @@ rtdetr_cfg = {
         'use_vfl': True,
         'loss_coeff': {'class': 1,
                        'bbox': 5,
-                       'giou': 2,
-                       'no_object': 0.1,},
+                       'giou': 2,},
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800

+ 1 - 1
models/detectors/rtdetr/loss.py

@@ -169,7 +169,7 @@ class DETRLoss(nn.Module):
             loss_cls = varifocal_loss_with_logits(logits,
                                                   target_score,
                                                   target_label,
-                                                  num_gts)
+                                                  num_gts / num_query_objects)
         else:
             loss_cls = sigmoid_focal_loss(logits,
                                           target_label,

+ 3 - 3
models/detectors/rtdetr/loss_utils.py

@@ -32,7 +32,7 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
         alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
         loss = alpha_t * loss
 
-    return loss.sum() / num_boxes
+    return loss.mean(1).sum() / num_boxes
 
 ## Variable FocalLoss
 def varifocal_loss_with_logits(pred_logits,
@@ -44,9 +44,9 @@ def varifocal_loss_with_logits(pred_logits,
     pred_score = F.sigmoid(pred_logits)
     weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
     loss = F.binary_cross_entropy_with_logits(pred_logits, gt_score, reduction='none')
-    loss *= weight
+    loss = loss * weight
 
-    return loss.sum() / normalizer
+    return loss.mean(1).sum() / normalizer
 
 ## InverseSigmoid
 def inverse_sigmoid(x, eps=1e-5):