SUimeModelTraner/export_onnx.py

87 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
输入法模型ONNX导出脚本
将模型导出为两个ONNX部分
1. context_encoder.onnx - 上下文编码器
2. decoder.onnx - 解码器
使用方法:
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
依赖安装:
pip install onnx onnxruntime
"""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from src.model.onnx_export import check_onnx_available, run_full_export
def main():
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="输出目录(默认: ./exported_models",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="导出设备(默认: cpu",
)
parser.add_argument(
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
)
args = parser.parse_args()
if not check_onnx_available():
sys.exit(1)
print(f"输出目录: {Path(args.output_dir).absolute()}")
context_encoder_path, decoder_path, config = run_full_export(
checkpoint_path=args.checkpoint,
output_dir=args.output_dir,
device=args.device,
skip_verification=args.skip_verification,
)
output_dir = Path(args.output_dir)
print()
print("=" * 60)
print("ONNX导出完成")
print("=" * 60)
print("生成的模型文件:")
print(f" - {context_encoder_path}")
print(f" - {decoder_path}")
print(f" - {output_dir / 'example_inputs.npz'}")
print(f" - {output_dir / 'example_inputs.pt'}")
print(f" - {output_dir / 'inference_example.py'}")
print()
print("使用方法:")
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类")
print()
print("注意:")
print(" - 请确保安装了 onnxruntime: pip install onnxruntime")
print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu")
print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代")
if __name__ == "__main__":
main()