|
|
@@ -65,9 +65,10 @@ class Criterion(object):
|
|
|
# check target
|
|
|
if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
|
|
|
# There is no valid gt
|
|
|
- fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
|
|
|
+ gt_label = cls_preds.new_full((1, num_anchors), self.num_classes), #[1, M,]
|
|
|
gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
|
|
|
gt_box = cls_preds.new_zeros((1, num_anchors, 4)) #[1, M, 4]
|
|
|
+ fg_mask = cls_preds.new_zeros(1, num_anchors).bool() #[1, M,]
|
|
|
else:
|
|
|
tgt_labels = tgt_labels[None, :, None] # [1, Mp, 1]
|
|
|
tgt_boxs = tgt_boxs[None] # [1, Mp, 4]
|