소스 검색

debug loss of YOLOv2

yjh0410 2 년 전
부모
커밋
074eb26310
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      models/yolov2/loss.py

+ 1 - 1
models/yolov2/loss.py

@@ -77,7 +77,7 @@ class Criterion(object):
         
         # cls loss
         pred_cls_pos = pred_cls[pos_masks]
-        gt_classes_pos = gt_classes[pos_masks] * ious.clamp(0.)
+        gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
         loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
         loss_cls = loss_cls.sum() / num_fgs