浏览代码

fix a bug in num_gts

yjh0410 1 年之前
父节点
当前提交
137632a0da
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      models/detectors/rtdetr/loss.py

+ 2 - 2
models/detectors/rtdetr/loss.py

@@ -419,6 +419,6 @@ class DINOLoss(DETRLoss):
                 assert len(dn_positive_idx[i]) == len(gt_idx)
                 assert len(dn_positive_idx[i]) == len(gt_idx)
                 dn_match_indices.append((dn_positive_idx[i], gt_idx))
                 dn_match_indices.append((dn_positive_idx[i], gt_idx))
             else:
             else:
-                dn_match_indices.append((torch.zeros([0], device=labels[i].device, dtype="int64"),
-                                         torch.zeros([0], device=labels[i].device, dtype="int64")))
+                dn_match_indices.append((torch.zeros([0], device=labels[i].device).long(),
+                                         torch.zeros([0], device=labels[i].device).long()))
         return dn_match_indices
         return dn_match_indices