coco_evaluator.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import json
  2. import os
  3. import contextlib
  4. import torch
  5. from pycocotools.cocoeval import COCOeval
  6. from datasets import build_transform
  7. from datasets.coco import build_coco
  8. class COCOAPIEvaluator():
  9. def __init__(self, args, cfg, device):
  10. # ----------------- Basic parameters -----------------
  11. self.image_set = 'val2017'
  12. self.device = device
  13. # ----------------- Metrics -----------------
  14. self.map = 0.
  15. self.ap50_95 = 0.
  16. self.ap50 = 0.
  17. # ----------------- Dataset -----------------
  18. self.transform = build_transform(cfg, is_train=False)
  19. self.dataset = build_coco(args, self.transform, is_train=False)
  20. @torch.no_grad()
  21. def evaluate(self, model):
  22. ids = []
  23. coco_results = []
  24. model.eval()
  25. model.trainable = False
  26. # start testing
  27. for index, (image, target) in enumerate(self.dataset):
  28. if index % 500 == 0:
  29. print('[Eval: %d / %d]'%(index, len(self.dataset)))
  30. # image id
  31. id_ = int(target['image_id'])
  32. ids.append(id_)
  33. # inference
  34. image = image.unsqueeze(0).to(self.device)
  35. outputs = model(image)
  36. scores = outputs['scores']
  37. labels = outputs['labels']
  38. bboxes = outputs['bboxes']
  39. # rescale bbox
  40. orig_h, orig_w = target["orig_size"].tolist()
  41. bboxes[..., 0::2] *= orig_w
  42. bboxes[..., 1::2] *= orig_h
  43. # reformat results
  44. for i, box in enumerate(bboxes):
  45. x1 = float(box[0])
  46. y1 = float(box[1])
  47. x2 = float(box[2])
  48. y2 = float(box[3])
  49. label = self.dataset.coco_indexs[int(labels[i])]
  50. # COCO json format
  51. bbox = [x1, y1, x2 - x1, y2 - y1]
  52. score = float(scores[i])
  53. A = {"image_id": id_,
  54. "category_id": label,
  55. "bbox": bbox,
  56. "score": score}
  57. coco_results.append(A)
  58. model.train()
  59. model.trainable = True
  60. annType = ['segm', 'bbox', 'keypoints']
  61. # Evaluate the Dt (detection) json comparing with the ground truth
  62. if len(coco_results) > 0:
  63. print('evaluating ......')
  64. cocoGt = self.dataset.coco
  65. # suppress pycocotools prints
  66. with open(os.devnull, 'w') as devnull:
  67. with contextlib.redirect_stdout(devnull):
  68. cocoDt = cocoGt.loadRes(coco_results)
  69. cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
  70. cocoEval.params.imgIds = ids
  71. cocoEval.evaluate()
  72. cocoEval.accumulate()
  73. cocoEval.summarize()
  74. # update mAP
  75. ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
  76. print('ap50_95 : ', ap50_95)
  77. print('ap50 : ', ap50)
  78. self.map = ap50_95
  79. self.ap50_95 = ap50_95
  80. self.ap50 = ap50
  81. del coco_results
  82. else:
  83. print('No coco detection results !')