浏览代码

fix a bug in num_gts

yjh0410 1 年之前
父节点
当前提交
c37e4319bc
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      models/detectors/rtdetr/loss.py

+ 1 - 1
models/detectors/rtdetr/loss.py

@@ -277,7 +277,7 @@ class DETRLoss(nn.Module):
 
     def _get_num_gts(self, targets):
         num_gts = sum(len(a) for a in targets)
-        num_gts = torch.as_tensor([num_gts]).float()
+        num_gts = torch.as_tensor([num_gts], device=targets[0].device).float()
 
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_gts)