|
|
@@ -69,20 +69,19 @@ class Criterion(object):
|
|
|
torch.distributed.all_reduce(num_fgs)
|
|
|
num_fgs = (num_fgs / get_world_size()).clamp(1.0)
|
|
|
|
|
|
- # cls loss
|
|
|
- pred_cls_pos = pred_cls[pos_masks]
|
|
|
- gt_classes_pos = gt_classes[pos_masks]
|
|
|
- loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
|
|
|
- loss_cls = loss_cls.sum() / num_fgs
|
|
|
-
|
|
|
# box loss
|
|
|
pred_box_pos = pred_box[pos_masks]
|
|
|
gt_bboxes_pos = gt_bboxes[pos_masks]
|
|
|
loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
|
|
|
loss_box = loss_box.sum() / num_fgs
|
|
|
|
|
|
+ # cls loss
|
|
|
+ pred_cls_pos = pred_cls[pos_masks]
|
|
|
+ gt_classes_pos = gt_classes[pos_masks] * ious.clamp(0.)
|
|
|
+ loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
|
|
|
+ loss_cls = loss_cls.sum() / num_fgs
|
|
|
+
|
|
|
# obj loss
|
|
|
- gt_objectness[pos_masks] *= ious.clamp(0.)
|
|
|
loss_obj = self.loss_objectness(pred_obj, gt_objectness)
|
|
|
loss_obj = loss_obj.sum() / num_fgs
|
|
|
|