|
|
@@ -77,7 +77,7 @@ class Criterion(object):
|
|
|
|
|
|
# cls loss
|
|
|
pred_cls_pos = pred_cls[pos_masks]
|
|
|
- gt_classes_pos = gt_classes[pos_masks] * ious.clamp(0.)
|
|
|
+ gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
|
|
|
loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
|
|
|
loss_cls = loss_cls.sum() / num_fgs
|
|
|
|