|
|
@@ -14,15 +14,14 @@ from torch import nn
|
|
|
|
|
|
from utils.misc import SiLU
|
|
|
from utils.misc import load_weight, replace_module
|
|
|
-from config import build_config
|
|
|
+
|
|
|
+from config import build_model_config
|
|
|
from models.detectors import build_model
|
|
|
|
|
|
|
|
|
def make_parser():
|
|
|
parser = argparse.ArgumentParser("YOLO ONNXRuntime")
|
|
|
# basic
|
|
|
- parser.add_argument("--output-name", type=str, default="yolo_free_large.onnx",
|
|
|
- help="output name of models")
|
|
|
parser.add_argument('-size', '--img_size', default=640, type=int,
|
|
|
help='the max size of input image')
|
|
|
parser.add_argument("--input", default="images", type=str,
|
|
|
@@ -48,7 +47,7 @@ def make_parser():
|
|
|
help='Dir to save onnx file')
|
|
|
|
|
|
# model
|
|
|
- parser.add_argument('-v', '--version', default='yolo_free_large', type=str,
|
|
|
+ parser.add_argument('-m', '--model', default='yolov1', type=str,
|
|
|
help='build yolo')
|
|
|
parser.add_argument('--weight', default=None,
|
|
|
type=str, help='Trained state_dict file path to open')
|
|
|
@@ -72,17 +71,11 @@ def main():
|
|
|
logger.info("args value: {}".format(args))
|
|
|
device = torch.device('cpu')
|
|
|
|
|
|
- # config
|
|
|
- cfg = build_config(args)
|
|
|
+ # Dataset & Model Config
|
|
|
+ model_cfg = build_model_config(args)
|
|
|
|
|
|
# build model
|
|
|
- model = build_model(
|
|
|
- args=args,
|
|
|
- cfg=cfg,
|
|
|
- device=device,
|
|
|
- num_classes=args.num_classes,
|
|
|
- trainable=False
|
|
|
- )
|
|
|
+ model = build_model(args, model_cfg, device, args.num_classes, False, deploy=True)
|
|
|
|
|
|
# replace nn.SiLU with SiLU
|
|
|
model = replace_module(model, nn.SiLU, SiLU)
|
|
|
@@ -97,20 +90,21 @@ def main():
|
|
|
# save onnx file
|
|
|
save_path = os.path.join(args.save_dir, str(args.opset))
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
- output_name = os.path.join(save_path, args.output_name)
|
|
|
+ output_name = os.path.join(args.model + '.onnx')
|
|
|
+ output_path = os.path.join(save_path, output_name)
|
|
|
|
|
|
torch.onnx._export(
|
|
|
model,
|
|
|
dummy_input,
|
|
|
- output_name,
|
|
|
+ output_path,
|
|
|
input_names=[args.input],
|
|
|
- output_names=[args.output],
|
|
|
+ output_names=[output_name],
|
|
|
dynamic_axes={args.input: {0: 'batch'},
|
|
|
- args.output: {0: 'batch'}} if args.dynamic else None,
|
|
|
+ output_name: {0: 'batch'}} if args.dynamic else None,
|
|
|
opset_version=args.opset,
|
|
|
)
|
|
|
|
|
|
- logger.info("generated onnx model named {}".format(output_name))
|
|
|
+ logger.info("generated onnx model named {}".format(output_path))
|
|
|
|
|
|
if not args.no_onnxsim:
|
|
|
import onnx
|
|
|
@@ -120,7 +114,7 @@ def main():
|
|
|
input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
|
|
|
|
|
|
# use onnxsimplify to reduce reduent model.
|
|
|
- onnx_model = onnx.load(output_name)
|
|
|
+ onnx_model = onnx.load(output_path)
|
|
|
model_simp, check = simplify(onnx_model,
|
|
|
dynamic_input_shape=args.dynamic,
|
|
|
input_shapes=input_shapes)
|
|
|
@@ -129,9 +123,9 @@ def main():
|
|
|
# save onnxsim file
|
|
|
save_path = os.path.join(save_path, 'onnxsim')
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
- output_name = os.path.join(save_path, args.output_name)
|
|
|
- onnx.save(model_simp, output_name)
|
|
|
- logger.info("generated simplified onnx model named {}".format(output_name))
|
|
|
+ output_path = os.path.join(save_path, output_name)
|
|
|
+ onnx.save(model_simp, output_path)
|
|
|
+ logger.info("generated simplified onnx model named {}".format(output_path))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|