|
|
@@ -88,15 +88,9 @@ class RTDETR(nn.Module):
|
|
|
topk_bboxes = box_pred[topk_box_idxs]
|
|
|
|
|
|
if self.deploy:
|
|
|
- # [n_anchors_all, 4 + C]
|
|
|
- outputs = torch.cat([topk_bboxes, topk_scores.unsqueeze(-1)], dim=-1)
|
|
|
- return outputs
|
|
|
- else:
|
|
|
- topk_bboxes = topk_bboxes.cpu().numpy()
|
|
|
- topk_scores = topk_scores.cpu().numpy()
|
|
|
- topk_labels = topk_labels.cpu().numpy()
|
|
|
-
|
|
|
return topk_bboxes, topk_scores, topk_labels
|
|
|
+ else:
|
|
|
+ return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
|
|
|
|
|
|
|
|
|
# ---------------------- Main Process for Training ----------------------
|