coco_evaluator.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import json
  2. import os
  3. import contextlib
  4. import torch
  5. from pycocotools.cocoeval import COCOeval
  6. from datasets import build_transform
  7. from datasets.coco import build_coco
  8. class COCOAPIEvaluator():
  9. def __init__(self, args, cfg, device):
  10. # ----------------- Basic parameters -----------------
  11. self.image_set = 'val2017'
  12. self.device = device
  13. # ----------------- Metrics -----------------
  14. self.map = 0.
  15. self.ap50_95 = 0.
  16. self.ap50 = 0.
  17. # ----------------- Dataset -----------------
  18. self.transform = build_transform(cfg, is_train=False)
  19. self.dataset = build_coco(args, self.transform, is_train=False)
  20. @torch.no_grad()
  21. def evaluate(self, model):
  22. ids = []
  23. coco_results = []
  24. model.eval()
  25. model.trainable = False
  26. # start testing
  27. for index, (image, target) in enumerate(self.dataset):
  28. if index % 500 == 0:
  29. print('[Eval: %d / %d]'%(index, len(self.dataset)))
  30. # image id
  31. id_ = int(target['image_id'])
  32. ids.append(id_)
  33. # inference
  34. image = image.unsqueeze(0).to(self.device)
  35. outputs = model(image)
  36. bboxes, scores, cls_inds = outputs
  37. # rescale bbox
  38. orig_h, orig_w = target["orig_size"].tolist()
  39. bboxes[..., 0::2] *= orig_w
  40. bboxes[..., 1::2] *= orig_h
  41. # reformat results
  42. for i, box in enumerate(bboxes):
  43. x1 = float(box[0])
  44. y1 = float(box[1])
  45. x2 = float(box[2])
  46. y2 = float(box[3])
  47. label = self.dataset.coco_indexs[int(cls_inds[i])]
  48. # COCO json format
  49. bbox = [x1, y1, x2 - x1, y2 - y1]
  50. score = float(scores[i])
  51. A = {"image_id": id_,
  52. "category_id": label,
  53. "bbox": bbox,
  54. "score": score}
  55. coco_results.append(A)
  56. model.train()
  57. model.trainable = True
  58. annType = ['segm', 'bbox', 'keypoints']
  59. # Evaluate the Dt (detection) json comparing with the ground truth
  60. if len(coco_results) > 0:
  61. print('evaluating ......')
  62. cocoGt = self.dataset.coco
  63. if self.testset:
  64. json.dump(coco_results, open('coco_test-dev.json', 'w'))
  65. cocoDt = cocoGt.loadRes('coco_test-dev.json')
  66. else:
  67. # suppress pycocotools prints
  68. with open(os.devnull, 'w') as devnull:
  69. with contextlib.redirect_stdout(devnull):
  70. cocoDt = cocoGt.loadRes(coco_results)
  71. cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
  72. cocoEval.params.imgIds = ids
  73. cocoEval.evaluate()
  74. cocoEval.accumulate()
  75. cocoEval.summarize()
  76. # update mAP
  77. ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
  78. print('ap50_95 : ', ap50_95)
  79. print('ap50 : ', ap50)
  80. self.map = ap50_95
  81. self.ap50_95 = ap50_95
  82. self.ap50 = ap50
  83. del coco_results
  84. else:
  85. print('No coco detection results !')