export_onnx.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. # Thanks to YOLOX: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/tools/export_onnx.py
  5. import argparse
  6. import os
  7. from loguru import logger
  8. import sys
  9. sys.path.append('..')
  10. import torch
  11. from torch import nn
  12. from utils.misc import SiLU
  13. from utils.misc import load_weight, replace_module
  14. from config import build_config
  15. from models.detectors import build_model
  16. def make_parser():
  17. parser = argparse.ArgumentParser("YOLO ONNXRuntime")
  18. # basic
  19. parser.add_argument("--output-name", type=str, default="yolo_free_large.onnx",
  20. help="output name of models")
  21. parser.add_argument('-size', '--img_size', default=640, type=int,
  22. help='the max size of input image')
  23. parser.add_argument("--input", default="images", type=str,
  24. help="input node name of onnx model")
  25. parser.add_argument("--output", default="output", type=str,
  26. help="output node name of onnx model")
  27. parser.add_argument("-o", "--opset", default=11, type=int,
  28. help="onnx opset version")
  29. parser.add_argument("--batch-size", type=int, default=1,
  30. help="batch size")
  31. parser.add_argument("--dynamic", action="store_true", default=False,
  32. help="whether the input shape should be dynamic or not")
  33. parser.add_argument("--no-onnxsim", action="store_true", default=False,
  34. help="use onnxsim or not")
  35. parser.add_argument("-f", "--exp_file", default=None, type=str,
  36. help="experiment description file")
  37. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  38. parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
  39. help="Modify config options using the command-line")
  40. parser.add_argument("--decode_in_inference", action="store_true", default=False,
  41. help="decode in inference or not")
  42. parser.add_argument('--save_dir', default='../weights/onnx/', type=str,
  43. help='Dir to save onnx file')
  44. # model
  45. parser.add_argument('-v', '--version', default='yolo_free_large', type=str,
  46. help='build yolo')
  47. parser.add_argument('--weight', default=None,
  48. type=str, help='Trained state_dict file path to open')
  49. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  50. help='confidence threshold')
  51. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  52. help='NMS threshold')
  53. parser.add_argument('--topk', default=100, type=int,
  54. help='topk candidates for testing')
  55. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  56. help='topk candidates for testing')
  57. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  58. help='fuse Conv & BN')
  59. return parser
  60. @logger.catch
  61. def main():
  62. args = make_parser().parse_args()
  63. logger.info("args value: {}".format(args))
  64. device = torch.device('cpu')
  65. # config
  66. cfg = build_config(args)
  67. # build model
  68. model = build_model(
  69. args=args,
  70. cfg=cfg,
  71. device=device,
  72. num_classes=args.num_classes,
  73. trainable=False
  74. )
  75. # replace nn.SiLU with SiLU
  76. model = replace_module(model, nn.SiLU, SiLU)
  77. # load trained weight
  78. model = load_weight(model, args.weight, args.fuse_conv_bn)
  79. model = model.to(device).eval()
  80. logger.info("loading checkpoint done.")
  81. dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size)
  82. # save onnx file
  83. save_path = os.path.join(args.save_dir, str(args.opset))
  84. os.makedirs(save_path, exist_ok=True)
  85. output_name = os.path.join(save_path, args.output_name)
  86. torch.onnx._export(
  87. model,
  88. dummy_input,
  89. output_name,
  90. input_names=[args.input],
  91. output_names=[args.output],
  92. dynamic_axes={args.input: {0: 'batch'},
  93. args.output: {0: 'batch'}} if args.dynamic else None,
  94. opset_version=args.opset,
  95. )
  96. logger.info("generated onnx model named {}".format(output_name))
  97. if not args.no_onnxsim:
  98. import onnx
  99. from onnxsim import simplify
  100. input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
  101. # use onnxsimplify to reduce reduent model.
  102. onnx_model = onnx.load(output_name)
  103. model_simp, check = simplify(onnx_model,
  104. dynamic_input_shape=args.dynamic,
  105. input_shapes=input_shapes)
  106. assert check, "Simplified ONNX model could not be validated"
  107. # save onnxsim file
  108. save_path = os.path.join(save_path, 'onnxsim')
  109. os.makedirs(save_path, exist_ok=True)
  110. output_name = os.path.join(save_path, args.output_name)
  111. onnx.save(model_simp, output_name)
  112. logger.info("generated simplified onnx model named {}".format(output_name))
  113. if __name__ == "__main__":
  114. main()