loss_utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. import torch.distributed as dist
  5. from torchvision.ops.boxes import box_area
  6. # ------------------------- For loss -------------------------
  7. ## FocalLoss
  8. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  9. """
  10. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  11. Args:
  12. inputs: A float tensor of arbitrary shape.
  13. The predictions for each example.
  14. targets: A float tensor with the same shape as inputs. Stores the binary
  15. classification label for each element in inputs
  16. (0 for the negative class and 1 for the positive class).
  17. alpha: (optional) Weighting factor in range (0,1) to balance
  18. positive vs negative examples. Default = -1 (no weighting).
  19. gamma: Exponent of the modulating factor (1 - p_t) to
  20. balance easy vs hard examples.
  21. Returns:
  22. Loss tensor
  23. """
  24. prob = inputs.sigmoid()
  25. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  26. p_t = prob * targets + (1 - prob) * (1 - targets)
  27. loss = ce_loss * ((1 - p_t) ** gamma)
  28. if alpha >= 0:
  29. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  30. loss = alpha_t * loss
  31. return loss.mean(1).sum() / num_boxes
  32. # ------------------------- For box -------------------------
  33. def box_cxcywh_to_xyxy(x):
  34. x_c, y_c, w, h = x.unbind(-1)
  35. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  36. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  37. return torch.stack(b, dim=-1)
  38. def box_xyxy_to_cxcywh(x):
  39. x0, y0, x1, y1 = x.unbind(-1)
  40. b = [(x0 + x1) / 2, (y0 + y1) / 2,
  41. (x1 - x0), (y1 - y0)]
  42. return torch.stack(b, dim=-1)
  43. def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
  44. # hack for matcher
  45. if proposals.size() != gt.size():
  46. proposals = proposals[:, None]
  47. gt = gt[None]
  48. proposals = proposals.float()
  49. gt = gt.float()
  50. px, py, pw, ph = proposals.unbind(-1)
  51. gx, gy, gw, gh = gt.unbind(-1)
  52. dx = (gx - px) / (pw + 0.1)
  53. dy = (gy - py) / (ph + 0.1)
  54. dw = torch.log(gw / (pw + 0.1))
  55. dh = torch.log(gh / (ph + 0.1))
  56. deltas = torch.stack([dx, dy, dw, dh], dim=-1)
  57. means = deltas.new_tensor(means).unsqueeze(0)
  58. stds = deltas.new_tensor(stds).unsqueeze(0)
  59. deltas = deltas.sub_(means).div_(stds)
  60. return deltas
  61. def box_iou(boxes1, boxes2):
  62. area1 = box_area(boxes1)
  63. area2 = box_area(boxes2)
  64. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  65. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  66. wh = (rb - lt).clamp(min=0) # [N,M,2]
  67. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  68. union = area1[:, None] + area2 - inter
  69. iou = inter / union
  70. return iou, union
  71. def generalized_box_iou(boxes1, boxes2):
  72. """
  73. Generalized IoU from https://giou.stanford.edu/
  74. The boxes should be in [x0, y0, x1, y1] format
  75. Returns a [N, M] pairwise matrix, where N = len(boxes1)
  76. and M = len(boxes2)
  77. """
  78. # degenerate boxes gives inf / nan results
  79. # so do an early check
  80. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  81. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  82. iou, union = box_iou(boxes1, boxes2)
  83. lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  84. rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  85. wh = (rb - lt).clamp(min=0) # [N,M,2]
  86. area = wh[:, :, 0] * wh[:, :, 1]
  87. return iou - (area - union) / area
  88. # ------------------------- For distributed -------------------------
  89. def is_dist_avail_and_initialized():
  90. if not dist.is_available():
  91. return False
  92. if not dist.is_initialized():
  93. return False
  94. return True
  95. def get_world_size():
  96. if not is_dist_avail_and_initialized():
  97. return 1
  98. return dist.get_world_size()