|
@@ -65,7 +65,7 @@ class Criterion(object):
|
|
|
# check target
|
|
# check target
|
|
|
if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
|
|
if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
|
|
|
# There is no valid gt
|
|
# There is no valid gt
|
|
|
- gt_label = cls_preds.new_full((1, num_anchors), self.num_classes) #[1, M,]
|
|
|
|
|
|
|
+ gt_label = cls_preds.new_full((1, num_anchors), self.num_classes).long() #[1, M,]
|
|
|
gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
|
|
gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
|
|
|
gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
|
|
gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
|
|
|
fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
|
|
fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
|