coco_evaluator.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import json
  2. import os
  3. import contextlib
  4. import torch
  5. from pycocotools.cocoeval import COCOeval
  6. from datasets import build_transform
  7. from datasets.coco import build_coco
  8. class COCOAPIEvaluator():
  9. def __init__(self, args, cfg, device):
  10. # ----------------- Basic parameters -----------------
  11. self.image_set = 'val2017'
  12. self.device = device
  13. # ----------------- Metrics -----------------
  14. self.map = 0.
  15. self.ap50_95 = 0.
  16. self.ap50 = 0.
  17. # ----------------- Dataset -----------------
  18. self.transform = build_transform(cfg, is_train=False)
  19. self.dataset = build_coco(args, self.transform, is_train=False)
  20. @torch.no_grad()
  21. def evaluate(self, model):
  22. ids = []
  23. coco_results = []
  24. model.eval()
  25. model.trainable = False
  26. # start testing
  27. for index, (image, target) in enumerate(self.dataset):
  28. if index % 500 == 0:
  29. print('[Eval: %d / %d]'%(index, len(self.dataset)))
  30. # image id
  31. id_ = int(target['image_id'])
  32. ids.append(id_)
  33. # inference
  34. image = image.unsqueeze(0).to(self.device)
  35. outputs = model(image)
  36. bboxes, scores, cls_inds = outputs
  37. # rescale bbox
  38. orig_h, orig_w = target["orig_size"].tolist()
  39. bboxes[..., 0::2] *= orig_w
  40. bboxes[..., 1::2] *= orig_h
  41. # reformat results
  42. for i, box in enumerate(bboxes):
  43. x1 = float(box[0])
  44. y1 = float(box[1])
  45. x2 = float(box[2])
  46. y2 = float(box[3])
  47. label = self.dataset.coco_indexs[int(cls_inds[i])]
  48. # COCO json format
  49. bbox = [x1, y1, x2 - x1, y2 - y1]
  50. score = float(scores[i])
  51. A = {"image_id": id_,
  52. "category_id": label,
  53. "bbox": bbox,
  54. "score": score}
  55. coco_results.append(A)
  56. model.train()
  57. model.trainable = True
  58. annType = ['segm', 'bbox', 'keypoints']
  59. # Evaluate the Dt (detection) json comparing with the ground truth
  60. if len(coco_results) > 0:
  61. print('evaluating ......')
  62. cocoGt = self.dataset.coco
  63. # suppress pycocotools prints
  64. with open(os.devnull, 'w') as devnull:
  65. with contextlib.redirect_stdout(devnull):
  66. cocoDt = cocoGt.loadRes(coco_results)
  67. cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
  68. cocoEval.params.imgIds = ids
  69. cocoEval.evaluate()
  70. cocoEval.accumulate()
  71. cocoEval.summarize()
  72. # update mAP
  73. ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
  74. print('ap50_95 : ', ap50_95)
  75. print('ap50 : ', ap50)
  76. self.map = ap50_95
  77. self.ap50_95 = ap50_95
  78. self.ap50 = ap50
  79. del coco_results
  80. else:
  81. print('No coco detection results !')