Pārlūkot izejas kodu

train RT-DETR-R18 on COCO

yjh0410 1 gadu atpakaļ
vecāks
revīzija
49cc168eeb
1 mainītis faili ar 4 papildinājumiem un 2 dzēšanām
  1. 4 2
      models/detectors/rtdetr/rtdetr.py

+ 4 - 2
models/detectors/rtdetr/rtdetr.py

@@ -41,6 +41,8 @@ class RT_DETR(nn.Module):
         box_preds_x2y2 = box_pred[..., :2] + 0.5 * box_pred[..., 2:]
         box_pred = torch.cat([box_preds_x1y1, box_preds_x2y2], dim=-1)
         
+        cls_pred = cls_pred[0]
+        box_pred = box_pred[0]
         if self.no_multi_labels:
             # [M,]
             scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
@@ -65,8 +67,8 @@ class RT_DETR(nn.Module):
             return topk_bboxes, topk_scores, topk_labels
         else:
             # Top-k select
-            cls_pred = cls_pred[0].flatten().sigmoid_()
-            box_pred = box_pred[0]
+            cls_pred = cls_pred.flatten().sigmoid_()
+            box_pred = box_pred
 
             # Keep top k top scoring indices only.
             num_topk = min(self.num_topk, box_pred.size(0))