box_ops.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Utilities for bounding box manipulation and GIoU.
  4. """
  5. import torch
  6. from torchvision.ops.boxes import box_area
  7. def get_ious(bboxes1,
  8. bboxes2,
  9. box_mode="xyxy",
  10. iou_type="iou"):
  11. """
  12. Compute iou loss of type ['iou', 'giou', 'linear_iou']
  13. Args:
  14. inputs (tensor): pred values
  15. targets (tensor): target values
  16. weight (tensor): loss weight
  17. box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
  18. loss_type (str): 'giou' or 'iou' or 'linear_iou'
  19. reduction (str): reduction manner
  20. Returns:
  21. loss (tensor): computed iou loss.
  22. """
  23. if box_mode == "ltrb":
  24. bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1)
  25. bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1)
  26. elif box_mode != "xyxy":
  27. raise NotImplementedError
  28. eps = torch.finfo(torch.float32).eps
  29. bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \
  30. * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0)
  31. bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \
  32. * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0)
  33. w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2])
  34. - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0)
  35. h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3])
  36. - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0)
  37. area_intersect = w_intersect * h_intersect
  38. area_union = bboxes2_area + bboxes1_area - area_intersect
  39. ious = area_intersect / area_union.clamp(min=eps)
  40. if iou_type == "iou":
  41. return ious
  42. elif iou_type == "giou":
  43. g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \
  44. - torch.min(bboxes1[..., 0], bboxes2[..., 0])
  45. g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \
  46. - torch.min(bboxes1[..., 1], bboxes2[..., 1])
  47. ac_uion = g_w_intersect * g_h_intersect
  48. gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
  49. return gious
  50. else:
  51. raise NotImplementedError
  52. def box_cxcywh_to_xyxy(x):
  53. x_c, y_c, w, h = x.unbind(-1)
  54. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  55. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  56. return torch.stack(b, dim=-1)
  57. def box_xyxy_to_cxcywh(x):
  58. x0, y0, x1, y1 = x.unbind(-1)
  59. b = [(x0 + x1) / 2, (y0 + y1) / 2,
  60. (x1 - x0), (y1 - y0)]
  61. return torch.stack(b, dim=-1)
  62. # modified from torchvision to also return the union
  63. def box_iou(boxes1, boxes2):
  64. area1 = box_area(boxes1)
  65. area2 = box_area(boxes2)
  66. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  67. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  68. wh = (rb - lt).clamp(min=0) # [N,M,2]
  69. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  70. union = area1[:, None] + area2 - inter
  71. union[union == 0.0] = 1.0
  72. iou = inter / union
  73. return iou, union
  74. def generalized_box_iou(boxes1, boxes2):
  75. """
  76. Generalized IoU from https://giou.stanford.edu/
  77. The boxes should be in [x0, y0, x1, y1] format
  78. Returns a [N, M] pairwise matrix, where N = len(boxes1)
  79. and M = len(boxes2)
  80. """
  81. # degenerate boxes gives inf / nan results
  82. # so do an early check
  83. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  84. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  85. iou, union = box_iou(boxes1, boxes2)
  86. lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
  87. rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
  88. wh = (rb - lt).clamp(min=0) # [N,M,2]
  89. area = wh[:, :, 0] * wh[:, :, 1]
  90. return iou - (area - union) / area