export_onnx.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. # Thanks to YOLOX: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/tools/export_onnx.py
  5. import argparse
  6. import os
  7. from loguru import logger
  8. import sys
  9. sys.path.append('..')
  10. import torch
  11. from torch import nn
  12. from utils.misc import SiLU
  13. from utils.misc import load_weight, replace_module
  14. from config import build_model_config
  15. from models.detectors import build_model
  16. def make_parser():
  17. parser = argparse.ArgumentParser("YOLO ONNXRuntime")
  18. # basic
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument("--input", default="images", type=str,
  22. help="input node name of onnx model")
  23. parser.add_argument("--output", default="output", type=str,
  24. help="output node name of onnx model")
  25. parser.add_argument("-o", "--opset", default=11, type=int,
  26. help="onnx opset version")
  27. parser.add_argument("--batch-size", type=int, default=1,
  28. help="batch size")
  29. parser.add_argument("--dynamic", action="store_true", default=False,
  30. help="whether the input shape should be dynamic or not")
  31. parser.add_argument("--no-onnxsim", action="store_true", default=False,
  32. help="use onnxsim or not")
  33. parser.add_argument("-f", "--exp_file", default=None, type=str,
  34. help="experiment description file")
  35. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  36. parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
  37. help="Modify config options using the command-line")
  38. parser.add_argument('--save_dir', default='../weights/onnx/', type=str,
  39. help='Dir to save onnx file')
  40. # model
  41. parser.add_argument('-m', '--model', default='yolov1', type=str,
  42. help='build yolo')
  43. parser.add_argument('--weight', default=None,
  44. type=str, help='Trained state_dict file path to open')
  45. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  46. help='confidence threshold')
  47. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  48. help='NMS threshold')
  49. parser.add_argument('--topk', default=100, type=int,
  50. help='topk candidates for testing')
  51. parser.add_argument('-nc', '--num_classes', default=80, type=int,
  52. help='topk candidates for testing')
  53. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  54. help='fuse Conv & BN')
  55. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  56. help='Perform NMS operations regardless of category.')
  57. return parser
  58. @logger.catch
  59. def main():
  60. args = make_parser().parse_args()
  61. logger.info("args value: {}".format(args))
  62. device = torch.device('cpu')
  63. # Dataset & Model Config
  64. model_cfg = build_model_config(args)
  65. # build model
  66. model = build_model(args, model_cfg, device, args.num_classes, False, deploy=True)
  67. # replace nn.SiLU with SiLU
  68. model = replace_module(model, nn.SiLU, SiLU)
  69. # load trained weight
  70. model = load_weight(model, args.weight, args.fuse_conv_bn)
  71. model = model.to(device).eval()
  72. logger.info("loading checkpoint done.")
  73. dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size)
  74. # save onnx file
  75. save_path = os.path.join(args.save_dir, str(args.opset))
  76. os.makedirs(save_path, exist_ok=True)
  77. output_name = os.path.join(args.model + '.onnx')
  78. output_path = os.path.join(save_path, output_name)
  79. torch.onnx._export(
  80. model,
  81. dummy_input,
  82. output_path,
  83. input_names=[args.input],
  84. output_names=[output_name],
  85. dynamic_axes={args.input: {0: 'batch'},
  86. output_name: {0: 'batch'}} if args.dynamic else None,
  87. opset_version=args.opset,
  88. )
  89. logger.info("generated onnx model named {}".format(output_path))
  90. if not args.no_onnxsim:
  91. import onnx
  92. from onnxsim import simplify
  93. input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
  94. # use onnxsimplify to reduce reduent model.
  95. onnx_model = onnx.load(output_path)
  96. model_simp, check = simplify(onnx_model,
  97. dynamic_input_shape=args.dynamic,
  98. input_shapes=input_shapes)
  99. assert check, "Simplified ONNX model could not be validated"
  100. # save onnxsim file
  101. save_path = os.path.join(save_path, 'onnxsim')
  102. os.makedirs(save_path, exist_ok=True)
  103. output_path = os.path.join(save_path, output_name)
  104. onnx.save(model_simp, output_path)
  105. logger.info("generated simplified onnx model named {}".format(output_path))
  106. if __name__ == "__main__":
  107. main()