Parcourir la source

fix a bug in num_gts

yjh0410 il y a 1 an
Parent
commit
c37e4319bc
1 fichiers modifiés avec 1 ajouts et 1 suppressions
  1. 1 1
      models/detectors/rtdetr/loss.py

+ 1 - 1
models/detectors/rtdetr/loss.py

@@ -277,7 +277,7 @@ class DETRLoss(nn.Module):
 
     def _get_num_gts(self, targets):
         num_gts = sum(len(a) for a in targets)
-        num_gts = torch.as_tensor([num_gts]).float()
+        num_gts = torch.as_tensor([num_gts], device=targets[0].device).float()
 
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_gts)