loss.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import torch
  2. import torch.nn.functional as F
  3. from .matcher import AlignedSimOTA
  4. from utils.box_ops import get_ious
  5. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  6. class Criterion(object):
  7. def __init__(self,
  8. cfg,
  9. device,
  10. num_classes=80):
  11. self.cfg = cfg
  12. self.device = device
  13. self.num_classes = num_classes
  14. # loss weight
  15. self.loss_cls_weight = cfg['loss_cls_weight']
  16. self.loss_box_weight = cfg['loss_box_weight']
  17. # matcher
  18. matcher_config = cfg['matcher']
  19. self.matcher = AlignedSimOTA(
  20. num_classes=num_classes,
  21. soft_center_radius=matcher_config['soft_center_radius'],
  22. topk=matcher_config['topk_candicate'],
  23. iou_weight=matcher_config['iou_weight']
  24. )
  25. def loss_classes(self, pred_cls, target, beta=2.0):
  26. """
  27. Quality Focal Loss
  28. pred_cls: (torch.Tensor): [N, C]。
  29. target: (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N,)
  30. """
  31. label, score = target
  32. pred_sigmoid = pred_cls.sigmoid()
  33. scale_factor = pred_sigmoid
  34. zerolabel = scale_factor.new_zeros(pred_cls.shape)
  35. ce_loss = F.binary_cross_entropy_with_logits(
  36. pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
  37. bg_class_ind = pred_cls.shape[-1]
  38. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  39. pos_label = label[pos].long()
  40. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  41. ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
  42. pred_cls[pos, pos_label], score[pos],
  43. reduction='none') * scale_factor.abs().pow(beta)
  44. return ce_loss
  45. def loss_bboxes(self, pred_box, gt_box):
  46. # regression loss
  47. ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
  48. loss_box = 1.0 - ious
  49. return loss_box
  50. def __call__(self, outputs, targets):
  51. """
  52. outputs['pred_cls']: List(Tensor) [B, M, C]
  53. outputs['pred_box']: List(Tensor) [B, M, 4]
  54. outputs['strides']: List(Int) [8, 16, 32] output stride
  55. targets: (List) [dict{'boxes': [...],
  56. 'labels': [...],
  57. 'orig_size': ...}, ...]
  58. """
  59. bs = outputs['pred_cls'][0].shape[0]
  60. device = outputs['pred_cls'][0].device
  61. fpn_strides = outputs['strides']
  62. anchors = outputs['anchors']
  63. # preds: [B, M, C]
  64. cls_preds = torch.cat(outputs['pred_cls'], dim=1)
  65. box_preds = torch.cat(outputs['pred_box'], dim=1)
  66. cls_targets = []
  67. box_targets = []
  68. assign_metrics = []
  69. for batch_idx in range(bs):
  70. tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
  71. tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
  72. # label assignment
  73. assigned_result = self.matcher(fpn_strides=fpn_strides,
  74. anchors=anchors,
  75. pred_cls=cls_preds[batch_idx].detach(),
  76. pred_box=box_preds[batch_idx].detach(),
  77. gt_labels=tgt_labels,
  78. gt_bboxes=tgt_bboxes
  79. )
  80. cls_targets.append(assigned_result['assigned_labels'])
  81. box_targets.append(assigned_result['assigned_bboxes'])
  82. assign_metrics.append(assigned_result['assign_metrics'])
  83. cls_targets = torch.cat(cls_targets, dim=0)
  84. box_targets = torch.cat(box_targets, dim=0)
  85. assign_metrics = torch.cat(assign_metrics, dim=0)
  86. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  87. bg_class_ind = self.num_classes
  88. pos_inds = ((cls_targets >= 0)
  89. & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
  90. # num_fgs = assign_metrics.sum()
  91. num_fgs = pos_inds.size(0)
  92. if is_dist_avail_and_initialized():
  93. torch.distributed.all_reduce(num_fgs)
  94. num_fgs = max(num_fgs / get_world_size(), 1.0)
  95. # cls loss
  96. cls_preds = cls_preds.view(-1, self.num_classes)
  97. loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
  98. loss_cls = loss_cls.sum() / num_fgs
  99. # regression loss
  100. box_preds_pos = box_preds.view(-1, 4)[pos_inds]
  101. box_targets_pos = box_targets[pos_inds]
  102. loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
  103. loss_box = loss_box.sum() / box_preds_pos.shape[0]
  104. # total loss
  105. losses = self.loss_cls_weight * loss_cls + \
  106. self.loss_box_weight * loss_box
  107. loss_dict = dict(
  108. loss_cls = loss_cls,
  109. loss_box = loss_box,
  110. losses = losses
  111. )
  112. return loss_dict
  113. def build_criterion(cfg, device, num_classes):
  114. criterion = Criterion(
  115. cfg=cfg,
  116. device=device,
  117. num_classes=num_classes
  118. )
  119. return criterion
  120. if __name__ == "__main__":
  121. pass