#!/usr/bin/env python3 """ 输入法模型ONNX导出脚本 将模型导出为两个ONNX部分: 1. context_encoder.onnx - 上下文编码器 2. decoder.onnx - 解码器 使用方法: python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models 依赖安装: pip install onnx onnxruntime """ import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) from src.model.onnx_export import check_onnx_available, run_full_export def main(): parser = argparse.ArgumentParser(description="输入法模型ONNX导出") parser.add_argument( "--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径" ) parser.add_argument( "--output-dir", "-o", type=str, default="./exported_models", help="输出目录(默认: ./exported_models)", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="导出设备(默认: cpu)", ) parser.add_argument( "--skip-verification", action="store_true", help="跳过ONNX模型验证" ) args = parser.parse_args() if not check_onnx_available(): sys.exit(1) print(f"输出目录: {Path(args.output_dir).absolute()}") context_encoder_path, decoder_path, config = run_full_export( checkpoint_path=args.checkpoint, output_dir=args.output_dir, device=args.device, skip_verification=args.skip_verification, ) output_dir = Path(args.output_dir) print() print("=" * 60) print("ONNX导出完成!") print("=" * 60) print("生成的模型文件:") print(f" - {context_encoder_path}") print(f" - {decoder_path}") print(f" - {output_dir / 'example_inputs.npz'}") print(f" - {output_dir / 'example_inputs.pt'}") print(f" - {output_dir / 'inference_example.py'}") print() print("使用方法:") print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}") print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py") print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类") print() print("注意:") print(" - 请确保安装了 onnxruntime: pip install onnxruntime") print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu") print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代") if __name__ == "__main__": main()