|
|
@@ -41,6 +41,8 @@ class RT_DETR(nn.Module):
|
|
|
box_preds_x2y2 = box_pred[..., :2] + 0.5 * box_pred[..., 2:]
|
|
|
box_pred = torch.cat([box_preds_x1y1, box_preds_x2y2], dim=-1)
|
|
|
|
|
|
+ cls_pred = cls_pred[0]
|
|
|
+ box_pred = box_pred[0]
|
|
|
if self.no_multi_labels:
|
|
|
# [M,]
|
|
|
scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
|
|
|
@@ -65,8 +67,8 @@ class RT_DETR(nn.Module):
|
|
|
return topk_bboxes, topk_scores, topk_labels
|
|
|
else:
|
|
|
# Top-k select
|
|
|
- cls_pred = cls_pred[0].flatten().sigmoid_()
|
|
|
- box_pred = box_pred[0]
|
|
|
+ cls_pred = cls_pred.flatten().sigmoid_()
|
|
|
+ box_pred = box_pred
|
|
|
|
|
|
# Keep top k top scoring indices only.
|
|
|
num_topk = min(self.num_topk, box_pred.size(0))
|