|
@@ -72,7 +72,7 @@ class Criterion(object):
|
|
|
tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
|
|
tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
|
|
|
tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
|
|
tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
|
|
|
(
|
|
(
|
|
|
- gt_label,
|
|
|
|
|
|
|
+ gt_label, #[1, M,]
|
|
|
gt_box, #[1, M, 4]
|
|
gt_box, #[1, M, 4]
|
|
|
gt_score, #[1, M, C]
|
|
gt_score, #[1, M, C]
|
|
|
fg_mask, #[1, M,]
|
|
fg_mask, #[1, M,]
|
|
@@ -91,9 +91,10 @@ class Criterion(object):
|
|
|
|
|
|
|
|
# List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
|
|
# List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
|
|
|
fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
|
|
fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
|
|
|
- gt_label_targets = torch.cat(gt_label_targets, 0).view(-1,) # [BM, 1]
|
|
|
|
|
|
|
+ gt_label_targets = torch.cat(gt_label_targets, 0).view(-1,) # [BM,]
|
|
|
gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
|
|
gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
|
|
|
gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
|
|
gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
|
|
|
|
|
+ num_fgs = max(gt_score_targets.sum(), 1)
|
|
|
|
|
|
|
|
# cls loss
|
|
# cls loss
|
|
|
cls_preds = cls_preds.view(-1, self.num_classes)
|
|
cls_preds = cls_preds.view(-1, self.num_classes)
|
|
@@ -105,9 +106,8 @@ class Criterion(object):
|
|
|
loss_box = self.reg_lossf(box_preds, gt_bbox_targets, bbox_weight, fg_masks)
|
|
loss_box = self.reg_lossf(box_preds, gt_bbox_targets, bbox_weight, fg_masks)
|
|
|
|
|
|
|
|
# normalize loss
|
|
# normalize loss
|
|
|
- gt_score_targets_sum = max(gt_score_targets.sum(), 1)
|
|
|
|
|
- loss_cls = loss_cls.sum() / gt_score_targets_sum
|
|
|
|
|
- loss_box = loss_box.sum() / gt_score_targets_sum
|
|
|
|
|
|
|
+ loss_cls = loss_cls.sum() / num_fgs
|
|
|
|
|
+ loss_box = loss_box.sum() / num_fgs
|
|
|
|
|
|
|
|
# total loss
|
|
# total loss
|
|
|
losses = loss_cls * self.loss_cls_weight + \
|
|
losses = loss_cls * self.loss_cls_weight + \
|