浏览代码

fix a bug in num_gts

yjh0410 1 年之前
父节点
当前提交
a89d5c1cc7
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      models/detectors/rtdetr/loss.py

+ 2 - 2
models/detectors/rtdetr/loss.py

@@ -419,6 +419,6 @@ class DINOLoss(DETRLoss):
                 assert len(dn_positive_idx[i]) == len(gt_idx)
                 dn_match_indices.append((dn_positive_idx[i], gt_idx))
             else:
-                dn_match_indices.append((torch.zeros([0], dtype="int64"),
-                                         torch.zeros([0], dtype="int64")))
+                dn_match_indices.append((torch.zeros([0], device=labels[i].device, dtype="int64"),
+                                         torch.zeros([0], device=labels[i].device, dtype="int64")))
         return dn_match_indices