box_ops.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from typing import List
  2. import math
  3. import numpy as np
  4. import torch
  5. from torchvision.ops.boxes import box_area
  6. # ------------------ Box ops ------------------
  7. def box_cxcywh_to_xyxy(x):
  8. x_c, y_c, w, h = x.unbind(-1)
  9. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  10. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  11. return torch.stack(b, dim=-1)
  12. def box_xyxy_to_cxcywh(x):
  13. x0, y0, x1, y1 = x.unbind(-1)
  14. b = [(x0 + x1) / 2, (y0 + y1) / 2,
  15. (x1 - x0), (y1 - y0)]
  16. return torch.stack(b, dim=-1)
  17. def rescale_bboxes(bboxes, origin_size, ratio):
  18. # rescale bboxes
  19. if isinstance(ratio, float):
  20. bboxes /= ratio
  21. elif isinstance(ratio, List) and len(ratio) == 2:
  22. bboxes[..., [0, 2]] /= ratio[0]
  23. bboxes[..., [1, 3]] /= ratio[1]
  24. else:
  25. raise NotImplementedError("ratio should be a int or List[int, int] type.")
  26. # clip bboxes
  27. bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_size[0])
  28. bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_size[1])
  29. return bboxes
  30. def bbox2dist(anchor_points, bbox, reg_max):
  31. '''Transform bbox(xyxy) to dist(ltrb).'''
  32. x1y1, x2y2 = torch.split(bbox, 2, -1)
  33. lt = anchor_points - x1y1
  34. rb = x2y2 - anchor_points
  35. dist = torch.cat([lt, rb], -1).clamp(0, reg_max - 0.01)
  36. return dist
  37. def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
  38. # hack for matcher
  39. if proposals.size() != gt.size():
  40. proposals = proposals[:, None]
  41. gt = gt[None]
  42. proposals = proposals.float()
  43. gt = gt.float()
  44. px, py, pw, ph = proposals.unbind(-1)
  45. gx, gy, gw, gh = gt.unbind(-1)
  46. dx = (gx - px) / (pw + 0.1)
  47. dy = (gy - py) / (ph + 0.1)
  48. dw = torch.log(gw / (pw + 0.1))
  49. dh = torch.log(gh / (ph + 0.1))
  50. deltas = torch.stack([dx, dy, dw, dh], dim=-1)
  51. means = deltas.new_tensor(means).unsqueeze(0)
  52. stds = deltas.new_tensor(stds).unsqueeze(0)
  53. deltas = deltas.sub_(means).div_(stds)
  54. return deltas
  55. # ------------------ IoU ops ------------------
  56. def box_iou(boxes1, boxes2):
  57. area1 = box_area(boxes1)
  58. area2 = box_area(boxes2)
  59. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  60. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  61. wh = (rb - lt).clamp(min=0) # [N,M,2]
  62. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  63. union = area1[:, None] + area2 - inter
  64. iou = inter / union
  65. return iou, union
  66. def generalized_box_iou(boxes1, boxes2):
  67. """
  68. Generalized IoU from https://giou.stanford.edu/
  69. The boxes should be in [x0, y0, x1, y1] format
  70. Returns a [N, M] pairwise matrix, where N = len(boxes1)
  71. and M = len(boxes2)
  72. """
  73. # degenerate boxes gives inf / nan results
  74. # so do an early check
  75. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  76. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  77. iou, union = box_iou(boxes1, boxes2)
  78. lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  79. rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  80. wh = (rb - lt).clamp(min=0) # [N,M,2]
  81. area = wh[:, :, 0] * wh[:, :, 1]
  82. return iou - (area - union) / area
  83. def get_ious(bboxes1,
  84. bboxes2,
  85. box_mode="xyxy",
  86. iou_type="iou"):
  87. """
  88. Compute iou loss of type ['iou', 'giou', 'linear_iou']
  89. Args:
  90. inputs (tensor): pred values
  91. targets (tensor): target values
  92. weight (tensor): loss weight
  93. box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
  94. loss_type (str): 'giou' or 'iou' or 'linear_iou'
  95. reduction (str): reduction manner
  96. Returns:
  97. loss (tensor): computed iou loss.
  98. """
  99. if box_mode == "ltrb":
  100. bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1)
  101. bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1)
  102. elif box_mode != "xyxy":
  103. raise NotImplementedError
  104. eps = torch.finfo(torch.float32).eps
  105. bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \
  106. * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0)
  107. bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \
  108. * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0)
  109. w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2])
  110. - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0)
  111. h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3])
  112. - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0)
  113. area_intersect = w_intersect * h_intersect
  114. area_union = bboxes2_area + bboxes1_area - area_intersect
  115. ious = area_intersect / area_union.clamp(min=eps)
  116. if iou_type == "iou":
  117. return ious
  118. elif iou_type == "giou":
  119. g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \
  120. - torch.min(bboxes1[..., 0], bboxes2[..., 0])
  121. g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \
  122. - torch.min(bboxes1[..., 1], bboxes2[..., 1])
  123. ac_uion = g_w_intersect * g_h_intersect
  124. gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
  125. return gious
  126. else:
  127. raise NotImplementedError
  128. # copy from YOLOv5
  129. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  130. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
  131. # Get the coordinates of bounding boxes
  132. if xywh: # transform from xywh to xyxy
  133. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
  134. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  135. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  136. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  137. else: # x1, y1, x2, y2 = box1
  138. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
  139. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
  140. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  141. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  142. # Intersection area
  143. inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
  144. (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
  145. # Union Area
  146. union = w1 * h1 + w2 * h2 - inter + eps
  147. # IoU
  148. iou = inter / union
  149. if CIoU or DIoU or GIoU:
  150. cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
  151. ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
  152. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  153. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  154. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
  155. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  156. v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
  157. with torch.no_grad():
  158. alpha = v / (v - iou + (1 + eps))
  159. return iou - (rho2 / c2 + v * alpha) # CIoU
  160. return iou - rho2 / c2 # DIoU
  161. c_area = cw * ch + eps # convex area
  162. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  163. return iou # IoU
  164. if __name__ == '__main__':
  165. box1 = torch.tensor([[10, 10, 20, 20]])
  166. box2 = torch.tensor([[15, 15, 20, 20]])
  167. iou = box_iou(box1, box2)
  168. print(iou)