coco_evaluator.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import json
  2. import tempfile
  3. import torch
  4. from dataset.coco import COCODataset
  5. from utils.box_ops import rescale_bboxes, rescale_bboxes_with_deltas
  6. from dataset.data_augment import SSDBaseTransform, YOLOv5BaseTransform
  7. try:
  8. from pycocotools.cocoeval import COCOeval
  9. except:
  10. print("It seems that the COCOAPI is not installed.")
  11. class COCOAPIEvaluator():
  12. """
  13. COCO AP Evaluation class.
  14. All the data in the val2017 dataset are processed \
  15. and evaluated by COCO API.
  16. """
  17. def __init__(self, data_dir, device, testset=False, transform=None):
  18. """
  19. Args:
  20. data_dir (str): dataset root directory
  21. img_size (int): image size after preprocess. images are resized \
  22. to squares whose shape is (img_size, img_size).
  23. confthre (float):
  24. confidence threshold ranging from 0 to 1, \
  25. which is defined in the config file.
  26. nmsthre (float):
  27. IoU threshold of non-max supression ranging from 0 to 1.
  28. """
  29. self.testset = testset
  30. if self.testset:
  31. image_set = 'test2017'
  32. else:
  33. image_set = 'val2017'
  34. self.dataset = COCODataset(data_dir=data_dir, image_set=image_set, is_train=False)
  35. self.transform = transform
  36. self.device = device
  37. self.map = 0.
  38. self.ap50_95 = 0.
  39. self.ap50 = 0.
  40. @torch.no_grad()
  41. def evaluate(self, model):
  42. """
  43. COCO average precision (AP) Evaluation. Iterate inference on the test dataset
  44. and the results are evaluated by COCO API.
  45. Args:
  46. model : model object
  47. Returns:
  48. ap50_95 (float) : calculated COCO AP for IoU=50:95
  49. ap50 (float) : calculated COCO AP for IoU=50
  50. """
  51. model.eval()
  52. ids = []
  53. data_dict = []
  54. num_images = len(self.dataset)
  55. print('total number of images: %d' % (num_images))
  56. # start testing
  57. for index in range(num_images): # all the data in val2017
  58. if index % 500 == 0:
  59. print('[Eval: %d / %d]'%(index, num_images))
  60. # load an image
  61. img, id_ = self.dataset.pull_image(index)
  62. orig_h, orig_w, _ = img.shape
  63. # preprocess
  64. x, _, deltas = self.transform(img)
  65. x = x.unsqueeze(0).to(self.device)
  66. id_ = int(id_)
  67. ids.append(id_)
  68. # inference
  69. outputs = model(x)
  70. bboxes, scores, cls_inds = outputs
  71. # rescale bboxes
  72. if isinstance(self.transform, SSDBaseTransform):
  73. origin_img_size = [orig_h, orig_w]
  74. cur_img_size = [*x.shape[-2:]]
  75. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size)
  76. elif isinstance(self.transform, YOLOv5BaseTransform):
  77. origin_img_size = [orig_h, orig_w]
  78. cur_img_size = [*x.shape[-2:]]
  79. bboxes = rescale_bboxes_with_deltas(bboxes, deltas, origin_img_size, cur_img_size)
  80. # process outputs
  81. for i, box in enumerate(bboxes):
  82. x1 = float(box[0])
  83. y1 = float(box[1])
  84. x2 = float(box[2])
  85. y2 = float(box[3])
  86. label = self.dataset.class_ids[int(cls_inds[i])]
  87. bbox = [x1, y1, x2 - x1, y2 - y1]
  88. score = float(scores[i]) # object score * class score
  89. A = {"image_id": id_, "category_id": label, "bbox": bbox,
  90. "score": score} # COCO json format
  91. data_dict.append(A)
  92. annType = ['segm', 'bbox', 'keypoints']
  93. # Evaluate the Dt (detection) json comparing with the ground truth
  94. if len(data_dict) > 0:
  95. print('evaluating ......')
  96. cocoGt = self.dataset.coco
  97. # workaround: temporarily write data to json file because pycocotools can't process dict in py36.
  98. if self.testset:
  99. json.dump(data_dict, open('coco_test-dev.json', 'w'))
  100. cocoDt = cocoGt.loadRes('coco_test-dev.json')
  101. return -1, -1
  102. else:
  103. _, tmp = tempfile.mkstemp()
  104. json.dump(data_dict, open(tmp, 'w'))
  105. cocoDt = cocoGt.loadRes(tmp)
  106. cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
  107. cocoEval.params.imgIds = ids
  108. cocoEval.evaluate()
  109. cocoEval.accumulate()
  110. cocoEval.summarize()
  111. ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
  112. print('ap50_95 : ', ap50_95)
  113. print('ap50 : ', ap50)
  114. self.map = ap50_95
  115. self.ap50_95 = ap50_95
  116. self.ap50 = ap50
  117. return ap50, ap50_95
  118. else:
  119. return 0, 0