#!/usr/bin/env python3 """ 输入法模型ONNX导出脚本 将模型导出为两个ONNX部分: 1. context_encoder.onnx - 上下文编码器 2. decoder.onnx - 解码器 使用方法: python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models 依赖安装: pip install onnx onnxruntime """ import argparse import os import sys from pathlib import Path import torch import numpy as np # 检查ONNX是否可用 try: import onnx import onnxruntime as ort ONNX_AVAILABLE = True except ImportError: ONNX_AVAILABLE = False print("警告: ONNX或ONNX Runtime未安装") print("请运行: pip install onnx onnxruntime") # 添加src目录到路径 sys.path.insert(0, str(Path(__file__).parent)) from src.model.export_models import create_export_models_from_checkpoint def check_onnx_available(): """检查ONNX依赖是否可用""" if not ONNX_AVAILABLE: print("错误: ONNX导出需要以下依赖:") print(" pip install onnx onnxruntime") print("请安装后重试") return False return True def export_context_encoder(model, output_path, config): """ 导出上下文编码器为ONNX格式 Args: model: ContextEncoderExport实例 output_path: 输出路径 config: 模型配置 """ print(f"正在导出上下文编码器到: {output_path}") # 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理 batch_size = 2 seq_len = config.get("max_seq_len", 128) pinyin_len = 24 # 固定长度 dim = config.get("dim", 512) input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) # 设置动态轴 dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "pinyin_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, "context_H": {0: "batch_size", 1: "seq_len"}, "pinyin_P": {0: "batch_size"}, "context_mask": {0: "batch_size", 1: "seq_len"}, "pinyin_mask": {0: "batch_size"}, } # 导出模型 torch.onnx.export( model, (input_ids, pinyin_ids, attention_mask), output_path, input_names=["input_ids", "pinyin_ids", "attention_mask"], output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"], dynamic_axes=dynamic_axes, opset_version=18, # 支持现代操作符 do_constant_folding=True, verbose=False, ) print("✅ 上下文编码器导出完成") # 验证导出 try: onnx_model = onnx.load(output_path) onnx.checker.check_model(onnx_model) print("✅ ONNX模型验证通过") except Exception as e: print(f"⚠️ ONNX模型验证警告: {e}") return input_ids, pinyin_ids, attention_mask def export_decoder(model, output_path, config, example_inputs=None): """ 导出解码器为ONNX格式 Args: model: DecoderExport实例 output_path: 输出路径 config: 模型配置 example_inputs: 示例输入(用于验证一致性) """ print(f"正在导解码器到: {output_path}") # 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理 batch_size = 2 seq_len = config.get("max_seq_len", 128) pinyin_len = 24 dim = config.get("dim", 512) num_slots = config.get("num_slots", 8) # 使用提供的示例输入或创建新的 if example_inputs is not None: context_H, pinyin_P, context_mask, pinyin_mask = example_inputs batch_size = context_H.size(0) # 使用实际batch size history_slot_ids = torch.randint( 0, 100, (batch_size, num_slots), dtype=torch.long ) else: context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32) pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32) context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32) pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32) history_slot_ids = torch.randint( 0, 100, (batch_size, num_slots), dtype=torch.long ) # 设置动态轴 dynamic_axes = { "context_H": {0: "batch_size", 1: "seq_len"}, "pinyin_P": {0: "batch_size"}, "history_slot_ids": {0: "batch_size"}, "context_mask": {0: "batch_size", 1: "seq_len"}, "pinyin_mask": {0: "batch_size"}, "logits": {0: "batch_size"}, } # 导出模型 torch.onnx.export( model, (context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask), output_path, input_names=[ "context_H", "pinyin_P", "history_slot_ids", "context_mask", "pinyin_mask", ], output_names=["logits"], dynamic_axes=dynamic_axes, opset_version=18, do_constant_folding=True, verbose=False, ) print("✅ 解码器导出完成") # 验证导出 try: onnx_model = onnx.load(output_path) onnx.checker.check_model(onnx_model) print("✅ ONNX模型验证通过") except Exception as e: print(f"⚠️ ONNX模型验证警告: {e}") return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask def save_example_inputs(output_dir, example_inputs_dict): """ 保存示例输入为NPZ文件,用于验证 Args: output_dir: 输出目录 example_inputs_dict: 示例输入字典 """ npz_path = os.path.join(output_dir, "example_inputs.npz") # 转换为numpy数组 np_data = {} for key, tensor in example_inputs_dict.items(): if isinstance(tensor, torch.Tensor): np_data[key] = tensor.cpu().numpy() elif isinstance(tensor, tuple): for i, t in enumerate(tensor): if isinstance(t, torch.Tensor): np_data[f"{key}_{i}"] = t.cpu().numpy() np.savez(npz_path, **np_data) print(f"✅ 示例输入已保存到: {npz_path}") # 同时保存为PyTorch格式 torch_path = os.path.join(output_dir, "example_inputs.pt") torch.save(example_inputs_dict, torch_path) print(f"✅ PyTorch示例输入已保存到: {torch_path}") def create_inference_example(output_dir, config): """ 创建推理示例脚本 Args: output_dir: 输出目录 config: 模型配置 """ example_path = os.path.join(output_dir, "inference_example.py") example_code = '''#!/usr/bin/env python3 """ ONNX模型推理示例 展示如何使用导出的两个ONNX模型进行推理, 包括束搜索(beam search)算法。 """ import numpy as np import onnxruntime as ort import torch import torch.nn.functional as F from typing import List, Tuple class ONNXInference: """ONNX模型推理器""" def __init__(self, context_encoder_path, decoder_path): """ 初始化ONNX推理器 Args: context_encoder_path: 上下文编码器ONNX模型路径 decoder_path: 解码器ONNX模型路径 """ # 创建ONNX Runtime会话 self.context_encoder_session = ort.InferenceSession( context_encoder_path, providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' ) self.decoder_session = ort.InferenceSession( decoder_path, providers=['CPUExecutionProvider'] ) # 获取输入输出名称 self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] print(f"上下文编码器输入: {self.context_input_names}") print(f"上下文编码器输出: {self.context_output_names}") print(f"解码器输入: {self.decoder_input_names}") print(f"解码器输出: {self.decoder_output_names}") def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): """ 准备模型输入(与原始推理脚本保持一致) 注意: 这里需要实现文本到token的转换 为了简化示例,假设已经实现了相关函数 """ # 这里应该调用实际的预处理函数 # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids raise NotImplementedError("请实现实际的输入预处理") def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): """ 运行上下文编码器 Args: input_ids: [batch, seq_len] pinyin_ids: [batch, 24] attention_mask: [batch, seq_len] Returns: context_H, pinyin_P, context_mask, pinyin_mask """ # 准备输入 inputs = { "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, } # 运行推理 outputs = self.context_encoder_session.run(self.context_output_names, inputs) # 解包输出 context_H, pinyin_P, context_mask, pinyin_mask = outputs return ( torch.from_numpy(context_H), torch.from_numpy(pinyin_P), torch.from_numpy(context_mask), torch.from_numpy(pinyin_mask), ) def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): """ 运行解码器 Args: context_H: [batch, seq_len, 512] pinyin_P: [batch, 24, 512] history_slot_ids: [batch, 8] context_mask: [batch, seq_len] pinyin_mask: [batch, 24] Returns: logits: [batch, vocab_size] """ # 准备输入 inputs = { "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, } # 运行推理 outputs = self.decoder_session.run(self.decoder_output_names, inputs) # 解包输出 logits = outputs[0] return torch.from_numpy(logits) def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, beam_size=5, max_length=10, vocab_size=10019): """ 束搜索算法示例 Args: context_H: 上下文编码 pinyin_P: 拼音编码 context_mask: 上下文掩码 pinyin_mask: 拼音掩码 beam_size: 束大小 max_length: 最大生成长度 vocab_size: 词汇表大小 Returns: 最佳序列列表 """ # 初始束:空序列,分数为0 beams = [([], 0.0)] # (序列, 对数概率) for step in range(max_length): new_beams = [] for seq, score in beams: # 构建history_slot_ids:已确认的字符ID if len(seq) < 8: history = seq + [0] * (8 - len(seq)) else: history = seq[-8:] # 只保留最近8个 history_tensor = torch.tensor([history], dtype=torch.long) # 运行解码器 logits = self.run_decoder( context_H, pinyin_P, history_tensor, context_mask, pinyin_mask ) # 获取概率 probs = F.softmax(logits[0], dim=-1) # 获取top-k候选 top_probs, top_indices = torch.topk(probs, beam_size) # 扩展束 for prob, idx in zip(top_probs, top_indices): new_seq = seq + [idx.item()] new_score = score + torch.log(prob).item() new_beams.append((new_seq, new_score)) # 剪枝:保留beam_size个最佳候选 new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] # 检查是否所有序列都已结束(以结束符0结尾) all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) if all_ended: break return beams def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): """ 单步预测 Args: input_ids: 输入token IDs pinyin_ids: 拼音IDs attention_mask: 注意力掩码 history_slot_ids: 历史槽位IDs Returns: 预测logits """ # 1. 运行上下文编码器 context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( input_ids, pinyin_ids, attention_mask ) # 2. 运行解码器 logits = self.run_decoder( context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask ) return logits def main(): """示例主函数""" print("ONNX模型推理示例") print("=" * 60) # 初始化推理器 context_encoder_path = "context_encoder.onnx" decoder_path = "decoder.onnx" if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): print("错误: 找不到ONNX模型文件") print("请先运行export_onnx.py导出模型") return inference = ONNXInference(context_encoder_path, decoder_path) print("✅ ONNX推理器初始化完成") print("请参考此示例实现完整的输入法推理流程") if __name__ == "__main__": main() ''' with open(example_path, "w", encoding="utf-8") as f: f.write(example_code) print(f"✅ 推理示例脚本已保存到: {example_path}") def main(): parser = argparse.ArgumentParser(description="输入法模型ONNX导出") parser.add_argument( "--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径" ) parser.add_argument( "--output-dir", "-o", type=str, default="./exported_models", help="输出目录(默认: ./exported_models)", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="导出设备(默认: cpu)", ) parser.add_argument( "--skip-verification", action="store_true", help="跳过ONNX模型验证" ) args = parser.parse_args() # 检查依赖 if not check_onnx_available(): sys.exit(1) # 创建输出目录 output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"📁 输出目录: {output_dir.absolute()}") # 加载checkpoint并创建导出模型 print(f"📦 加载checkpoint: {args.checkpoint}") context_encoder_export, decoder_export, config = ( create_export_models_from_checkpoint(args.checkpoint, args.device) ) print(f"📊 模型配置: {config}") # 导出上下文编码器 context_encoder_path = output_dir / "context_encoder.onnx" example_inputs = export_context_encoder( context_encoder_export, str(context_encoder_path), config ) # 使用上下文编码器的输出作为解码器的示例输入 with torch.no_grad(): context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export( *example_inputs ) # 导出解码器 decoder_path = output_dir / "decoder.onnx" export_decoder( decoder_export, str(decoder_path), config, example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask), ) # 保存示例输入 example_inputs_dict = { "input_ids": example_inputs[0], "pinyin_ids": example_inputs[1], "attention_mask": example_inputs[2], "context_H": context_H, "pinyin_P": pinyin_P, "context_mask": context_mask, "pinyin_mask": pinyin_mask, } save_example_inputs(output_dir, example_inputs_dict) # 创建推理示例脚本 create_inference_example(output_dir, config) print("\n" + "=" * 60) print("🎉 ONNX导出完成!") print("=" * 60) print(f"生成的模型文件:") print(f" - {context_encoder_path}") print(f" - {decoder_path}") print(f" - {output_dir}/example_inputs.npz") print(f" - {output_dir}/example_inputs.pt") print(f" - {output_dir}/inference_example.py") print("\n使用方法:") print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}") print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py") print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类") print("\n注意:") print(" - 请确保安装了onnxruntime: pip install onnxruntime") print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu") print(" - 束搜索算法需要根据实际需求进行调整") if __name__ == "__main__": main()