JI_tools.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #coding:utf-8
  2. import numpy as np
  3. from .matching import maxWeightMatching
  4. def compute_matching(dt_boxes, gt_boxes, bm_thr):
  5. assert dt_boxes.shape[-1] > 3 and gt_boxes.shape[-1] > 3
  6. if dt_boxes.shape[0] < 1 or gt_boxes.shape[0] < 1:
  7. return list()
  8. N, K = dt_boxes.shape[0], gt_boxes.shape[0]
  9. ious = compute_iou_matrix(dt_boxes, gt_boxes)
  10. rows, cols = np.where(ious > bm_thr)
  11. bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)]
  12. mates = maxWeightMatching(bipartites)
  13. if len(mates) < 1:
  14. return list()
  15. rows = np.where(np.array(mates) > -1)[0]
  16. indices = np.where(rows < N + 1)[0]
  17. rows = rows[indices]
  18. cols = np.array([mates[i] for i in rows])
  19. matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)]
  20. return matches
  21. def compute_head_body_matching(dt_body, dt_head, gt_body, gt_head, bm_thr):
  22. assert dt_body.shape[-1] > 3 and gt_body.shape[-1] > 3
  23. assert dt_head.shape[-1] > 3 and gt_head.shape[-1] > 3
  24. assert dt_body.shape[0] == dt_head.shape[0]
  25. assert gt_body.shape[0] == gt_head.shape[0]
  26. N, K = dt_body.shape[0], gt_body.shape[0]
  27. ious_body = compute_iou_matrix(dt_body, gt_body)
  28. ious_head = compute_iou_matrix(dt_head, gt_head)
  29. mask_body = ious_body > bm_thr
  30. mask_head = ious_head > bm_thr
  31. # only keep the both matches detections
  32. mask = np.array(mask_body) & np.array(mask_head)
  33. ious = np.zeros((N, K))
  34. ious[mask] = (ious_body[mask] + ious_head[mask]) / 2
  35. rows, cols = np.where(ious > bm_thr)
  36. bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)]
  37. mates = maxWeightMatching(bipartites)
  38. if len(mates) < 1:
  39. return list()
  40. rows = np.where(np.array(mates) > -1)[0]
  41. indices = np.where(rows < N + 1)[0]
  42. rows = rows[indices]
  43. cols = np.array([mates[i] for i in rows])
  44. matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)]
  45. return matches
  46. def compute_multi_head_body_matching(dt_body, dt_head_0, dt_head_1, gt_body, gt_head, bm_thr):
  47. assert dt_body.shape[-1] > 3 and gt_body.shape[-1] > 3
  48. assert dt_head_0.shape[-1] > 3 and gt_head.shape[-1] > 3
  49. assert dt_head_1.shape[-1] > 3 and gt_head.shape[-1] > 3
  50. assert dt_body.shape[0] == dt_head_0.shape[0]
  51. assert gt_body.shape[0] == gt_head.shape[0]
  52. N, K = dt_body.shape[0], gt_body.shape[0]
  53. ious_body = compute_iou_matrix(dt_body, gt_body)
  54. ious_head_0 = compute_iou_matrix(dt_head_0, gt_head)
  55. ious_head_1 = compute_iou_matrix(dt_head_1, gt_head)
  56. mask_body = ious_body > bm_thr
  57. mask_head_0 = ious_head_0 > bm_thr
  58. mask_head_1 = ious_head_1 > bm_thr
  59. mask_head = mask_head_0 | mask_head_1
  60. # only keep the both matches detections
  61. mask = np.array(mask_body) & np.array(mask_head)
  62. ious = np.zeros((N, K))
  63. #ious[mask] = (ious_body[mask] + ious_head[mask]) / 2
  64. ious[mask] = ious_body[mask]
  65. rows, cols = np.where(ious > bm_thr)
  66. bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)]
  67. mates = maxWeightMatching(bipartites)
  68. if len(mates) < 1:
  69. return list()
  70. rows = np.where(np.array(mates) > -1)[0]
  71. indices = np.where(rows < N + 1)[0]
  72. rows = rows[indices]
  73. cols = np.array([mates[i] for i in rows])
  74. matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)]
  75. return matches
  76. def get_head_body_ignores(dt_body, dt_head, gt_body, gt_head, bm_thr):
  77. if gt_body.size:
  78. body_ioas = compute_ioa_matrix(dt_body, gt_body)
  79. head_ioas = compute_ioa_matrix(dt_head, gt_head)
  80. body_ioas = np.max(body_ioas, axis=1)
  81. head_ioas = np.max(head_ioas, axis=1)
  82. head_rows = np.where(head_ioas > bm_thr)[0]
  83. body_rows = np.where(body_ioas > bm_thr)[0]
  84. rows = set.union(set(head_rows), set(body_rows))
  85. return len(rows)
  86. else:
  87. return 0
  88. def get_ignores(dt_boxes, gt_boxes, bm_thr):
  89. if gt_boxes.size:
  90. ioas = compute_ioa_matrix(dt_boxes, gt_boxes)
  91. ioas = np.max(ioas, axis = 1)
  92. rows = np.where(ioas > bm_thr)[0]
  93. return len(rows)
  94. else:
  95. return 0
  96. def compute_ioa_matrix(dboxes: np.ndarray, gboxes: np.ndarray):
  97. eps = 1e-6
  98. assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4
  99. N, K = dboxes.shape[0], gboxes.shape[0]
  100. dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1))
  101. gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1))
  102. iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0])
  103. ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1])
  104. inter = np.maximum(0, iw) * np.maximum(0, ih)
  105. dtarea = np.maximum(dtboxes[:,:,2] - dtboxes[:,:,0], 0) * np.maximum(dtboxes[:,:,3] - dtboxes[:,:,1], 0)
  106. ioas = inter / (dtarea + eps)
  107. return ioas
  108. def compute_iou_matrix(dboxes:np.ndarray, gboxes:np.ndarray):
  109. eps = 1e-6
  110. assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4
  111. N, K = dboxes.shape[0], gboxes.shape[0]
  112. dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1))
  113. gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1))
  114. iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0])
  115. ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1])
  116. inter = np.maximum(0, iw) * np.maximum(0, ih)
  117. dtarea = (dtboxes[:,:,2] - dtboxes[:,:,0]) * (dtboxes[:,:,3] - dtboxes[:,:,1])
  118. gtarea = (gtboxes[:,:,2] - gtboxes[:,:,0]) * (gtboxes[:,:,3] - gtboxes[:,:,1])
  119. ious = inter / (dtarea + gtarea - inter + eps)
  120. return ious