87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
输入法模型ONNX导出脚本
|
||
|
||
将模型导出为两个ONNX部分:
|
||
1. context_encoder.onnx - 上下文编码器
|
||
2. decoder.onnx - 解码器
|
||
|
||
使用方法:
|
||
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
|
||
|
||
依赖安装:
|
||
pip install onnx onnxruntime
|
||
"""
|
||
|
||
import argparse
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from src.model.onnx_export import check_onnx_available, run_full_export
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
|
||
parser.add_argument(
|
||
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
"-o",
|
||
type=str,
|
||
default="./exported_models",
|
||
help="输出目录(默认: ./exported_models)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cpu",
|
||
choices=["cpu", "cuda"],
|
||
help="导出设备(默认: cpu)",
|
||
)
|
||
parser.add_argument(
|
||
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
if not check_onnx_available():
|
||
sys.exit(1)
|
||
|
||
print(f"输出目录: {Path(args.output_dir).absolute()}")
|
||
|
||
context_encoder_path, decoder_path, config = run_full_export(
|
||
checkpoint_path=args.checkpoint,
|
||
output_dir=args.output_dir,
|
||
device=args.device,
|
||
skip_verification=args.skip_verification,
|
||
)
|
||
|
||
output_dir = Path(args.output_dir)
|
||
print()
|
||
print("=" * 60)
|
||
print("ONNX导出完成!")
|
||
print("=" * 60)
|
||
print("生成的模型文件:")
|
||
print(f" - {context_encoder_path}")
|
||
print(f" - {decoder_path}")
|
||
print(f" - {output_dir / 'example_inputs.npz'}")
|
||
print(f" - {output_dir / 'example_inputs.pt'}")
|
||
print(f" - {output_dir / 'inference_example.py'}")
|
||
print()
|
||
print("使用方法:")
|
||
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
|
||
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
|
||
print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类")
|
||
print()
|
||
print("注意:")
|
||
print(" - 请确保安装了 onnxruntime: pip install onnxruntime")
|
||
print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu")
|
||
print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|