|
|
@@ -15,7 +15,6 @@ class Criterion(object):
|
|
|
self.device = device
|
|
|
self.num_classes = num_classes
|
|
|
self.reg_max = cfg['reg_max']
|
|
|
- self.use_dfl = cfg['reg_max'] > 1
|
|
|
# --------------- Loss config ---------------
|
|
|
self.loss_cls_weight = cfg['loss_cls_weight']
|
|
|
self.loss_box_weight = cfg['loss_box_weight']
|
|
|
@@ -170,125 +169,17 @@ class Criterion(object):
|
|
|
loss_dfl = loss_dfl.sum() / num_fgs
|
|
|
|
|
|
# total loss
|
|
|
- if not self.use_dfl:
|
|
|
- losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight
|
|
|
- loss_dict = dict(
|
|
|
- loss_cls = loss_cls,
|
|
|
- loss_box = loss_box,
|
|
|
- losses = losses
|
|
|
- )
|
|
|
- else:
|
|
|
- losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight + loss_dfl * self.loss_dfl_weight
|
|
|
- loss_dict = dict(
|
|
|
- loss_cls = loss_cls,
|
|
|
- loss_box = loss_box,
|
|
|
- loss_dfl = loss_dfl,
|
|
|
- losses = losses
|
|
|
- )
|
|
|
+ losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight + loss_dfl * self.loss_dfl_weight
|
|
|
+ loss_dict = dict(
|
|
|
+ loss_cls = loss_cls,
|
|
|
+ loss_box = loss_box,
|
|
|
+ loss_dfl = loss_dfl,
|
|
|
+ losses = losses
|
|
|
+ )
|
|
|
|
|
|
return loss_dict
|
|
|
|
|
|
|
|
|
-class ClassificationLoss(nn.Module):
|
|
|
- def __init__(self, cfg, reduction='none'):
|
|
|
- super(ClassificationLoss, self).__init__()
|
|
|
- self.cfg = cfg
|
|
|
- self.reduction = reduction
|
|
|
- # For VFL
|
|
|
- self.alpha = 0.75
|
|
|
- self.gamma = 2.0
|
|
|
-
|
|
|
-
|
|
|
- def binary_cross_entropy(self, pred_logits, gt_score):
|
|
|
- loss = F.binary_cross_entropy_with_logits(
|
|
|
- pred_logits.float(), gt_score.float(), reduction='none')
|
|
|
-
|
|
|
- if self.reduction == 'sum':
|
|
|
- loss = loss.sum()
|
|
|
- elif self.reduction == 'mean':
|
|
|
- loss = loss.mean()
|
|
|
-
|
|
|
- return loss
|
|
|
-
|
|
|
-
|
|
|
- def forward(self, pred_logits, gt_score):
|
|
|
- if self.cfg['cls_loss'] == 'bce':
|
|
|
- return self.binary_cross_entropy(pred_logits, gt_score)
|
|
|
-
|
|
|
-
|
|
|
-class RegressionLoss(nn.Module):
|
|
|
- def __init__(self, num_classes, reg_max, use_dfl):
|
|
|
- super(RegressionLoss, self).__init__()
|
|
|
- self.num_classes = num_classes
|
|
|
- self.reg_max = reg_max
|
|
|
- self.use_dfl = use_dfl
|
|
|
-
|
|
|
-
|
|
|
- def df_loss(self, pred_regs, target):
|
|
|
- gt_left = target.to(torch.long)
|
|
|
- gt_right = gt_left + 1
|
|
|
- weight_left = gt_right.to(torch.float) - target
|
|
|
- weight_right = 1 - weight_left
|
|
|
- # loss left
|
|
|
- loss_left = F.cross_entropy(
|
|
|
- pred_regs.view(-1, self.reg_max + 1),
|
|
|
- gt_left.view(-1),
|
|
|
- reduction='none').view(gt_left.shape) * weight_left
|
|
|
- # loss right
|
|
|
- loss_right = F.cross_entropy(
|
|
|
- pred_regs.view(-1, self.reg_max + 1),
|
|
|
- gt_right.view(-1),
|
|
|
- reduction='none').view(gt_left.shape) * weight_right
|
|
|
-
|
|
|
- loss = (loss_left + loss_right).mean(-1, keepdim=True)
|
|
|
-
|
|
|
- return loss
|
|
|
-
|
|
|
-
|
|
|
- def forward(self, pred_regs, pred_boxs, anchors, gt_boxs, bbox_weight, fg_masks, strides):
|
|
|
- """
|
|
|
- Input:
|
|
|
- pred_regs: (Tensor) [BM, 4*(reg_max + 1)]
|
|
|
- pred_boxs: (Tensor) [BM, 4]
|
|
|
- anchors: (Tensor) [BM, 2]
|
|
|
- gt_boxs: (Tensor) [BM, 4]
|
|
|
- bbox_weight: (Tensor) [BM, 1]
|
|
|
- fg_masks: (Tensor) [BM,]
|
|
|
- strides: (Tensor) [BM, 1]
|
|
|
- """
|
|
|
- # select positive samples mask
|
|
|
- num_pos = fg_masks.sum()
|
|
|
-
|
|
|
- if num_pos > 0:
|
|
|
- pred_boxs_pos = pred_boxs[fg_masks]
|
|
|
- gt_boxs_pos = gt_boxs[fg_masks]
|
|
|
-
|
|
|
- # iou loss
|
|
|
- ious = bbox_iou(pred_boxs_pos,
|
|
|
- gt_boxs_pos,
|
|
|
- xywh=False,
|
|
|
- CIoU=True)
|
|
|
- loss_iou = (1.0 - ious) * bbox_weight
|
|
|
-
|
|
|
- # dfl loss
|
|
|
- if self.use_dfl:
|
|
|
- pred_regs_pos = pred_regs[fg_masks]
|
|
|
- gt_boxs_s = gt_boxs / strides
|
|
|
- anchors_s = anchors / strides
|
|
|
- gt_ltrb_s = bbox2dist(anchors_s, gt_boxs_s, self.reg_max)
|
|
|
- gt_ltrb_s_pos = gt_ltrb_s[fg_masks]
|
|
|
- loss_dfl = self.df_loss(pred_regs_pos, gt_ltrb_s_pos)
|
|
|
- loss_dfl *= bbox_weight
|
|
|
- else:
|
|
|
- loss_dfl = pred_regs.sum() * 0.
|
|
|
-
|
|
|
- else:
|
|
|
- loss_iou = pred_regs.sum() * 0.
|
|
|
- loss_dfl = pred_regs.sum() * 0.
|
|
|
-
|
|
|
- return loss_iou, loss_dfl
|
|
|
-
|
|
|
-
|
|
|
def build_criterion(cfg, device, num_classes):
|
|
|
criterion = Criterion(
|
|
|
cfg=cfg,
|