|
@@ -2,42 +2,33 @@ import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
-from utils.box_ops import get_ious
|
|
|
|
|
from utils.misc import sigmoid_focal_loss
|
|
from utils.misc import sigmoid_focal_loss
|
|
|
from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
|
|
from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
|
|
|
|
|
|
|
|
-from .matcher import FcosMatcher, AlignedOTAMatcher
|
|
|
|
|
|
|
+from .matcher import FcosMatcher
|
|
|
|
|
|
|
|
|
|
|
|
|
-class SetCriterion(nn.Module):
|
|
|
|
|
|
|
+class SetCriterion(object):
|
|
|
def __init__(self, cfg):
|
|
def __init__(self, cfg):
|
|
|
- super().__init__()
|
|
|
|
|
# ------------- Basic parameters -------------
|
|
# ------------- Basic parameters -------------
|
|
|
self.cfg = cfg
|
|
self.cfg = cfg
|
|
|
self.num_classes = cfg.num_classes
|
|
self.num_classes = cfg.num_classes
|
|
|
|
|
+
|
|
|
# ------------- Focal loss -------------
|
|
# ------------- Focal loss -------------
|
|
|
self.alpha = cfg.focal_loss_alpha
|
|
self.alpha = cfg.focal_loss_alpha
|
|
|
self.gamma = cfg.focal_loss_gamma
|
|
self.gamma = cfg.focal_loss_gamma
|
|
|
|
|
+
|
|
|
# ------------- Loss weight -------------
|
|
# ------------- Loss weight -------------
|
|
|
- # ------------- Matcher & Loss weight -------------
|
|
|
|
|
- self.matcher_cfg = cfg.matcher_hpy
|
|
|
|
|
- if cfg.matcher == 'fcos_matcher':
|
|
|
|
|
- self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
|
|
|
|
|
- 'loss_reg': cfg.loss_reg_weight,
|
|
|
|
|
- 'loss_ctn': cfg.loss_ctn_weight}
|
|
|
|
|
- self.matcher = FcosMatcher(cfg.num_classes,
|
|
|
|
|
- self.matcher_cfg['center_sampling_radius'],
|
|
|
|
|
- self.matcher_cfg['object_sizes_of_interest'],
|
|
|
|
|
- [1., 1., 1., 1.]
|
|
|
|
|
- )
|
|
|
|
|
- elif cfg.matcher == 'simota':
|
|
|
|
|
- self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
|
|
|
|
|
- 'loss_reg': cfg.loss_reg_weight}
|
|
|
|
|
- self.matcher = AlignedOTAMatcher(cfg.num_classes,
|
|
|
|
|
- self.matcher_cfg['soft_center_radius'],
|
|
|
|
|
- self.matcher_cfg['topk_candidates'])
|
|
|
|
|
- else:
|
|
|
|
|
- raise NotImplementedError("Unknown matcher: {}.".format(cfg.matcher))
|
|
|
|
|
|
|
+ self.weight_dict = {'loss_cls': cfg.loss_cls,
|
|
|
|
|
+ 'loss_reg': cfg.loss_reg,
|
|
|
|
|
+ 'loss_ctn': cfg.loss_ctn,}
|
|
|
|
|
+
|
|
|
|
|
+ # ------------- Matcher -------------
|
|
|
|
|
+ self.matcher = FcosMatcher(cfg.num_classes,
|
|
|
|
|
+ center_sampling_radius=cfg.center_sampling_radius,
|
|
|
|
|
+ object_sizes_of_interest=cfg.object_sizes_of_interest,
|
|
|
|
|
+ box_weights=[1., 1., 1., 1.],
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
def loss_labels(self, pred_cls, tgt_cls, num_boxes=1.0):
|
|
def loss_labels(self, pred_cls, tgt_cls, num_boxes=1.0):
|
|
|
"""
|
|
"""
|
|
@@ -49,34 +40,7 @@ class SetCriterion(nn.Module):
|
|
|
|
|
|
|
|
return loss_cls.sum() / num_boxes
|
|
return loss_cls.sum() / num_boxes
|
|
|
|
|
|
|
|
- def loss_labels_qfl(self, pred_cls, target, beta=2.0, num_boxes=1.0):
|
|
|
|
|
- # Quality FocalLoss
|
|
|
|
|
- """
|
|
|
|
|
- pred_cls: (torch.Tensor): [N, C]。
|
|
|
|
|
- target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
|
|
|
|
|
- """
|
|
|
|
|
- label, score = target
|
|
|
|
|
- pred_sigmoid = pred_cls.sigmoid()
|
|
|
|
|
- scale_factor = pred_sigmoid
|
|
|
|
|
- zerolabel = scale_factor.new_zeros(pred_cls.shape)
|
|
|
|
|
-
|
|
|
|
|
- ce_loss = F.binary_cross_entropy_with_logits(
|
|
|
|
|
- pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
|
|
|
|
|
-
|
|
|
|
|
- bg_class_ind = pred_cls.shape[-1]
|
|
|
|
|
- pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
|
|
|
|
|
- if pos.shape[0] > 0:
|
|
|
|
|
- pos_label = label[pos].long()
|
|
|
|
|
-
|
|
|
|
|
- scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
|
|
|
|
|
-
|
|
|
|
|
- ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
|
|
|
|
|
- pred_cls[pos, pos_label], score[pos],
|
|
|
|
|
- reduction='none') * scale_factor.abs().pow(beta)
|
|
|
|
|
-
|
|
|
|
|
- return ce_loss.sum() / num_boxes
|
|
|
|
|
-
|
|
|
|
|
- def loss_bboxes_ltrb(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
|
|
|
|
|
|
|
+ def loss_bboxes(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
|
|
|
"""
|
|
"""
|
|
|
pred_box: (Tensor) [N, 4]
|
|
pred_box: (Tensor) [N, 4]
|
|
|
tgt_box: (Tensor) [N, 4]
|
|
tgt_box: (Tensor) [N, 4]
|
|
@@ -114,16 +78,7 @@ class SetCriterion(nn.Module):
|
|
|
|
|
|
|
|
return loss_box.sum() / num_boxes
|
|
return loss_box.sum() / num_boxes
|
|
|
|
|
|
|
|
- def loss_bboxes_xyxy(self, pred_box, gt_box, num_boxes=1.0, box_weight=None):
|
|
|
|
|
- ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
|
|
|
|
|
- loss_box = 1.0 - ious
|
|
|
|
|
-
|
|
|
|
|
- if box_weight is not None:
|
|
|
|
|
- loss_box = loss_box.squeeze(-1) * box_weight
|
|
|
|
|
-
|
|
|
|
|
- return loss_box.sum() / num_boxes
|
|
|
|
|
-
|
|
|
|
|
- def fcos_loss(self, outputs, targets):
|
|
|
|
|
|
|
+ def __call__(self, outputs, targets):
|
|
|
"""
|
|
"""
|
|
|
outputs['pred_cls']: (Tensor) [B, M, C]
|
|
outputs['pred_cls']: (Tensor) [B, M, C]
|
|
|
outputs['pred_reg']: (Tensor) [B, M, 4]
|
|
outputs['pred_reg']: (Tensor) [B, M, 4]
|
|
@@ -137,10 +92,10 @@ class SetCriterion(nn.Module):
|
|
|
device = outputs['pred_cls'][0].device
|
|
device = outputs['pred_cls'][0].device
|
|
|
fpn_strides = outputs['strides']
|
|
fpn_strides = outputs['strides']
|
|
|
anchors = outputs['anchors']
|
|
anchors = outputs['anchors']
|
|
|
- pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
|
|
|
pred_delta = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4)
|
|
pred_delta = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4)
|
|
|
- pred_ctn = torch.cat(outputs['pred_ctn'], dim=1).view(-1, 1)
|
|
|
|
|
- masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
|
|
|
|
|
|
|
+ pred_ctn = torch.cat(outputs['pred_ctn'], dim=1).view(-1, 1)
|
|
|
|
|
|
|
|
# -------------------- Label Assignment --------------------
|
|
# -------------------- Label Assignment --------------------
|
|
|
gt_classes, gt_deltas, gt_centerness = self.matcher(fpn_strides, anchors, targets)
|
|
gt_classes, gt_deltas, gt_centerness = self.matcher(fpn_strides, anchors, targets)
|
|
@@ -148,33 +103,31 @@ class SetCriterion(nn.Module):
|
|
|
gt_deltas = gt_deltas.view(-1, 4).to(device)
|
|
gt_deltas = gt_deltas.view(-1, 4).to(device)
|
|
|
gt_centerness = gt_centerness.view(-1, 1).to(device)
|
|
gt_centerness = gt_centerness.view(-1, 1).to(device)
|
|
|
|
|
|
|
|
- foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
|
|
|
|
|
- num_foreground = foreground_idxs.sum()
|
|
|
|
|
|
|
+ fg_masks = (gt_classes >= 0) & (gt_classes != self.num_classes)
|
|
|
|
|
+ num_fgs = fg_masks.sum()
|
|
|
|
|
|
|
|
if is_dist_avail_and_initialized():
|
|
if is_dist_avail_and_initialized():
|
|
|
- torch.distributed.all_reduce(num_foreground)
|
|
|
|
|
- num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
|
|
|
|
|
|
|
+ torch.distributed.all_reduce(num_fgs)
|
|
|
|
|
+ num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
|
|
|
|
|
|
|
|
- num_foreground_centerness = gt_centerness[foreground_idxs].sum()
|
|
|
|
|
|
|
+ num_fgs_ctn = gt_centerness[fg_masks].sum()
|
|
|
if is_dist_avail_and_initialized():
|
|
if is_dist_avail_and_initialized():
|
|
|
- torch.distributed.all_reduce(num_foreground_centerness)
|
|
|
|
|
- num_targets = torch.clamp(num_foreground_centerness / get_world_size(), min=1).item()
|
|
|
|
|
|
|
+ torch.distributed.all_reduce(num_fgs_ctn)
|
|
|
|
|
+ num_targets = torch.clamp(num_fgs_ctn / get_world_size(), min=1).item()
|
|
|
|
|
|
|
|
# -------------------- classification loss --------------------
|
|
# -------------------- classification loss --------------------
|
|
|
gt_classes_target = torch.zeros_like(pred_cls)
|
|
gt_classes_target = torch.zeros_like(pred_cls)
|
|
|
- gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
|
|
|
|
|
- valid_idxs = (gt_classes >= 0) & masks
|
|
|
|
|
- loss_labels = self.loss_labels(
|
|
|
|
|
- pred_cls[valid_idxs], gt_classes_target[valid_idxs], num_foreground)
|
|
|
|
|
|
|
+ gt_classes_target[fg_masks, gt_classes[fg_masks]] = 1
|
|
|
|
|
+ loss_labels = self.loss_labels(pred_cls, gt_classes_target, num_fgs)
|
|
|
|
|
|
|
|
# -------------------- regression loss --------------------
|
|
# -------------------- regression loss --------------------
|
|
|
- loss_bboxes = self.loss_bboxes_ltrb(
|
|
|
|
|
- pred_delta[foreground_idxs], gt_deltas[foreground_idxs], gt_centerness[foreground_idxs], num_targets)
|
|
|
|
|
|
|
+ loss_bboxes = self.loss_bboxes(
|
|
|
|
|
+ pred_delta[fg_masks], gt_deltas[fg_masks], gt_centerness[fg_masks], num_targets)
|
|
|
|
|
|
|
|
# -------------------- centerness loss --------------------
|
|
# -------------------- centerness loss --------------------
|
|
|
loss_centerness = F.binary_cross_entropy_with_logits(
|
|
loss_centerness = F.binary_cross_entropy_with_logits(
|
|
|
- pred_ctn[foreground_idxs], gt_centerness[foreground_idxs], reduction='none')
|
|
|
|
|
- loss_centerness = loss_centerness.sum() / num_foreground
|
|
|
|
|
|
|
+ pred_ctn[fg_masks], gt_centerness[fg_masks], reduction='none')
|
|
|
|
|
+ loss_centerness = loss_centerness.sum() / num_fgs
|
|
|
|
|
|
|
|
total_loss = loss_labels * self.weight_dict["loss_cls"] + \
|
|
total_loss = loss_labels * self.weight_dict["loss_cls"] + \
|
|
|
loss_bboxes * self.weight_dict["loss_reg"] + \
|
|
loss_bboxes * self.weight_dict["loss_reg"] + \
|
|
@@ -187,104 +140,3 @@ class SetCriterion(nn.Module):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
return loss_dict
|
|
return loss_dict
|
|
|
-
|
|
|
|
|
- def ota_loss(self, outputs, targets):
|
|
|
|
|
- """
|
|
|
|
|
- outputs['pred_cls']: (Tensor) [B, M, C]
|
|
|
|
|
- outputs['pred_reg']: (Tensor) [B, M, 4]
|
|
|
|
|
- outputs['pred_box']: (Tensor) [B, M, 4]
|
|
|
|
|
- outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
|
|
|
|
|
- targets: (List) [dict{'boxes': [...],
|
|
|
|
|
- 'labels': [...],
|
|
|
|
|
- 'orig_size': ...}, ...]
|
|
|
|
|
- """
|
|
|
|
|
- # -------------------- Pre-process --------------------
|
|
|
|
|
- bs = outputs['pred_cls'][0].shape[0]
|
|
|
|
|
- device = outputs['pred_cls'][0].device
|
|
|
|
|
- fpn_strides = outputs['strides']
|
|
|
|
|
- anchors = outputs['anchors']
|
|
|
|
|
- # preds: [B, M, C]
|
|
|
|
|
- # preds: [B, M, C]
|
|
|
|
|
- cls_preds = torch.cat(outputs['pred_cls'], dim=1)
|
|
|
|
|
- box_preds = torch.cat(outputs['pred_box'], dim=1)
|
|
|
|
|
- masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
|
|
|
|
|
-
|
|
|
|
|
- # -------------------- Label Assignment --------------------
|
|
|
|
|
- cls_targets = []
|
|
|
|
|
- box_targets = []
|
|
|
|
|
- assign_metrics = []
|
|
|
|
|
- for batch_idx in range(bs):
|
|
|
|
|
- tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
|
|
|
|
|
- tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
|
|
|
|
|
- # refine target
|
|
|
|
|
- tgt_boxes_wh = tgt_bboxes[..., 2:] - tgt_bboxes[..., :2]
|
|
|
|
|
- min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
|
|
|
|
|
- keep = (min_tgt_size >= 8)
|
|
|
|
|
- tgt_bboxes = tgt_bboxes[keep]
|
|
|
|
|
- tgt_labels = tgt_labels[keep]
|
|
|
|
|
- # label assignment
|
|
|
|
|
- assigned_result = self.matcher(fpn_strides=fpn_strides,
|
|
|
|
|
- anchors=anchors,
|
|
|
|
|
- pred_cls=cls_preds[batch_idx].detach(),
|
|
|
|
|
- pred_box=box_preds[batch_idx].detach(),
|
|
|
|
|
- gt_labels=tgt_labels,
|
|
|
|
|
- gt_bboxes=tgt_bboxes
|
|
|
|
|
- )
|
|
|
|
|
- cls_targets.append(assigned_result['assigned_labels'])
|
|
|
|
|
- box_targets.append(assigned_result['assigned_bboxes'])
|
|
|
|
|
- assign_metrics.append(assigned_result['assign_metrics'])
|
|
|
|
|
-
|
|
|
|
|
- # List[B, M, C] -> Tensor[BM, C]
|
|
|
|
|
- cls_targets = torch.cat(cls_targets, dim=0)
|
|
|
|
|
- box_targets = torch.cat(box_targets, dim=0)
|
|
|
|
|
- assign_metrics = torch.cat(assign_metrics, dim=0)
|
|
|
|
|
-
|
|
|
|
|
- valid_idxs = (cls_targets >= 0) & masks
|
|
|
|
|
- foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
|
|
|
|
|
- num_fgs = assign_metrics.sum()
|
|
|
|
|
-
|
|
|
|
|
- if is_dist_avail_and_initialized():
|
|
|
|
|
- torch.distributed.all_reduce(num_fgs)
|
|
|
|
|
- num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
|
|
|
|
|
-
|
|
|
|
|
- # -------------------- classification loss --------------------
|
|
|
|
|
- cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
|
|
|
|
|
- qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
|
|
|
|
|
- loss_labels = self.loss_labels_qfl(cls_preds, qfl_targets, 2.0, num_fgs)
|
|
|
|
|
-
|
|
|
|
|
- # -------------------- regression loss --------------------
|
|
|
|
|
- box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
|
|
|
|
|
- box_targets_pos = box_targets[foreground_idxs]
|
|
|
|
|
- box_weight = assign_metrics[foreground_idxs]
|
|
|
|
|
- loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs, box_weight)
|
|
|
|
|
-
|
|
|
|
|
- total_loss = loss_labels * self.weight_dict["loss_cls"] + \
|
|
|
|
|
- loss_bboxes * self.weight_dict["loss_reg"]
|
|
|
|
|
- loss_dict = dict(
|
|
|
|
|
- loss_cls = loss_labels,
|
|
|
|
|
- loss_reg = loss_bboxes,
|
|
|
|
|
- losses = total_loss,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- return loss_dict
|
|
|
|
|
-
|
|
|
|
|
- def forward(self, outputs, targets):
|
|
|
|
|
- """
|
|
|
|
|
- outputs['pred_cls']: (Tensor) [B, M, C]
|
|
|
|
|
- outputs['pred_reg']: (Tensor) [B, M, 4]
|
|
|
|
|
- outputs['pred_ctn']: (Tensor) [B, M, 1]
|
|
|
|
|
- outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
|
|
|
|
|
- targets: (List) [dict{'boxes': [...],
|
|
|
|
|
- 'labels': [...],
|
|
|
|
|
- 'orig_size': ...}, ...]
|
|
|
|
|
- """
|
|
|
|
|
- if self.cfg.matcher == "fcos_matcher":
|
|
|
|
|
- return self.fcos_loss(outputs, targets)
|
|
|
|
|
- elif self.cfg.matcher == "simota":
|
|
|
|
|
- return self.ota_loss(outputs, targets)
|
|
|
|
|
- else:
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
|
|
- pass
|
|
|