loss_utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. import torch.distributed as dist
  5. from torchvision.ops.boxes import box_area
  6. # ------------------------- For loss -------------------------
  7. ## FocalLoss
  8. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  9. """
  10. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  11. Args:
  12. inputs: A float tensor of arbitrary shape.
  13. The predictions for each example.
  14. targets: A float tensor with the same shape as inputs. Stores the binary
  15. classification label for each element in inputs
  16. (0 for the negative class and 1 for the positive class).
  17. alpha: (optional) Weighting factor in range (0,1) to balance
  18. positive vs negative examples. Default = -1 (no weighting).
  19. gamma: Exponent of the modulating factor (1 - p_t) to
  20. balance easy vs hard examples.
  21. Returns:
  22. Loss tensor
  23. """
  24. prob = inputs.sigmoid()
  25. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  26. p_t = prob * targets + (1 - prob) * (1 - targets)
  27. loss = ce_loss * ((1 - p_t) ** gamma)
  28. if alpha >= 0:
  29. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  30. loss = alpha_t * loss
  31. return loss.mean(1).sum() / num_boxes
  32. ## Variable FocalLoss
  33. def varifocal_loss_with_logits(pred_logits,
  34. gt_score,
  35. label,
  36. normalizer=1.0,
  37. alpha=0.75,
  38. gamma=2.0):
  39. pred_score = F.sigmoid(pred_logits)
  40. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  41. loss = F.binary_cross_entropy_with_logits(pred_logits, gt_score, reduction='none')
  42. loss = loss * weight
  43. return loss.mean(1).sum() / normalizer
  44. ## InverseSigmoid
  45. def inverse_sigmoid(x, eps=1e-5):
  46. x = x.clamp(min=0, max=1)
  47. x1 = x.clamp(min=eps)
  48. x2 = (1 - x).clamp(min=eps)
  49. return torch.log(x1/x2)
  50. ## GIoU loss
  51. class GIoULoss(object):
  52. """ Modified GIoULoss from Paddle-Paddle"""
  53. def __init__(self, eps=1e-10, reduction='none'):
  54. self.eps = eps
  55. self.reduction = reduction
  56. assert reduction in ('none', 'mean', 'sum')
  57. def bbox_overlap(self, box1, box2, eps=1e-10):
  58. """calculate the iou of box1 and box2
  59. Args:
  60. box1 (Tensor): box1 with the shape (..., 4)
  61. box2 (Tensor): box1 with the shape (..., 4)
  62. eps (float): epsilon to avoid divide by zero
  63. Return:
  64. iou (Tensor): iou of box1 and box2
  65. overlap (Tensor): overlap of box1 and box2
  66. union (Tensor): union of box1 and box2
  67. """
  68. x1, y1, x2, y2 = box1
  69. x1g, y1g, x2g, y2g = box2
  70. xkis1 = torch.max(x1, x1g)
  71. ykis1 = torch.max(y1, y1g)
  72. xkis2 = torch.min(x2, x2g)
  73. ykis2 = torch.min(y2, y2g)
  74. w_inter = (xkis2 - xkis1).clip(0)
  75. h_inter = (ykis2 - ykis1).clip(0)
  76. overlap = w_inter * h_inter
  77. area1 = (x2 - x1) * (y2 - y1)
  78. area2 = (x2g - x1g) * (y2g - y1g)
  79. union = area1 + area2 - overlap + eps
  80. iou = overlap / union
  81. return iou, overlap, union
  82. def __call__(self, pbox, gbox):
  83. # x1, y1, x2, y2 = torch.split(pbox, 4, dim=-1)
  84. # x1g, y1g, x2g, y2g = torch.split(gbox, 4, dim=-1)
  85. x1, y1, x2, y2 = torch.chunk(pbox, 4, dim=-1)
  86. x1g, y1g, x2g, y2g = torch.chunk(gbox, 4, dim=-1)
  87. box1 = [x1, y1, x2, y2]
  88. box2 = [x1g, y1g, x2g, y2g]
  89. iou, _, union = self.bbox_overlap(box1, box2, self.eps)
  90. xc1 = torch.min(x1, x1g)
  91. yc1 = torch.min(y1, y1g)
  92. xc2 = torch.max(x2, x2g)
  93. yc2 = torch.max(y2, y2g)
  94. area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
  95. miou = iou - ((area_c - union) / area_c)
  96. giou = 1 - miou
  97. if self.reduction == 'none':
  98. loss = giou
  99. elif self.reduction == 'sum':
  100. loss = giou.sum()
  101. elif self.reduction == 'mean':
  102. loss = giou.mean()
  103. return loss
  104. # ------------------------- For box -------------------------
  105. def box_cxcywh_to_xyxy(x):
  106. x_c, y_c, w, h = x.unbind(-1)
  107. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  108. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  109. return torch.stack(b, dim=-1)
  110. def box_xyxy_to_cxcywh(x):
  111. x0, y0, x1, y1 = x.unbind(-1)
  112. b = [(x0 + x1) / 2, (y0 + y1) / 2,
  113. (x1 - x0), (y1 - y0)]
  114. return torch.stack(b, dim=-1)
  115. def box_iou(boxes1, boxes2):
  116. area1 = box_area(boxes1)
  117. area2 = box_area(boxes2)
  118. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  119. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  120. wh = (rb - lt).clamp(min=0) # [N,M,2]
  121. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  122. union = area1[:, None] + area2 - inter
  123. iou = inter / union
  124. return iou, union
  125. def generalized_box_iou(boxes1, boxes2):
  126. """
  127. Generalized IoU from https://giou.stanford.edu/
  128. The boxes should be in [x0, y0, x1, y1] format
  129. Returns a [N, M] pairwise matrix, where N = len(boxes1)
  130. and M = len(boxes2)
  131. """
  132. # degenerate boxes gives inf / nan results
  133. # so do an early check
  134. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  135. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  136. iou, union = box_iou(boxes1, boxes2)
  137. lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  138. rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  139. wh = (rb - lt).clamp(min=0) # [N,M,2]
  140. area = wh[:, :, 0] * wh[:, :, 1]
  141. return iou - (area - union) / area
  142. def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
  143. """Modified from Paddle-paddle
  144. Args:
  145. box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  146. box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  147. giou (bool): whether use giou or not, default False
  148. diou (bool): whether use diou or not, default False
  149. ciou (bool): whether use ciou or not, default False
  150. eps (float): epsilon to avoid divide by zero
  151. Return:
  152. iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
  153. """
  154. px1, py1, px2, py2 = torch.chunk(box1, 4, -1)
  155. gx1, gy1, gx2, gy2 = torch.chunk(box2, 4, -1)
  156. x1 = torch.max(px1, gx1)
  157. y1 = torch.max(py1, gy1)
  158. x2 = torch.min(px2, gx2)
  159. y2 = torch.min(py2, gy2)
  160. overlap = ((x2 - x1).clamp(0)) * ((y2 - y1).clamp(0))
  161. area1 = (px2 - px1) * (py2 - py1)
  162. area1 = area1.clamp(0)
  163. area2 = (gx2 - gx1) * (gy2 - gy1)
  164. area2 = area2.clamp(0)
  165. union = area1 + area2 - overlap + eps
  166. iou = overlap / union
  167. if giou or ciou or diou:
  168. # convex w, h
  169. cw = torch.max(px2, gx2) - torch.min(px1, gx1)
  170. ch = torch.max(py2, gy2) - torch.min(py1, gy1)
  171. if giou:
  172. c_area = cw * ch + eps
  173. return iou - (c_area - union) / c_area
  174. else:
  175. # convex diagonal squared
  176. c2 = cw**2 + ch**2 + eps
  177. # center distance
  178. rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
  179. if diou:
  180. return iou - rho2 / c2
  181. else:
  182. w1, h1 = px2 - px1, py2 - py1 + eps
  183. w2, h2 = gx2 - gx1, gy2 - gy1 + eps
  184. delta = torch.atan(w1 / h1) - torch.atan(w2 / h2)
  185. v = (4 / math.pi**2) * torch.pow(delta, 2)
  186. alpha = v / (1 + eps - iou + v)
  187. alpha.requires_grad_ = False
  188. return iou - (rho2 / c2 + v * alpha)
  189. else:
  190. return iou
  191. # ------------------------- For distributed -------------------------
  192. def is_dist_avail_and_initialized():
  193. if not dist.is_available():
  194. return False
  195. if not dist.is_initialized():
  196. return False
  197. return True
  198. def get_world_size():
  199. if not is_dist_avail_and_initialized():
  200. return 1
  201. return dist.get_world_size()