loss.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import torch
  2. import torch.nn.functional as F
  3. try:
  4. from .loss_utils import get_ious, get_world_size, is_dist_avail_and_initialized
  5. from .matcher import AlignedSimOtaMatcher
  6. except:
  7. from loss_utils import get_ious, get_world_size, is_dist_avail_and_initialized
  8. from matcher import AlignedSimOtaMatcher
  9. class Criterion(object):
  10. def __init__(self, cfg, num_classes=80):
  11. # ------------ Basic parameters ------------
  12. self.cfg = cfg
  13. self.num_classes = num_classes
  14. # --------------- Matcher config ---------------
  15. self.matcher_hpy = cfg['matcher_hpy']
  16. self.matcher = AlignedSimOtaMatcher(soft_center_radius = self.matcher_hpy['soft_center_radius'],
  17. topk_candidates = self.matcher_hpy['topk_candidates'],
  18. num_classes = num_classes,
  19. )
  20. # ------------- Loss weight -------------
  21. self.weight_dict = {'loss_cls': cfg['loss_coeff']['class'],
  22. 'loss_box': cfg['loss_coeff']['bbox'],
  23. 'loss_giou': cfg['loss_coeff']['giou']}
  24. def loss_classes(self, pred_cls, target, num_gts, beta=2.0):
  25. # Quality FocalLoss
  26. """
  27. pred_cls: (torch.Tensor): [N, C]。
  28. target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
  29. """
  30. label, score = target
  31. pred_sigmoid = pred_cls.sigmoid()
  32. scale_factor = pred_sigmoid
  33. zerolabel = scale_factor.new_zeros(pred_cls.shape)
  34. ce_loss = F.binary_cross_entropy_with_logits(
  35. pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
  36. bg_class_ind = pred_cls.shape[-1]
  37. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  38. pos_label = label[pos].long()
  39. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  40. ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
  41. pred_cls[pos, pos_label], score[pos],
  42. reduction='none') * scale_factor.abs().pow(beta)
  43. losses = {}
  44. losses['loss_cls'] = ce_loss.sum() / num_gts
  45. return losses
  46. def loss_bboxes(self, pred_reg, pred_box, gt_box, anchors, stride_tensors, num_gts):
  47. # --------------- Compute L1 loss ---------------
  48. ## xyxy -> cxcy&bwbh
  49. gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
  50. gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
  51. ## Encode gt box
  52. gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
  53. gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
  54. gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
  55. # L1 loss
  56. loss_box = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
  57. # --------------- Compute GIoU loss ---------------
  58. gious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
  59. loss_giou = 1.0 - gious
  60. losses = {}
  61. losses['loss_box'] = loss_box.sum() / num_gts
  62. losses['loss_giou'] = loss_giou.sum() / num_gts
  63. return losses
  64. def __call__(self, outputs, targets):
  65. """
  66. outputs['pred_cls']: List(Tensor) [B, M, C]
  67. outputs['pred_box']: List(Tensor) [B, M, 4]
  68. outputs['pred_box']: List(Tensor) [B, M, 4]
  69. outputs['strides']: List(Int) [8, 16, 32] output stride
  70. targets: (List) [dict{'boxes': [...],
  71. 'labels': [...],
  72. 'orig_size': ...}, ...]
  73. """
  74. bs = outputs['pred_cls'][0].shape[0]
  75. device = outputs['pred_cls'][0].device
  76. anchors = outputs['anchors']
  77. fpn_strides = outputs['strides']
  78. stride_tensors = outputs['stride_tensors']
  79. losses = dict()
  80. # preds: [B, M, C]
  81. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  82. box_preds = torch.cat(outputs['pred_box'], dim=1)
  83. reg_preds = torch.cat(outputs['pred_reg'], dim=1)
  84. # --------------- label assignment ---------------
  85. cls_targets = []
  86. box_targets = []
  87. assign_metrics = []
  88. for batch_idx in range(bs):
  89. tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
  90. tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
  91. assigned_result = self.matcher(fpn_strides=fpn_strides,
  92. anchors=anchors,
  93. pred_cls=cls_preds[batch_idx].detach(),
  94. pred_box=box_preds[batch_idx].detach(),
  95. gt_labels=tgt_labels,
  96. gt_bboxes=tgt_bboxes
  97. )
  98. cls_targets.append(assigned_result['assigned_labels'])
  99. box_targets.append(assigned_result['assigned_bboxes'])
  100. assign_metrics.append(assigned_result['assign_metrics'])
  101. # List[B, M, C] -> Tensor[BM, C]
  102. cls_targets = torch.cat(cls_targets, dim=0)
  103. box_targets = torch.cat(box_targets, dim=0)
  104. assign_metrics = torch.cat(assign_metrics, dim=0)
  105. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  106. bg_class_ind = self.num_classes
  107. pos_inds = ((cls_targets >= 0) & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
  108. num_fgs = assign_metrics.sum()
  109. if is_dist_avail_and_initialized():
  110. torch.distributed.all_reduce(num_fgs)
  111. num_fgs = (num_fgs / get_world_size()).clamp(1.0).item()
  112. # ------------------ Classification loss ------------------
  113. cls_preds = cls_preds.view(-1, self.num_classes)
  114. loss_dict = self.loss_classes(cls_preds, (cls_targets, assign_metrics), num_fgs)
  115. loss_dict = {k: loss_dict[k] * self.weight_dict[k] for k in loss_dict if k in self.weight_dict}
  116. losses.update(loss_dict)
  117. # ------------------ Regression loss ------------------
  118. box_targets_pos = box_targets[pos_inds]
  119. ## positive predictions
  120. box_preds_pos = box_preds.view(-1, 4)[pos_inds]
  121. reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
  122. ## anchor tensors
  123. anchors_tensors = torch.cat(anchors, dim=0)[None].repeat(bs, 1, 1)
  124. anchors_tensors_pos = anchors_tensors.view(-1, 2)[pos_inds]
  125. ## stride tensors
  126. stride_tensors = torch.cat(stride_tensors, dim=0)[None].repeat(bs, 1, 1)
  127. stride_tensors_pos = stride_tensors.view(-1, 1)[pos_inds]
  128. ## aux loss
  129. loss_dict = self.loss_bboxes(reg_preds_pos, box_preds_pos, box_targets_pos, anchors_tensors_pos, stride_tensors_pos, num_fgs)
  130. loss_dict = {k: loss_dict[k] * self.weight_dict[k] for k in loss_dict if k in self.weight_dict}
  131. losses.update(loss_dict)
  132. return losses
  133. def build_criterion(cfg, num_classes):
  134. criterion = Criterion(cfg, num_classes)
  135. return criterion