coco_evaluator.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import json
  2. import tempfile
  3. import torch
  4. from dataset.coco import COCODataset
  5. from utils.box_ops import rescale_bboxes
  6. try:
  7. from pycocotools.cocoeval import COCOeval
  8. except:
  9. print("It seems that the COCOAPI is not installed.")
  10. class COCOAPIEvaluator():
  11. """
  12. COCO AP Evaluation class.
  13. All the data in the val2017 dataset are processed \
  14. and evaluated by COCO API.
  15. """
  16. def __init__(self, data_dir, device, testset=False, transform=None):
  17. """
  18. Args:
  19. data_dir (str): dataset root directory
  20. img_size (int): image size after preprocess. images are resized \
  21. to squares whose shape is (img_size, img_size).
  22. confthre (float):
  23. confidence threshold ranging from 0 to 1, \
  24. which is defined in the config file.
  25. nmsthre (float):
  26. IoU threshold of non-max supression ranging from 0 to 1.
  27. """
  28. # ----------------- Basic parameters -----------------
  29. self.image_set = 'test2017' if testset else 'val2017'
  30. self.transform = transform
  31. self.device = device
  32. self.testset = testset
  33. # ----------------- Metrics -----------------
  34. self.map = 0.
  35. self.ap50_95 = 0.
  36. self.ap50 = 0.
  37. # ----------------- Dataset -----------------
  38. self.dataset = COCODataset(data_dir=data_dir, image_set=self.image_set)
  39. @torch.no_grad()
  40. def evaluate(self, model):
  41. """
  42. COCO average precision (AP) Evaluation. Iterate inference on the test dataset
  43. and the results are evaluated by COCO API.
  44. Args:
  45. model : model object
  46. Returns:
  47. ap50_95 (float) : calculated COCO AP for IoU=50:95
  48. ap50 (float) : calculated COCO AP for IoU=50
  49. """
  50. model.eval()
  51. ids = []
  52. data_dict = []
  53. num_images = len(self.dataset)
  54. print('total number of images: %d' % (num_images))
  55. # start testing
  56. for index in range(num_images): # all the data in val2017
  57. if index % 500 == 0:
  58. print('[Eval: %d / %d]'%(index, num_images))
  59. # load an image
  60. img, id_ = self.dataset.pull_image(index)
  61. orig_h, orig_w, _ = img.shape
  62. # preprocess
  63. x, _, ratio = self.transform(img)
  64. x = x.unsqueeze(0).to(self.device) / 255.
  65. id_ = int(id_)
  66. ids.append(id_)
  67. # inference
  68. outputs = model(x)
  69. scores = outputs['scores']
  70. labels = outputs['labels']
  71. bboxes = outputs['bboxes']
  72. # rescale bboxes
  73. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  74. # process outputs
  75. for i, box in enumerate(bboxes):
  76. x1 = float(box[0])
  77. y1 = float(box[1])
  78. x2 = float(box[2])
  79. y2 = float(box[3])
  80. label = self.dataset.class_ids[int(labels[i])]
  81. bbox = [x1, y1, x2 - x1, y2 - y1]
  82. score = float(scores[i]) # object score * class score
  83. A = {"image_id": id_, "category_id": label, "bbox": bbox,
  84. "score": score} # COCO json format
  85. data_dict.append(A)
  86. annType = ['segm', 'bbox', 'keypoints']
  87. # Evaluate the Dt (detection) json comparing with the ground truth
  88. if len(data_dict) > 0:
  89. print('evaluating ......')
  90. cocoGt = self.dataset.coco
  91. # workaround: temporarily write data to json file because pycocotools can't process dict in py36.
  92. if self.testset:
  93. json.dump(data_dict, open('coco_test-dev.json', 'w'))
  94. cocoDt = cocoGt.loadRes('coco_test-dev.json')
  95. return -1, -1
  96. else:
  97. _, tmp = tempfile.mkstemp()
  98. json.dump(data_dict, open(tmp, 'w'))
  99. cocoDt = cocoGt.loadRes(tmp)
  100. cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
  101. cocoEval.params.imgIds = ids
  102. cocoEval.evaluate()
  103. cocoEval.accumulate()
  104. cocoEval.summarize()
  105. ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
  106. print('ap50_95 : ', ap50_95)
  107. print('ap50 : ', ap50)
  108. self.map = ap50_95
  109. self.ap50_95 = ap50_95
  110. self.ap50 = ap50
  111. return ap50, ap50_95
  112. else:
  113. return 0, 0