yjh0410 1 år sedan
förälder
incheckning
0edc875536
1 ändrade filer med 6 tillägg och 2 borttagningar
  1. 6 2
      odlab/models/detectors/fcos/criterion.py

+ 6 - 2
odlab/models/detectors/fcos/criterion.py

@@ -114,10 +114,13 @@ class SetCriterion(nn.Module):
 
         return loss_box.sum() / num_boxes
 
-    def loss_bboxes_xyxy(self, pred_box, gt_box, num_boxes=1.0):
+    def loss_bboxes_xyxy(self, pred_box, gt_box, num_boxes=1.0, box_weight=None):
         ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
         loss_box = 1.0 - ious
 
+        if box_weight is not None:
+            loss_box = loss_box.unsqueeze(1) * box_weight
+
         return loss_box.sum() / num_boxes
     
     def fcos_loss(self, outputs, targets):
@@ -248,7 +251,8 @@ class SetCriterion(nn.Module):
         # -------------------- regression loss --------------------
         box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
         box_targets_pos = box_targets[foreground_idxs]
-        loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs)
+        box_weight = assign_metrics[foreground_idxs]
+        loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs, box_weight)
 
         loss_dict = dict(
                 loss_cls = loss_labels,