export_onnx.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. # Thanks to YOLOX: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/tools/export_onnx.py
  5. import argparse
  6. import os
  7. from loguru import logger
  8. import sys
  9. sys.path.append('..')
  10. import torch
  11. from torch import nn
  12. from utils.misc import SiLU
  13. from utils.misc import load_weight, replace_module
  14. from config import build_model_config
  15. from models.detectors import build_model
  16. def make_parser():
  17. parser = argparse.ArgumentParser("YOLO ONNXRuntime")
  18. # basic
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument("--input", default="images", type=str,
  22. help="input node name of onnx model")
  23. parser.add_argument("--output", default="output", type=str,
  24. help="output node name of onnx model")
  25. parser.add_argument("-o", "--opset", default=11, type=int,
  26. help="onnx opset version")
  27. parser.add_argument("--batch-size", type=int, default=1,
  28. help="batch size")
  29. parser.add_argument("--dynamic", action="store_true", default=False,
  30. help="whether the input shape should be dynamic or not")
  31. parser.add_argument("--no-onnxsim", action="store_true", default=False,
  32. help="use onnxsim or not")
  33. parser.add_argument("-f", "--exp_file", default=None, type=str,
  34. help="experiment description file")
  35. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  36. parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
  37. help="Modify config options using the command-line")
  38. parser.add_argument('--save_dir', default='../weights/onnx/', type=str,
  39. help='Dir to save onnx file')
  40. # model
  41. parser.add_argument('-m', '--model', default='yolov1', type=str,
  42. help='build yolo')
  43. parser.add_argument('--weight', default=None,
  44. type=str, help='Trained state_dict file path to open')
  45. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  46. help='confidence threshold')
  47. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  48. help='NMS threshold')
  49. parser.add_argument('--topk', default=100, type=int,
  50. help='topk candidates for testing')
  51. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  52. help='topk candidates for testing')
  53. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  54. help='fuse Conv & BN')
  55. return parser
  56. @logger.catch
  57. def main():
  58. args = make_parser().parse_args()
  59. logger.info("args value: {}".format(args))
  60. device = torch.device('cpu')
  61. # Dataset & Model Config
  62. model_cfg = build_model_config(args)
  63. # build model
  64. model = build_model(args, model_cfg, device, args.num_classes, False, deploy=True)
  65. # replace nn.SiLU with SiLU
  66. model = replace_module(model, nn.SiLU, SiLU)
  67. # load trained weight
  68. model = load_weight(model, args.weight, args.fuse_conv_bn)
  69. model = model.to(device).eval()
  70. logger.info("loading checkpoint done.")
  71. dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size)
  72. # save onnx file
  73. save_path = os.path.join(args.save_dir, str(args.opset))
  74. os.makedirs(save_path, exist_ok=True)
  75. output_name = os.path.join(args.model + '.onnx')
  76. output_path = os.path.join(save_path, output_name)
  77. torch.onnx._export(
  78. model,
  79. dummy_input,
  80. output_path,
  81. input_names=[args.input],
  82. output_names=[output_name],
  83. dynamic_axes={args.input: {0: 'batch'},
  84. output_name: {0: 'batch'}} if args.dynamic else None,
  85. opset_version=args.opset,
  86. )
  87. logger.info("generated onnx model named {}".format(output_path))
  88. if not args.no_onnxsim:
  89. import onnx
  90. from onnxsim import simplify
  91. input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
  92. # use onnxsimplify to reduce reduent model.
  93. onnx_model = onnx.load(output_path)
  94. model_simp, check = simplify(onnx_model,
  95. dynamic_input_shape=args.dynamic,
  96. input_shapes=input_shapes)
  97. assert check, "Simplified ONNX model could not be validated"
  98. # save onnxsim file
  99. save_path = os.path.join(save_path, 'onnxsim')
  100. os.makedirs(save_path, exist_ok=True)
  101. output_path = os.path.join(save_path, output_name)
  102. onnx.save(model_simp, output_path)
  103. logger.info("generated simplified onnx model named {}".format(output_path))
  104. if __name__ == "__main__":
  105. main()