yjh0410 2 年之前
父節點
當前提交
e016adec23
共有 1 個文件被更改,包括 2 次插入8 次删除
  1. 2 8
      models/detectors/rtdetr/rtdetr.py

+ 2 - 8
models/detectors/rtdetr/rtdetr.py

@@ -88,15 +88,9 @@ class RTDETR(nn.Module):
         topk_bboxes = box_pred[topk_box_idxs]
 
         if self.deploy:
-            # [n_anchors_all, 4 + C]
-            outputs = torch.cat([topk_bboxes, topk_scores.unsqueeze(-1)], dim=-1)
-            return outputs
-        else:
-            topk_bboxes = topk_bboxes.cpu().numpy()
-            topk_scores = topk_scores.cpu().numpy()
-            topk_labels = topk_labels.cpu().numpy()
-
             return topk_bboxes, topk_scores, topk_labels
+        else:
+            return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
         
 
     # ---------------------- Main Process for Training ----------------------