|
|
@@ -277,7 +277,7 @@ class DETRLoss(nn.Module):
|
|
|
|
|
|
def _get_num_gts(self, targets):
|
|
|
num_gts = sum(len(a) for a in targets)
|
|
|
- num_gts = torch.as_tensor([num_gts]).float()
|
|
|
+ num_gts = torch.as_tensor([num_gts], device=targets[0].device).float()
|
|
|
|
|
|
if is_dist_avail_and_initialized():
|
|
|
torch.distributed.all_reduce(num_gts)
|