|
@@ -136,7 +136,6 @@ class ClassificationLoss(nn.Module):
|
|
|
gt_score: (torch.Tensor): [N, C]
|
|
gt_score: (torch.Tensor): [N, C]
|
|
|
"""
|
|
"""
|
|
|
gt_label = gt_label.long()
|
|
gt_label = gt_label.long()
|
|
|
- gt_score = gt_score[:]
|
|
|
|
|
gt_score = gt_score[torch.arange(gt_label.shape[0]), gt_label]
|
|
gt_score = gt_score[torch.arange(gt_label.shape[0]), gt_label]
|
|
|
|
|
|
|
|
pred_sigmoid = pred_cls.sigmoid()
|
|
pred_sigmoid = pred_cls.sigmoid()
|