loss.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import torch
  2. import torch.nn.functional as F
  3. from utils.box_ops import get_ious
  4. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  5. from .matcher import SimOtaMatcher
  6. class SetCriterion(object):
  7. def __init__(self, cfg):
  8. self.cfg = cfg
  9. self.num_classes = cfg.num_classes
  10. # --------------- Loss config ---------------
  11. self.loss_cls_weight = cfg.loss_cls
  12. self.loss_box_weight = cfg.loss_box
  13. # --------------- Matcher config ---------------
  14. self.matcher = SimOtaMatcher(soft_center_radius = cfg.ota_soft_center_radius,
  15. topk_candidates = cfg.ota_topk_candidates,
  16. num_classes = cfg.num_classes,
  17. )
  18. def loss_classes(self, pred_cls, target, beta=2.0):
  19. # Quality FocalLoss
  20. """
  21. pred_cls: (torch.Tensor): [N, C]。
  22. target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
  23. """
  24. label, score = target
  25. pred_sigmoid = pred_cls.sigmoid()
  26. scale_factor = pred_sigmoid
  27. zerolabel = scale_factor.new_zeros(pred_cls.shape)
  28. ce_loss = F.binary_cross_entropy_with_logits(
  29. pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
  30. bg_class_ind = pred_cls.shape[-1]
  31. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  32. if pos.shape[0] > 0:
  33. pos_label = label[pos].long()
  34. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  35. ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
  36. pred_cls[pos, pos_label], score[pos],
  37. reduction='none') * scale_factor.abs().pow(beta)
  38. return ce_loss
  39. def loss_bboxes(self, pred_box, gt_box, bbox_weight=None):
  40. ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
  41. loss_box = 1.0 - ious
  42. if bbox_weight is not None:
  43. loss_box = loss_box.squeeze(-1) * bbox_weight
  44. return loss_box
  45. def __call__(self, outputs, targets):
  46. """
  47. outputs['pred_cls']: List(Tensor) [B, M, C]
  48. outputs['pred_reg']: List(Tensor) [B, M, 4]
  49. outputs['pred_box']: List(Tensor) [B, M, 4]
  50. outputs['strides']: List(Int) [8, 16, 32] output stride
  51. targets: (List) [dict{'boxes': [...],
  52. 'labels': [...],
  53. 'orig_size': ...}, ...]
  54. """
  55. bs = outputs['pred_cls'].shape[0]
  56. device = outputs['pred_cls'].device
  57. anchors = outputs['anchors']
  58. stride = outputs['stride']
  59. # preds: [B, M, C]
  60. cls_preds = outputs['pred_cls']
  61. box_preds = outputs['pred_box']
  62. # --------------- label assignment ---------------
  63. cls_targets = []
  64. box_targets = []
  65. assign_metrics = []
  66. for batch_idx in range(bs):
  67. tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
  68. tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
  69. assigned_result = self.matcher(stride=stride,
  70. anchors=anchors[..., :2],
  71. pred_cls=cls_preds[batch_idx].detach(),
  72. pred_box=box_preds[batch_idx].detach(),
  73. gt_labels=tgt_labels,
  74. gt_bboxes=tgt_bboxes
  75. )
  76. cls_targets.append(assigned_result['assigned_labels'])
  77. box_targets.append(assigned_result['assigned_bboxes'])
  78. assign_metrics.append(assigned_result['assign_metrics'])
  79. # List[B, M, C] -> Tensor[BM, C]
  80. cls_targets = torch.cat(cls_targets, dim=0)
  81. box_targets = torch.cat(box_targets, dim=0)
  82. assign_metrics = torch.cat(assign_metrics, dim=0)
  83. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  84. bg_class_ind = self.num_classes
  85. pos_inds = ((cls_targets >= 0) & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
  86. num_fgs = assign_metrics.sum()
  87. if is_dist_avail_and_initialized():
  88. torch.distributed.all_reduce(num_fgs)
  89. num_fgs = (num_fgs / get_world_size()).clamp(1.0).item()
  90. bbox_weight = assign_metrics[pos_inds]
  91. # ------------------ Classification loss ------------------
  92. cls_preds = cls_preds.view(-1, self.num_classes)
  93. loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
  94. loss_cls = loss_cls.sum() / num_fgs
  95. # ------------------ Regression loss ------------------
  96. box_preds_pos = box_preds.view(-1, 4)[pos_inds]
  97. box_targets_pos = box_targets[pos_inds]
  98. loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
  99. loss_box = loss_box.sum() / num_fgs
  100. # total loss
  101. losses = self.loss_cls_weight * loss_cls + \
  102. self.loss_box_weight * loss_box
  103. loss_dict = dict(
  104. loss_cls = loss_cls,
  105. loss_box = loss_box,
  106. losses = losses
  107. )
  108. return loss_dict