|
@@ -119,7 +119,7 @@ class SetCriterion(nn.Module):
|
|
|
loss_box = 1.0 - ious
|
|
loss_box = 1.0 - ious
|
|
|
|
|
|
|
|
if box_weight is not None:
|
|
if box_weight is not None:
|
|
|
- loss_box = loss_box.unsqueeze(1) * box_weight
|
|
|
|
|
|
|
+ loss_box = loss_box.squeeze(-1) * box_weight
|
|
|
|
|
|
|
|
return loss_box.sum() / num_boxes
|
|
return loss_box.sum() / num_boxes
|
|
|
|
|
|