compute_JI.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. import sys
  3. import json
  4. import math
  5. import argparse
  6. from multiprocessing import Queue, Process
  7. from tqdm import tqdm
  8. import numpy as np
  9. from .JIToolkits.JI_tools import compute_matching, get_ignores
  10. sys.path.insert(0, '../')
  11. # ---------------------------------- JI Evaluation functions ----------------------------------
  12. def evaluation_all(path, target_key, nr_procs=10):
  13. records = load_json_lines(path)
  14. res_line = []
  15. res_JI = []
  16. for i in range(10):
  17. score_thr = 1e-1 * i
  18. total = len(records)
  19. stride = math.ceil(total / nr_procs)
  20. result_queue = Queue(10000)
  21. results, procs = [], []
  22. for i in range(nr_procs):
  23. start = i*stride
  24. end = np.min([start+stride,total])
  25. sample_data = records[start:end]
  26. p = Process(target= compute_JI_with_ignore, args=(result_queue, sample_data, score_thr, target_key))
  27. p.start()
  28. procs.append(p)
  29. tqdm.monitor_interval = 0
  30. pbar = tqdm(total=total, leave = False, ascii = True)
  31. for i in range(total):
  32. t = result_queue.get()
  33. results.append(t)
  34. pbar.update(1)
  35. for p in procs:
  36. p.join()
  37. pbar.close()
  38. line, mean_ratio = gather(results)
  39. line = 'score_thr:{:.1f}, {}'.format(score_thr, line)
  40. print(line)
  41. res_line.append(line)
  42. res_JI.append(mean_ratio)
  43. return res_line, max(res_JI)
  44. def compute_JI_with_ignore(result_queue, records, score_thr, target_key, bm_thresh=0.5):
  45. for record in records:
  46. gt_boxes = load_bboxes(record, 'gtboxes', target_key, 'tag')
  47. gt_boxes[:,2:4] += gt_boxes[:,:2]
  48. gt_boxes = clip_boundary(gt_boxes, record['height'], record['width'])
  49. dt_boxes = load_bboxes(record, 'dtboxes', target_key, 'score')
  50. dt_boxes[:,2:4] += dt_boxes[:,:2]
  51. dt_boxes = clip_boundary(dt_boxes, record['height'], record['width'])
  52. keep = dt_boxes[:, -1] > score_thr
  53. dt_boxes = dt_boxes[keep][:, :-1]
  54. gt_tag = np.array(gt_boxes[:,-1]!=-1)
  55. matches = compute_matching(dt_boxes, gt_boxes[gt_tag, :4], bm_thresh)
  56. # get the unmatched_indices
  57. matched_indices = np.array([j for (j,_) in matches])
  58. unmatched_indices = list(set(np.arange(dt_boxes.shape[0])) - set(matched_indices))
  59. num_ignore_dt = get_ignores(dt_boxes[unmatched_indices], gt_boxes[~gt_tag, :4], bm_thresh)
  60. matched_indices = np.array([j for (_,j) in matches])
  61. unmatched_indices = list(set(np.arange(gt_boxes[gt_tag].shape[0])) - set(matched_indices))
  62. num_ignore_gt = get_ignores(gt_boxes[gt_tag][unmatched_indices], gt_boxes[~gt_tag, :4], bm_thresh)
  63. # compurte results
  64. eps = 1e-6
  65. k = len(matches)
  66. m = gt_tag.sum() - num_ignore_gt
  67. n = dt_boxes.shape[0] - num_ignore_dt
  68. ratio = k / (m + n -k + eps)
  69. recall = k / (m + eps)
  70. cover = k / (n + eps)
  71. noise = 1 - cover
  72. result_dict = dict(ratio = ratio, recall = recall, cover = cover,
  73. noise = noise, k = k, m = m, n = n)
  74. result_queue.put_nowait(result_dict)
  75. def gather(results):
  76. assert len(results)
  77. img_num = 0
  78. for result in results:
  79. if result['n'] != 0 or result['m'] != 0:
  80. img_num += 1
  81. mean_ratio = np.sum([rb['ratio'] for rb in results]) / img_num
  82. mean_cover = np.sum([rb['cover'] for rb in results]) / img_num
  83. mean_recall = np.sum([rb['recall'] for rb in results]) / img_num
  84. mean_noise = 1 - mean_cover
  85. valids = np.sum([rb['k'] for rb in results])
  86. total = np.sum([rb['n'] for rb in results])
  87. gtn = np.sum([rb['m'] for rb in results])
  88. #line = 'mean_ratio:{:.4f}, mean_cover:{:.4f}, mean_recall:{:.4f}, mean_noise:{:.4f}, valids:{}, total:{}, gtn:{}'.format(
  89. # mean_ratio, mean_cover, mean_recall, mean_noise, valids, total, gtn)
  90. line = 'mean_ratio:{:.4f}, valids:{}, total:{}, gtn:{}'.format(
  91. mean_ratio, valids, total, gtn)
  92. return line, mean_ratio
  93. def common_process(func, cls_list, nr_procs=10):
  94. total = len(cls_list)
  95. stride = math.ceil(total / nr_procs)
  96. result_queue = Queue(10000)
  97. results, procs = [], []
  98. for i in range(nr_procs):
  99. start = i*stride
  100. end = np.min([start+stride,total])
  101. sample_data = cls_list[start:end]
  102. p = Process(target= func,args=(result_queue, sample_data))
  103. p.start()
  104. procs.append(p)
  105. for i in range(total):
  106. t = result_queue.get()
  107. if t is None:
  108. continue
  109. results.append(t)
  110. for p in procs:
  111. p.join()
  112. return results
  113. # ---------------------------------- Basic functions ----------------------------------
  114. def load_json_lines(fpath):
  115. print(fpath)
  116. assert os.path.exists(fpath)
  117. with open(fpath,'r') as fid:
  118. lines = fid.readlines()
  119. records = [json.loads(line.strip('\n')) for line in lines]
  120. return records
  121. def save_json_lines(content,fpath):
  122. with open(fpath,'w') as fid:
  123. for db in content:
  124. line = json.dumps(db)+'\n'
  125. fid.write(line)
  126. def load_bboxes(dict_input, key_name, key_box, key_score=None, key_tag=None):
  127. assert key_name in dict_input
  128. if len(dict_input[key_name]) < 1:
  129. return np.empty([0, 5])
  130. else:
  131. assert key_box in dict_input[key_name][0]
  132. if key_score:
  133. assert key_score in dict_input[key_name][0]
  134. if key_tag:
  135. assert key_tag in dict_input[key_name][0]
  136. if key_score:
  137. if key_tag:
  138. bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score], rb[key_tag])) for rb in dict_input[key_name]])
  139. else:
  140. bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score])) for rb in dict_input[key_name]])
  141. else:
  142. if key_tag:
  143. bboxes = np.vstack([np.hstack((rb[key_box], rb[key_tag])) for rb in dict_input[key_name]])
  144. else:
  145. bboxes = np.vstack([rb[key_box] for rb in dict_input[key_name]])
  146. return bboxes
  147. def clip_boundary(boxes,height,width):
  148. assert boxes.shape[-1]>=4
  149. boxes[:,0] = np.minimum(np.maximum(boxes[:,0],0), width - 1)
  150. boxes[:,1] = np.minimum(np.maximum(boxes[:,1],0), height - 1)
  151. boxes[:,2] = np.maximum(np.minimum(boxes[:,2],width), 0)
  152. boxes[:,3] = np.maximum(np.minimum(boxes[:,3],height), 0)
  153. return boxes
  154. if __name__ == "__main__":
  155. parser = argparse.ArgumentParser(description='Analyze a json result file with iou match')
  156. parser.add_argument('--detfile', required=True, help='path of json result file to load')
  157. parser.add_argument('--target_key', required=True)
  158. args = parser.parse_args()
  159. evaluation_all(args.detfile, args.target_key)