yjh0410 2 năm trước cách đây
mục cha
commit
e016adec23
1 tập tin đã thay đổi với 2 bổ sung8 xóa
  1. 2 8
      models/detectors/rtdetr/rtdetr.py

+ 2 - 8
models/detectors/rtdetr/rtdetr.py

@@ -88,15 +88,9 @@ class RTDETR(nn.Module):
         topk_bboxes = box_pred[topk_box_idxs]
 
         if self.deploy:
-            # [n_anchors_all, 4 + C]
-            outputs = torch.cat([topk_bboxes, topk_scores.unsqueeze(-1)], dim=-1)
-            return outputs
-        else:
-            topk_bboxes = topk_bboxes.cpu().numpy()
-            topk_scores = topk_scores.cpu().numpy()
-            topk_labels = topk_labels.cpu().numpy()
-
             return topk_bboxes, topk_scores, topk_labels
+        else:
+            return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
         
 
     # ---------------------- Main Process for Training ----------------------