box_ops.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. import numpy as np
  3. from torchvision.ops.boxes import box_area
  4. # modified from torchvision to also return the union
  5. def box_iou(boxes1, boxes2):
  6. area1 = box_area(boxes1)
  7. area2 = box_area(boxes2)
  8. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  9. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  10. wh = (rb - lt).clamp(min=0) # [N,M,2]
  11. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  12. union = area1[:, None] + area2 - inter
  13. iou = inter / union
  14. return iou, union
  15. def get_ious(bboxes1,
  16. bboxes2,
  17. box_mode="xyxy",
  18. iou_type="iou"):
  19. """
  20. Compute iou loss of type ['iou', 'giou', 'linear_iou']
  21. Args:
  22. inputs (tensor): pred values
  23. targets (tensor): target values
  24. weight (tensor): loss weight
  25. box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
  26. loss_type (str): 'giou' or 'iou' or 'linear_iou'
  27. reduction (str): reduction manner
  28. Returns:
  29. loss (tensor): computed iou loss.
  30. """
  31. if box_mode == "ltrb":
  32. bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1)
  33. bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1)
  34. elif box_mode != "xyxy":
  35. raise NotImplementedError
  36. eps = torch.finfo(torch.float32).eps
  37. bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \
  38. * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0)
  39. bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \
  40. * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0)
  41. w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2])
  42. - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0)
  43. h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3])
  44. - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0)
  45. area_intersect = w_intersect * h_intersect
  46. area_union = bboxes2_area + bboxes1_area - area_intersect
  47. ious = area_intersect / area_union.clamp(min=eps)
  48. if iou_type == "iou":
  49. return ious
  50. elif iou_type == "giou":
  51. g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \
  52. - torch.min(bboxes1[..., 0], bboxes2[..., 0])
  53. g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \
  54. - torch.min(bboxes1[..., 1], bboxes2[..., 1])
  55. ac_uion = g_w_intersect * g_h_intersect
  56. gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
  57. return gious
  58. else:
  59. raise NotImplementedError
  60. def rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas=None):
  61. origin_h, origin_w = origin_img_size
  62. cur_img_h, cur_img_w = cur_img_size
  63. if deltas is None:
  64. # rescale
  65. bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / cur_img_w * origin_w
  66. bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / cur_img_h * origin_h
  67. # clip bboxes
  68. bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
  69. bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
  70. else:
  71. # rescale
  72. bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / (cur_img_w - deltas[0]) * origin_w
  73. bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / (cur_img_h - deltas[1]) * origin_h
  74. # clip bboxes
  75. bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
  76. bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
  77. return bboxes
  78. if __name__ == '__main__':
  79. box1 = torch.tensor([[10, 10, 20, 20]])
  80. box2 = torch.tensor([[15, 15, 20, 20]])
  81. iou = box_iou(box1, box2)
  82. print(iou)