onnx_inference.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. import argparse
  5. import os
  6. import cv2
  7. import time
  8. import numpy as np
  9. import sys
  10. sys.path.append('../../')
  11. import onnxruntime
  12. from utils.misc import PreProcessor, PostProcessor
  13. from utils.vis_tools import visualize
  14. def make_parser():
  15. parser = argparse.ArgumentParser("onnxruntime inference sample")
  16. parser.add_argument("-m", "--model", type=str, default="../../weights/onnx/11/yolov1.onnx",
  17. help="Input your onnx model.")
  18. parser.add_argument("-i", "--image_path", type=str, default='../test_image.jpg',
  19. help="Path to your input image.")
  20. parser.add_argument("-o", "--output_dir", type=str, default='../../det_results/onnx/',
  21. help="Path to your output directory.")
  22. parser.add_argument("-s", "--score_thr", type=float, default=0.35,
  23. help="Score threshould to filter the result.")
  24. parser.add_argument("-size", "--img_size", type=int, default=640,
  25. help="Specify an input shape for inference.")
  26. return parser
  27. if __name__ == '__main__':
  28. args = make_parser().parse_args()
  29. # class color for better visualization
  30. np.random.seed(0)
  31. class_colors = [(np.random.randint(255),
  32. np.random.randint(255),
  33. np.random.randint(255)) for _ in range(80)]
  34. # preprocessor
  35. prepocess = PreProcessor(img_size=args.img_size)
  36. # postprocessor
  37. postprocess = PostProcessor(num_classes=80, conf_thresh=args.score_thr, nms_thresh=0.5)
  38. # read an image
  39. input_shape = tuple([args.img_size, args.img_size])
  40. origin_img = cv2.imread(args.image_path)
  41. # preprocess
  42. x, ratio = prepocess(origin_img)
  43. t0 = time.time()
  44. # inference
  45. session = onnxruntime.InferenceSession(args.model)
  46. ort_inputs = {session.get_inputs()[0].name: x[None, :, :, :]}
  47. output = session.run(None, ort_inputs)
  48. print("inference time: {:.1f} ms".format((time.time() - t0)*1000))
  49. t0 = time.time()
  50. # post process
  51. bboxes, scores, labels = postprocess(output[0])
  52. bboxes /= ratio
  53. print("post-process time: {:.1f} ms".format((time.time() - t0)*1000))
  54. # visualize detection
  55. origin_img = visualize(
  56. img=origin_img,
  57. bboxes=bboxes,
  58. scores=scores,
  59. labels=labels,
  60. vis_thresh=args.score_thr,
  61. class_colors=class_colors
  62. )
  63. # show
  64. cv2.imshow('onnx detection', origin_img)
  65. cv2.waitKey(0)
  66. # save results
  67. os.makedirs(args.output_dir, exist_ok=True)
  68. output_path = os.path.join(args.output_dir, os.path.basename(args.image_path))
  69. cv2.imwrite(output_path, origin_img)