Ver código fonte

fix a bug in num_gts

yjh0410 1 ano atrás
pai
commit
a89d5c1cc7
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      models/detectors/rtdetr/loss.py

+ 2 - 2
models/detectors/rtdetr/loss.py

@@ -419,6 +419,6 @@ class DINOLoss(DETRLoss):
                 assert len(dn_positive_idx[i]) == len(gt_idx)
                 dn_match_indices.append((dn_positive_idx[i], gt_idx))
             else:
-                dn_match_indices.append((torch.zeros([0], dtype="int64"),
-                                         torch.zeros([0], dtype="int64")))
+                dn_match_indices.append((torch.zeros([0], device=labels[i].device, dtype="int64"),
+                                         torch.zeros([0], device=labels[i].device, dtype="int64")))
         return dn_match_indices