|
|
@@ -44,14 +44,13 @@ class AlignedSimOTA(object):
|
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
|
# [Mp, C] -> [N, Mp, C]
|
|
|
cls_preds_expand = cls_preds.unsqueeze(0).repeat(num_gt, 1, 1)
|
|
|
- score_preds = torch.sigmoid(cls_preds_expand)
|
|
|
# prepare cls_target
|
|
|
cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
|
|
|
- cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
|
|
|
+ cls_targets = cls_targets.unsqueeze(1).repeat(1, cls_preds_expand.size(1), 1)
|
|
|
cls_targets *= pair_wise_ious.unsqueeze(-1) # iou-aware
|
|
|
# [N, Mp]
|
|
|
cls_cost = F.binary_cross_entropy_with_logits(cls_preds_expand, cls_targets, reduction="none").sum(-1)
|
|
|
- del score_preds, cls_preds_expand
|
|
|
+ del cls_preds_expand
|
|
|
|
|
|
#----------------------- Dynamic K-Matching -----------------------
|
|
|
cost_matrix = (
|