|
|
@@ -19,18 +19,20 @@ class SetCriterion(nn.Module):
|
|
|
self.alpha = cfg.focal_loss_alpha
|
|
|
self.gamma = cfg.focal_loss_gamma
|
|
|
# ------------- Loss weight -------------
|
|
|
- self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
|
|
|
- 'loss_reg': cfg.loss_reg_weight,
|
|
|
- 'loss_ctn': cfg.loss_ctn_weight}
|
|
|
- # ------------- Matcher -------------
|
|
|
+ # ------------- Matcher & Loss weight -------------
|
|
|
self.matcher_cfg = cfg.matcher_hpy
|
|
|
if cfg.matcher == 'fcos_matcher':
|
|
|
+ self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
|
|
|
+ 'loss_reg': cfg.loss_reg_weight,
|
|
|
+ 'loss_ctn': cfg.loss_ctn_weight}
|
|
|
self.matcher = FcosMatcher(cfg.num_classes,
|
|
|
self.matcher_cfg['center_sampling_radius'],
|
|
|
self.matcher_cfg['object_sizes_of_interest'],
|
|
|
[1., 1., 1., 1.]
|
|
|
)
|
|
|
elif cfg.matcher == 'simota':
|
|
|
+ self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
|
|
|
+ 'loss_reg': cfg.loss_reg_weight}
|
|
|
self.matcher = SimOtaMatcher(cfg.num_classes,
|
|
|
self.matcher_cfg['soft_center_radius'],
|
|
|
self.matcher_cfg['topk_candidates'])
|
|
|
@@ -47,6 +49,33 @@ class SetCriterion(nn.Module):
|
|
|
|
|
|
return loss_cls.sum() / num_boxes
|
|
|
|
|
|
+ def loss_labels_qfl(self, pred_cls, target, beta=2.0, num_boxes=1.0):
|
|
|
+ # Quality FocalLoss
|
|
|
+ """
|
|
|
+ pred_cls: (torch.Tensor): [N, C]。
|
|
|
+ target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
|
|
|
+ """
|
|
|
+ label, score = target
|
|
|
+ pred_sigmoid = pred_cls.sigmoid()
|
|
|
+ scale_factor = pred_sigmoid
|
|
|
+ zerolabel = scale_factor.new_zeros(pred_cls.shape)
|
|
|
+
|
|
|
+ ce_loss = F.binary_cross_entropy_with_logits(
|
|
|
+ pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
|
|
|
+
|
|
|
+ bg_class_ind = pred_cls.shape[-1]
|
|
|
+ pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
|
|
|
+ if pos.shape[0] > 0:
|
|
|
+ pos_label = label[pos].long()
|
|
|
+
|
|
|
+ scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
|
|
|
+
|
|
|
+ ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
|
|
|
+ pred_cls[pos, pos_label], score[pos],
|
|
|
+ reduction='none') * scale_factor.abs().pow(beta)
|
|
|
+
|
|
|
+ return ce_loss.sum() / num_boxes
|
|
|
+
|
|
|
def loss_bboxes_ltrb(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
|
|
|
"""
|
|
|
pred_box: (Tensor) [N, 4]
|
|
|
@@ -157,27 +186,27 @@ class SetCriterion(nn.Module):
|
|
|
outputs['pred_cls']: (Tensor) [B, M, C]
|
|
|
outputs['pred_reg']: (Tensor) [B, M, 4]
|
|
|
outputs['pred_box']: (Tensor) [B, M, 4]
|
|
|
- outputs['pred_ctn']: (Tensor) [B, M, 1]
|
|
|
outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
|
|
|
targets: (List) [dict{'boxes': [...],
|
|
|
'labels': [...],
|
|
|
'orig_size': ...}, ...]
|
|
|
"""
|
|
|
# -------------------- Pre-process --------------------
|
|
|
- device = outputs['pred_cls'][0].device
|
|
|
- batch_size = outputs['pred_cls'][0].shape[0]
|
|
|
+ bs = outputs['pred_cls'][0].shape[0]
|
|
|
+ device = outputs['pred_cls'][0].device
|
|
|
fpn_strides = outputs['strides']
|
|
|
- anchors = outputs['anchors']
|
|
|
- pred_cls = torch.cat(outputs['pred_cls'], dim=1) # [B, M, C]
|
|
|
- pred_box = torch.cat(outputs['pred_box'], dim=1) # [B, M, 4]
|
|
|
- pred_ctn = torch.cat(outputs['pred_ctn'], dim=1) # [B, M, 1]
|
|
|
+ anchors = outputs['anchors']
|
|
|
+ # preds: [B, M, C]
|
|
|
+ # preds: [B, M, C]
|
|
|
+ cls_preds = torch.cat(outputs['pred_cls'], dim=1)
|
|
|
+ box_preds = torch.cat(outputs['pred_box'], dim=1)
|
|
|
masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
|
|
|
|
|
|
# -------------------- Label Assignment --------------------
|
|
|
- gt_classes = []
|
|
|
- gt_bboxes = []
|
|
|
- gt_centerness = []
|
|
|
- for batch_idx in range(batch_size):
|
|
|
+ cls_targets = []
|
|
|
+ box_targets = []
|
|
|
+ assign_metrics = []
|
|
|
+ for batch_idx in range(bs):
|
|
|
tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
|
|
|
tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
|
|
|
# refine target
|
|
|
@@ -189,52 +218,41 @@ class SetCriterion(nn.Module):
|
|
|
# label assignment
|
|
|
assigned_result = self.matcher(fpn_strides=fpn_strides,
|
|
|
anchors=anchors,
|
|
|
- pred_cls=pred_cls[batch_idx].detach(),
|
|
|
- pred_box=pred_box[batch_idx].detach(),
|
|
|
- pred_iou=pred_ctn[batch_idx].detach(),
|
|
|
+ pred_cls=cls_preds[batch_idx].detach(),
|
|
|
+ pred_box=box_preds[batch_idx].detach(),
|
|
|
gt_labels=tgt_labels,
|
|
|
gt_bboxes=tgt_bboxes
|
|
|
)
|
|
|
- gt_classes.append(assigned_result['assigned_labels'])
|
|
|
- gt_bboxes.append(assigned_result['assigned_bboxes'])
|
|
|
- gt_centerness.append(assigned_result['assign_metrics'])
|
|
|
+ cls_targets.append(assigned_result['assigned_labels'])
|
|
|
+ box_targets.append(assigned_result['assigned_bboxes'])
|
|
|
+ assign_metrics.append(assigned_result['assign_metrics'])
|
|
|
|
|
|
# List[B, M, C] -> Tensor[BM, C]
|
|
|
- gt_classes = torch.cat(gt_classes, dim=0) # [BM,]
|
|
|
- gt_bboxes = torch.cat(gt_bboxes, dim=0) # [BM, 4]
|
|
|
- gt_centerness = torch.cat(gt_centerness, dim=0) # [BM,]
|
|
|
+ cls_targets = torch.cat(cls_targets, dim=0)
|
|
|
+ box_targets = torch.cat(box_targets, dim=0)
|
|
|
+ assign_metrics = torch.cat(assign_metrics, dim=0)
|
|
|
|
|
|
- valid_idxs = (gt_classes >= 0) & masks
|
|
|
- foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
|
|
|
- num_foreground = foreground_idxs.sum()
|
|
|
+ valid_idxs = (cls_targets >= 0) & masks
|
|
|
+ foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
|
|
|
+ num_fgs = assign_metrics.sum()
|
|
|
|
|
|
if is_dist_avail_and_initialized():
|
|
|
- torch.distributed.all_reduce(num_foreground)
|
|
|
- num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
|
|
|
+ torch.distributed.all_reduce(num_fgs)
|
|
|
+ num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
|
|
|
|
|
|
# -------------------- classification loss --------------------
|
|
|
- pred_cls = pred_cls.view(-1, self.num_classes)
|
|
|
- gt_classes_target = torch.zeros_like(pred_cls)
|
|
|
- gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
|
|
|
- loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_classes_target[valid_idxs], num_foreground)
|
|
|
+ cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
|
|
|
+ qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
|
|
|
+ loss_labels = self.loss_labels_qfl(cls_preds, qfl_targets, 2.0, num_fgs)
|
|
|
|
|
|
# -------------------- regression loss --------------------
|
|
|
- pred_box = pred_box.view(-1, 4)
|
|
|
- pred_box_pos = pred_box[foreground_idxs]
|
|
|
- gt_box_pos = gt_bboxes[foreground_idxs]
|
|
|
- loss_bboxes = self.loss_bboxes_xyxy(pred_box_pos, gt_box_pos, num_foreground)
|
|
|
-
|
|
|
- # -------------------- centerness loss --------------------
|
|
|
- pred_ctn = pred_ctn.view(-1)
|
|
|
- pred_ctn_pos = pred_ctn[foreground_idxs]
|
|
|
- gt_ctn_pos = gt_centerness[foreground_idxs]
|
|
|
- loss_centerness = F.binary_cross_entropy_with_logits(pred_ctn_pos, gt_ctn_pos, reduction='none')
|
|
|
- loss_centerness = loss_centerness.sum() / num_foreground
|
|
|
+ box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
|
|
|
+ box_targets_pos = box_targets[foreground_idxs]
|
|
|
+ loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs)
|
|
|
|
|
|
loss_dict = dict(
|
|
|
loss_cls = loss_labels,
|
|
|
loss_reg = loss_bboxes,
|
|
|
- loss_ctn = loss_centerness,
|
|
|
)
|
|
|
|
|
|
return loss_dict
|