box_ops.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Utilities for bounding box manipulation and GIoU.
  4. """
  5. import torch
  6. from torchvision.ops.boxes import box_area
  7. def get_ious(bboxes1,
  8. bboxes2,
  9. box_mode="xyxy",
  10. iou_type="iou"):
  11. """
  12. Compute iou loss of type ['iou', 'giou', 'linear_iou']
  13. Args:
  14. inputs (tensor): pred values
  15. targets (tensor): target values
  16. weight (tensor): loss weight
  17. box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
  18. loss_type (str): 'giou' or 'iou' or 'linear_iou'
  19. reduction (str): reduction manner
  20. Returns:
  21. loss (tensor): computed iou loss.
  22. """
  23. if box_mode == "ltrb":
  24. bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1)
  25. bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1)
  26. elif box_mode != "xyxy":
  27. raise NotImplementedError
  28. eps = torch.finfo(torch.float32).eps
  29. bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \
  30. * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0)
  31. bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \
  32. * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0)
  33. w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2])
  34. - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0)
  35. h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3])
  36. - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0)
  37. area_intersect = w_intersect * h_intersect
  38. area_union = bboxes2_area + bboxes1_area - area_intersect
  39. ious = area_intersect / area_union.clamp(min=eps)
  40. if iou_type == "iou":
  41. return ious
  42. elif iou_type == "giou":
  43. g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \
  44. - torch.min(bboxes1[..., 0], bboxes2[..., 0])
  45. g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \
  46. - torch.min(bboxes1[..., 1], bboxes2[..., 1])
  47. ac_uion = g_w_intersect * g_h_intersect
  48. gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
  49. return gious
  50. else:
  51. raise NotImplementedError
  52. def delta2bbox(proposals,
  53. deltas,
  54. max_shape=None,
  55. wh_ratio_clip=16 / 1000,
  56. clip_border=True,
  57. add_ctr_clamp=False,
  58. ctr_clamp=32):
  59. dxy = deltas[..., :2]
  60. dwh = deltas[..., 2:]
  61. # Compute width/height of each roi
  62. pxy = proposals[..., :2]
  63. pwh = proposals[..., 2:]
  64. dxy_wh = pwh * dxy
  65. wh_ratio_clip = torch.tensor(wh_ratio_clip).to(deltas.device)
  66. max_ratio = torch.abs(torch.log(wh_ratio_clip))
  67. if add_ctr_clamp:
  68. dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
  69. dwh = torch.clamp(dwh, max=max_ratio)
  70. else:
  71. dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
  72. gxy = pxy + dxy_wh
  73. gwh = pwh * dwh.exp()
  74. x1y1 = gxy - (gwh * 0.5)
  75. x2y2 = gxy + (gwh * 0.5)
  76. bboxes = torch.cat([x1y1, x2y2], dim=-1)
  77. if clip_border and max_shape is not None:
  78. bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1])
  79. bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0])
  80. return bboxes
  81. def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
  82. # hack for matcher
  83. if proposals.size() != gt.size():
  84. proposals = proposals[:, None]
  85. gt = gt[None]
  86. proposals = proposals.float()
  87. gt = gt.float()
  88. px, py, pw, ph = proposals.unbind(-1)
  89. gx, gy, gw, gh = gt.unbind(-1)
  90. dx = (gx - px) / (pw + 0.1)
  91. dy = (gy - py) / (ph + 0.1)
  92. dw = torch.log(gw / (pw + 0.1))
  93. dh = torch.log(gh / (ph + 0.1))
  94. deltas = torch.stack([dx, dy, dw, dh], dim=-1)
  95. means = deltas.new_tensor(means).unsqueeze(0)
  96. stds = deltas.new_tensor(stds).unsqueeze(0)
  97. deltas = deltas.sub_(means).div_(stds)
  98. return deltas
  99. def box_cxcywh_to_xyxy(x):
  100. x_c, y_c, w, h = x.unbind(-1)
  101. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  102. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  103. return torch.stack(b, dim=-1)
  104. def box_xyxy_to_cxcywh(x):
  105. x0, y0, x1, y1 = x.unbind(-1)
  106. b = [(x0 + x1) / 2, (y0 + y1) / 2,
  107. (x1 - x0), (y1 - y0)]
  108. return torch.stack(b, dim=-1)
  109. # modified from torchvision to also return the union
  110. def box_iou(boxes1, boxes2):
  111. area1 = box_area(boxes1)
  112. area2 = box_area(boxes2)
  113. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  114. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  115. wh = (rb - lt).clamp(min=0) # [N,M,2]
  116. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  117. union = area1[:, None] + area2 - inter
  118. union[union == 0.0] = 1.0
  119. iou = inter / union
  120. return iou, union
  121. def generalized_box_iou(boxes1, boxes2):
  122. """
  123. Generalized IoU from https://giou.stanford.edu/
  124. The boxes should be in [x0, y0, x1, y1] format
  125. Returns a [N, M] pairwise matrix, where N = len(boxes1)
  126. and M = len(boxes2)
  127. """
  128. # degenerate boxes gives inf / nan results
  129. # so do an early check
  130. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  131. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  132. iou, union = box_iou(boxes1, boxes2)
  133. lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  134. rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  135. wh = (rb - lt).clamp(min=0) # [N,M,2]
  136. area = wh[:, :, 0] * wh[:, :, 1]
  137. return iou - (area - union) / area
  138. def masks_to_boxes(masks):
  139. """Compute the bounding boxes around the provided masks
  140. The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
  141. Returns a [N, 4] tensors, with the boxes in xyxy format
  142. """
  143. if masks.numel() == 0:
  144. return torch.zeros((0, 4), device=masks.device)
  145. h, w = masks.shape[-2:]
  146. y = torch.arange(0, h, dtype=torch.float)
  147. x = torch.arange(0, w, dtype=torch.float)
  148. y, x = torch.meshgrid(y, x)
  149. x_mask = (masks * x.unsqueeze(0))
  150. x_max = x_mask.flatten(1).max(-1)[0]
  151. x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
  152. y_mask = (masks * y.unsqueeze(0))
  153. y_max = y_mask.flatten(1).max(-1)[0]
  154. y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
  155. return torch.stack([x_min, y_min, x_max, y_max], 1)