|
|
@@ -39,9 +39,6 @@ class SetCriterion(nn.Module):
|
|
|
def loss_labels(self, outputs, targets, indices, num_boxes):
|
|
|
assert 'pred_logits' in outputs
|
|
|
src_logits = outputs['pred_logits']
|
|
|
-
|
|
|
- idx = self._get_src_permutation_idx(indices)
|
|
|
- target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
|
|
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
|
|
dtype=torch.int64, device=src_logits.device)
|
|
|
|