export_onnx.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. # Thanks to YOLOX: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/tools/export_onnx.py
  5. import argparse
  6. import os
  7. from loguru import logger
  8. import sys
  9. sys.path.append('..')
  10. import torch
  11. from torch import nn
  12. from utils.misc import SiLU
  13. from utils.misc import load_weight, replace_module
  14. from config import build_config
  15. from models import build_model
  16. def make_parser():
  17. parser = argparse.ArgumentParser("FreeYOLO ONNXRuntime")
  18. # basic
  19. parser.add_argument('--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument("--input", default="images", type=str,
  22. help="input node name of onnx model")
  23. parser.add_argument("--output", default="output", type=str,
  24. help="output node name of onnx model")
  25. parser.add_argument("--opset", default=13, type=int,
  26. help="onnx opset version")
  27. parser.add_argument("--batch-size", type=int, default=1,
  28. help="batch size")
  29. parser.add_argument("--dynamic", action="store_true", default=False,
  30. help="whether the input shape should be dynamic or not")
  31. parser.add_argument("--no-onnxsim", action="store_true", default=False,
  32. help="use onnxsim or not")
  33. parser.add_argument("-f", "--exp_file", default=None, type=str,
  34. help="experiment description file")
  35. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  36. parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
  37. help="Modify config options using the command-line")
  38. # model
  39. parser.add_argument('--model', default='yolov8_n', type=str,
  40. help='build FreeYOLOv2')
  41. parser.add_argument('--weight', default=None,
  42. type=str, help='Trained state_dict file path to open')
  43. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  44. help='fuse Conv & BN')
  45. return parser
  46. @logger.catch
  47. def main():
  48. args = make_parser().parse_args()
  49. logger.info("args value: {}".format(args))
  50. # Build config
  51. cfg = build_config(args)
  52. cfg.num_classes = 80 # for coco
  53. # Build model
  54. model = build_model(args, cfg, is_val=False)
  55. # Load trained weight
  56. model = load_weight(model, args.weight, args.fuse_conv_bn)
  57. model.eval()
  58. logger.info(" => loading checkpoint done.")
  59. dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size)
  60. # save onnx file
  61. save_path = os.path.join(os.path.split(args.weight)[0], str(args.opset))
  62. os.makedirs(save_path, exist_ok=True)
  63. output_name = os.path.join(args.model + '.onnx')
  64. output_path = os.path.join(save_path, output_name)
  65. torch.onnx._export(
  66. model,
  67. dummy_input,
  68. output_path,
  69. input_names=[args.input],
  70. output_names=[output_name],
  71. dynamic_axes={args.input: {0: 'batch'},
  72. output_name: {0: 'batch'}} if args.dynamic else None,
  73. opset_version=args.opset,
  74. )
  75. logger.info("generated onnx model named {}".format(output_path))
  76. if not args.no_onnxsim:
  77. import onnx
  78. from onnxsim import simplify
  79. input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
  80. # use onnxsimplify to reduce reduent model.
  81. onnx_model = onnx.load(output_path)
  82. model_simp, check = simplify(onnx_model,
  83. dynamic_input_shape=args.dynamic,
  84. input_shapes=input_shapes)
  85. assert check, "Simplified ONNX model could not be validated"
  86. # save onnxsim file
  87. save_path = os.path.join(save_path, 'onnxsim')
  88. os.makedirs(save_path, exist_ok=True)
  89. output_path = os.path.join(save_path, output_name)
  90. onnx.save(model_simp, output_path)
  91. logger.info("generated simplified onnx model named {}".format(output_path))
  92. if __name__ == "__main__":
  93. main()