yjh0410 пре 1 година
родитељ
комит
a89d5c1cc7
1 измењених фајлова са 2 додато и 2 уклоњено
  1. 2 2
      models/detectors/rtdetr/loss.py

+ 2 - 2
models/detectors/rtdetr/loss.py

@@ -419,6 +419,6 @@ class DINOLoss(DETRLoss):
                 assert len(dn_positive_idx[i]) == len(gt_idx)
                 dn_match_indices.append((dn_positive_idx[i], gt_idx))
             else:
-                dn_match_indices.append((torch.zeros([0], dtype="int64"),
-                                         torch.zeros([0], dtype="int64")))
+                dn_match_indices.append((torch.zeros([0], device=labels[i].device, dtype="int64"),
+                                         torch.zeros([0], device=labels[i].device, dtype="int64")))
         return dn_match_indices