yjh0410 пре 2 година
родитељ
комит
8710feb48b
1 измењених фајлова са 6 додато и 7 уклоњено
  1. 6 7
      models/yolov2/loss.py

+ 6 - 7
models/yolov2/loss.py

@@ -69,20 +69,19 @@ class Criterion(object):
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
-        # cls loss
-        pred_cls_pos = pred_cls[pos_masks]
-        gt_classes_pos = gt_classes[pos_masks]
-        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
-        loss_cls = loss_cls.sum() / num_fgs
-
         # box loss
         pred_box_pos = pred_box[pos_masks]
         gt_bboxes_pos = gt_bboxes[pos_masks]
         loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
         loss_box = loss_box.sum() / num_fgs
         
+        # cls loss
+        pred_cls_pos = pred_cls[pos_masks]
+        gt_classes_pos = gt_classes[pos_masks] * ious.clamp(0.)
+        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
+        loss_cls = loss_cls.sum() / num_fgs
+
         # obj loss
-        gt_objectness[pos_masks] *= ious.clamp(0.)
         loss_obj = self.loss_objectness(pred_obj, gt_objectness)
         loss_obj = loss_obj.sum() / num_fgs