|
|
@@ -68,9 +68,14 @@ class RTDETR(nn.Module):
|
|
|
# -------------------- Decode bbox --------------------
|
|
|
cls_pred = out_logits[0]
|
|
|
box_pred = out_bbox[0]
|
|
|
+ ## cxcywh -> xyxy
|
|
|
x1y1_pred = box_pred[..., :2] - box_pred[..., 2:] * 0.5
|
|
|
x2y2_pred = box_pred[..., :2] + box_pred[..., 2:] * 0.5
|
|
|
box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
+ ## denormalize bbox
|
|
|
+ img_h, img_w = x.shape[-2:]
|
|
|
+ box_pred[..., 0::2] *= img_w
|
|
|
+ box_pred[..., 1::2] *= img_h
|
|
|
|
|
|
# -------------------- Top-k --------------------
|
|
|
cls_pred = cls_pred.flatten().sigmoid_()
|
|
|
@@ -82,7 +87,16 @@ class RTDETR(nn.Module):
|
|
|
topk_labels = topk_idxs % self.num_classes
|
|
|
topk_bboxes = box_pred[topk_box_idxs]
|
|
|
|
|
|
- return topk_bboxes, topk_scores, topk_labels
|
|
|
+ if self.deploy:
|
|
|
+ # [n_anchors_all, 4 + C]
|
|
|
+ outputs = torch.cat([topk_bboxes, topk_scores.unsqueeze(-1)], dim=-1)
|
|
|
+ return outputs
|
|
|
+ else:
|
|
|
+ topk_bboxes = topk_bboxes.cpu().numpy()
|
|
|
+ topk_scores = topk_scores.cpu().numpy()
|
|
|
+ topk_labels = topk_labels.cpu().numpy()
|
|
|
+
|
|
|
+ return topk_bboxes, topk_scores, topk_labels
|
|
|
|
|
|
|
|
|
# ---------------------- Main Process for Training ----------------------
|