|
@@ -23,16 +23,15 @@ class SetCriterion(object):
|
|
|
beta = cfg.tal_beta
|
|
beta = cfg.tal_beta
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- def loss_classes(self, pred_logits, gt_score, gt_label, fg_mask):
|
|
|
|
|
- gt_label = torch.where(fg_mask > 0, gt_label, torch.full_like(gt_label, self.num_classes))
|
|
|
|
|
- one_hot_label = F.one_hot(gt_label.long(), self.num_classes + 1)[..., :-1]
|
|
|
|
|
-
|
|
|
|
|
- pred_score = pred_logits.sigmoid()
|
|
|
|
|
- weight = 0.75 * pred_score.pow(2.0) * (1 - one_hot_label) + gt_score * one_hot_label
|
|
|
|
|
- with torch.cuda.amp.autocast(enabled=False):
|
|
|
|
|
- loss_cls = F.binary_cross_entropy_with_logits(pred_logits.float(), gt_score.float(), reduction='none')
|
|
|
|
|
- loss_cls = loss_cls * weight
|
|
|
|
|
- loss_cls = loss_cls.sum()
|
|
|
|
|
|
|
+ def loss_classes(self, pred_logits, gt_score):
|
|
|
|
|
+ alpha, gamma = 0.75, 2.0
|
|
|
|
|
+ pred_sigmoid = pred_logits.sigmoid()
|
|
|
|
|
+ focal_weight = gt_score * (gt_score > 0.0).float() + \
|
|
|
|
|
+ alpha * (pred_sigmoid - gt_score).abs().pow(gamma) * \
|
|
|
|
|
+ (gt_score <= 0.0).float()
|
|
|
|
|
+
|
|
|
|
|
+ loss_cls = F.binary_cross_entropy_with_logits(
|
|
|
|
|
+ pred_logits, gt_score, reduction='none') * focal_weight
|
|
|
|
|
|
|
|
return loss_cls
|
|
return loss_cls
|
|
|
|
|
|
|
@@ -63,7 +62,6 @@ class SetCriterion(object):
|
|
|
anchors = torch.cat(outputs['anchors'], dim=0)
|
|
anchors = torch.cat(outputs['anchors'], dim=0)
|
|
|
|
|
|
|
|
# --------------- label assignment ---------------
|
|
# --------------- label assignment ---------------
|
|
|
- gt_label_targets = []
|
|
|
|
|
gt_score_targets = []
|
|
gt_score_targets = []
|
|
|
gt_bbox_targets = []
|
|
gt_bbox_targets = []
|
|
|
fg_masks = []
|
|
fg_masks = []
|
|
@@ -92,7 +90,7 @@ class SetCriterion(object):
|
|
|
tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
|
|
tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
|
|
|
tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
|
|
tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
|
|
|
(
|
|
(
|
|
|
- gt_label, # [1, M]
|
|
|
|
|
|
|
+ _, # [1, M]
|
|
|
gt_box, # [1, M, 4]
|
|
gt_box, # [1, M, 4]
|
|
|
gt_score, # [1, M, C]
|
|
gt_score, # [1, M, C]
|
|
|
fg_mask, # [1, M,]
|
|
fg_mask, # [1, M,]
|
|
@@ -104,14 +102,12 @@ class SetCriterion(object):
|
|
|
gt_labels = tgt_labels,
|
|
gt_labels = tgt_labels,
|
|
|
gt_bboxes = tgt_boxs
|
|
gt_bboxes = tgt_boxs
|
|
|
)
|
|
)
|
|
|
- gt_label_targets.append(gt_label)
|
|
|
|
|
gt_score_targets.append(gt_score)
|
|
gt_score_targets.append(gt_score)
|
|
|
gt_bbox_targets.append(gt_box)
|
|
gt_bbox_targets.append(gt_box)
|
|
|
fg_masks.append(fg_mask)
|
|
fg_masks.append(fg_mask)
|
|
|
|
|
|
|
|
# List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
|
|
# List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
|
|
|
fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
|
|
fg_masks = torch.cat(fg_masks, 0).view(-1) # [BM,]
|
|
|
- gt_label_targets = torch.cat(gt_label_targets, 0).view(-1) # [BM,]
|
|
|
|
|
gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
|
|
gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes) # [BM, C]
|
|
|
gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
|
|
gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4) # [BM, 4]
|
|
|
num_fgs = gt_score_targets.sum()
|
|
num_fgs = gt_score_targets.sum()
|
|
@@ -123,7 +119,7 @@ class SetCriterion(object):
|
|
|
|
|
|
|
|
# ------------------ Classification loss ------------------
|
|
# ------------------ Classification loss ------------------
|
|
|
cls_preds = cls_preds.view(-1, self.num_classes)
|
|
cls_preds = cls_preds.view(-1, self.num_classes)
|
|
|
- loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_label_targets, fg_masks)
|
|
|
|
|
|
|
+ loss_cls = self.loss_classes(cls_preds, gt_score_targets)
|
|
|
loss_cls = loss_cls.sum() / num_fgs
|
|
loss_cls = loss_cls.sum() / num_fgs
|
|
|
|
|
|
|
|
# ------------------ Regression loss ------------------
|
|
# ------------------ Regression loss ------------------
|