@@ -119,7 +119,7 @@ class SetCriterion(nn.Module):
loss_box = 1.0 - ious
if box_weight is not None:
- loss_box = loss_box.unsqueeze(1) * box_weight
+ loss_box = loss_box.squeeze(-1) * box_weight
return loss_box.sum() / num_boxes