浏览代码

design a more general Trainer

yjh0410 2 年之前
父节点
当前提交
851d05f075
共有 1 个文件被更改,包括 5 次插入5 次删除
  1. 5 5
      models/detectors/yolox_plus/loss.py

+ 5 - 5
models/detectors/yolox_plus/loss.py

@@ -72,7 +72,7 @@ class Criterion(object):
                 tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
                 tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    gt_label,
+                    gt_label,   #[1, M,]
                     gt_box,     #[1, M, 4]
                     gt_score,   #[1, M, C]
                     fg_mask,    #[1, M,]
@@ -91,9 +91,10 @@ class Criterion(object):
 
         # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
         fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1,)                   # [BM, 1]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1,)                   # [BM,]
         gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
         gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = max(gt_score_targets.sum(), 1)
        
         # cls loss
         cls_preds = cls_preds.view(-1, self.num_classes)
@@ -105,9 +106,8 @@ class Criterion(object):
         loss_box = self.reg_lossf(box_preds, gt_bbox_targets, bbox_weight, fg_masks)
         
         # normalize loss
-        gt_score_targets_sum = max(gt_score_targets.sum(), 1)
-        loss_cls = loss_cls.sum() / gt_score_targets_sum
-        loss_box = loss_box.sum() / gt_score_targets_sum
+        loss_cls = loss_cls.sum() / num_fgs
+        loss_box = loss_box.sum() / num_fgs
 
         # total loss
         losses = loss_cls * self.loss_cls_weight + \