|
|
@@ -41,10 +41,7 @@ class Criterion(object):
|
|
|
|
|
|
def loss_bboxes(self, pred_box, gt_box):
|
|
|
# regression loss
|
|
|
- ious = get_ious(pred_box,
|
|
|
- gt_box,
|
|
|
- box_mode="xyxy",
|
|
|
- iou_type='giou')
|
|
|
+ ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
|
|
|
loss_box = 1.0 - ious
|
|
|
|
|
|
return loss_box
|
|
|
@@ -123,16 +120,16 @@ class Criterion(object):
|
|
|
torch.distributed.all_reduce(num_fgs)
|
|
|
num_fgs = (num_fgs / get_world_size()).clamp(1.0)
|
|
|
|
|
|
- # obj loss
|
|
|
+ # ------------------ objecntness loss ------------------
|
|
|
loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
|
|
|
loss_obj = loss_obj.sum() / num_fgs
|
|
|
|
|
|
- # cls loss
|
|
|
+ # ------------------ classification loss ------------------
|
|
|
cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
|
|
|
loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
|
|
|
loss_cls = loss_cls.sum() / num_fgs
|
|
|
|
|
|
- # regression loss
|
|
|
+ # ------------------ regression loss ------------------
|
|
|
box_preds_pos = box_preds.view(-1, 4)[fg_masks]
|
|
|
loss_box = self.loss_bboxes(box_preds_pos, box_targets)
|
|
|
loss_box = loss_box.sum() / num_fgs
|