diff --git a/.gitignore b/.gitignore index b210324..3cb1118 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,9 @@ cython_debug/ uv.lock data/* + + +**/*.onnx +**/*.data +**/*.npz +**/*.pt diff --git a/beam_search_demo.py b/beam_search_demo.py new file mode 100644 index 0000000..6f73536 --- /dev/null +++ b/beam_search_demo.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +""" +束搜索算法演示 + +展示如何使用导出的两个ONNX模型进行束搜索推理, +模拟输入法场景:给定上下文、拼音和已确认字符,生成候选汉字序列。 +""" + +import argparse +import os +import sys +from pathlib import Path +from typing import List, Tuple, Dict, Any +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F + +# 添加src目录到路径 +sys.path.insert(0, str(Path(__file__).parent)) + + +class ONNXBeamSearch: + """ + 基于ONNX模型的束搜索解码器 + """ + + def __init__( + self, context_encoder_path: str, decoder_path: str, device: str = "cpu" + ): + """ + 初始化 + + Args: + context_encoder_path: 上下文编码器ONNX模型路径 + decoder_path: 解码器ONNX模型路径 + device: 推理设备(cpu或cuda) + """ + self.device = device + + # 配置ONNX Runtime提供程序 + if device == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + # 创建ONNX Runtime会话 + self.context_session = ort.InferenceSession( + context_encoder_path, providers=providers + ) + self.decoder_session = ort.InferenceSession(decoder_path, providers=providers) + + # 获取输入输出信息 + self.context_input_info = { + input.name: input for input in self.context_session.get_inputs() + } + self.context_output_info = { + output.name: output for output in self.context_session.get_outputs() + } + self.decoder_input_info = { + input.name: input for input in self.decoder_session.get_inputs() + } + self.decoder_output_info = { + output.name: output for output in self.decoder_session.get_outputs() + } + + print("📦 ONNX模型加载完成") + print(f" 上下文编码器输入: {list(self.context_input_info.keys())}") + print(f" 上下文编码器输出: {list(self.context_output_info.keys())}") + print(f" 解码器输入: {list(self.decoder_input_info.keys())}") + print(f" 解码器输出: {list(self.decoder_output_info.keys())}") + + def prepare_inputs( + self, + text_before: str, + text_after: str, + pinyin: str, + context_prompts: List[str] = None, + tokenizer=None, + query_engine=None, + max_seq_len: int = 128, + ) -> Dict[str, np.ndarray]: + """ + 准备模型输入 + + 注意: 这是一个简化的版本,实际应用中需要实现完整的预处理逻辑 + 这里使用随机数据模拟,实际应使用tokenizer和查询引擎 + + Args: + text_before: 光标前文本 + text_after: 光标后文本 + pinyin: 拼音输入 + context_prompts: 上下文提示 + tokenizer: 分词器(未实现) + query_engine: 查询引擎(未实现) + max_seq_len: 最大序列长度 + + Returns: + 输入字典,包含input_ids, pinyin_ids, attention_mask + """ + # 在实际应用中,这里应该实现: + # 1. 使用tokenizer将文本转换为input_ids + # 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids + # 3. 构建attention_mask + + # 简化为随机数据(用于演示) + batch_size = 1 + seq_len = min(max_seq_len, 64) # 简化长度 + + # 模拟输入 + input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64) + pinyin_ids = np.random.randint( + 0, 30, (batch_size, 24), dtype=np.int64 + ) # 固定24长度 + attention_mask = np.ones((batch_size, seq_len), dtype=np.int64) + + # 在实际应用中,应该这样处理拼音: + # from src.model.dataset import text_to_pinyin_ids + # pinyin_ids_list = text_to_pinyin_ids(pinyin) + # pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1) + + return { + "input_ids": input_ids, + "pinyin_ids": pinyin_ids, + "attention_mask": attention_mask, + } + + def run_context_encoder( + self, inputs: Dict[str, np.ndarray] + ) -> Tuple[np.ndarray, ...]: + """ + 运行上下文编码器 + + Args: + inputs: 输入字典 + + Returns: + 上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask) + """ + # 准备ONNX输入(确保顺序正确) + onnx_inputs = {} + for input_name in self.context_input_info.keys(): + if input_name in inputs: + onnx_inputs[input_name] = inputs[input_name] + else: + # 对于缺失的输入,使用默认值 + shape = self.context_input_info[input_name].shape + dtype = self.context_input_info[input_name].type + # 简化处理:创建零数组 + if "int" in dtype: + onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64) + else: + onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32) + + # 运行推理 + outputs = self.context_session.run(None, onnx_inputs) + + return tuple(outputs) + + def run_decoder( + self, + context_H: np.ndarray, + pinyin_P: np.ndarray, + history_slot_ids: np.ndarray, + context_mask: np.ndarray, + pinyin_mask: np.ndarray, + ) -> np.ndarray: + """ + 运行解码器 + + Args: + context_H: 上下文编码 [batch, seq_len, 512] + pinyin_P: 拼音编码 [batch, 24, 512] + history_slot_ids: 历史槽位IDs [batch, 8] + context_mask: 上下文掩码 [batch, seq_len] + pinyin_mask: 拼音掩码 [batch, 24] + + Returns: + logits [batch, vocab_size] + """ + inputs = { + "context_H": context_H, + "pinyin_P": pinyin_P, + "history_slot_ids": history_slot_ids, + "context_mask": context_mask, + "pinyin_mask": pinyin_mask, + } + + # 运行推理 + outputs = self.decoder_session.run(None, inputs) + + return outputs[0] + + def beam_search( + self, + context_H: np.ndarray, + pinyin_P: np.ndarray, + context_mask: np.ndarray, + pinyin_mask: np.ndarray, + beam_size: int = 5, + max_length: int = 8, + vocab_size: int = 10019, + temperature: float = 1.0, + length_penalty: float = 0.6, + ) -> List[Tuple[List[int], float]]: + """ + 束搜索算法 + + Args: + context_H: 上下文编码 + pinyin_P: 拼音编码 + context_mask: 上下文掩码 + pinyin_mask: 拼音掩码 + beam_size: 束大小 + max_length: 最大生成长度 + vocab_size: 词汇表大小 + temperature: 温度参数(控制随机性) + length_penalty: 长度惩罚系数 + + Returns: + 排序后的候选序列列表,每个元素为(序列, 分数) + """ + # 初始束:空序列,对数概率为0 + beams = [([], 0.0)] # (序列, 累计对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + # 构建history_slot_ids + if len(seq) < 8: + # 填充到8个槽位 + history = seq + [0] * (8 - len(seq)) + else: + # 只保留最近8个字符 + history = seq[-8:] + + history_array = np.array([history], dtype=np.int64) + + # 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_array, context_mask, pinyin_mask + ) + + # 应用温度 + logits = logits / temperature + + # 转换为概率 + probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy() + + # 获取top-k候选(k = beam_size * 2 以增加多样性) + top_k = min(beam_size * 2, vocab_size) + top_indices = np.argsort(probs)[-top_k:][::-1] + top_probs = probs[top_indices] + + # 扩展束 + for idx, prob in zip(top_indices, top_probs): + new_seq = seq + [int(idx)] + # 更新分数:累计对数概率 + log(prob) + new_score = score + np.log(prob + 1e-10) + + # 应用长度惩罚 + length_penalized_score = new_score / ( + (5 + len(new_seq)) ** length_penalty + ) + + new_beams.append((new_seq, length_penalized_score, new_score)) + + # 剪枝:保留beam_size个最佳候选 + new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序 + beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]] + + # 检查是否所有序列都已结束(结束符ID为0) + all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0) + if all_ended: + break + + # 返回原始分数(未应用长度惩罚) + return beams + + def interactive_beam_search(self): + """交互式束搜索演示""" + print("\n" + "=" * 60) + print("束搜索交互演示") + print("=" * 60) + + print("\n📝 输入法场景模拟:") + print(" 假设用户正在输入拼音 'nihao',已确认第一个字 '你'") + print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'") + print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)") + print("-" * 60) + + # 模拟输入(实际应用中应从用户获取) + text_before = "今天天气很好" + text_after = "我们去公园玩" + pinyin = "hao" # 继续输入"hao" + slot_chars = ["你"] # 已确认的字符 + context_prompts = ["张三", "李四"] + + print(f"\n📋 输入参数:") + print(f" 光标前文本: '{text_before}'") + print(f" 光标后文本: '{text_after}'") + print(f" 拼音: '{pinyin}'") + print(f" 槽位历史: {slot_chars}") + print(f" 上下文提示: {context_prompts}") + + # 准备输入(简化版,使用随机数据) + print(f"\n🔧 准备模型输入...") + inputs = self.prepare_inputs( + text_before=text_before, + text_after=text_after, + pinyin=pinyin, + context_prompts=context_prompts, + ) + + # 运行上下文编码器 + print(f"🧠 运行上下文编码器...") + context_outputs = self.run_context_encoder(inputs) + context_H, pinyin_P, context_mask, pinyin_mask = context_outputs + + print(f"✅ 上下文编码完成") + print(f" context_H形状: {context_H.shape}") + print(f" pinyin_P形状: {pinyin_P.shape}") + + # 将槽位历史转换为ID(简化:使用随机ID) + # 实际应用中应使用query_engine将汉字转换为ID + slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42 + if len(slot_ids) < 8: + slot_ids = slot_ids + [0] * (8 - len(slot_ids)) + + # 运行束搜索 + print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...") + beams = self.beam_search( + context_H, + pinyin_P, + context_mask, + pinyin_mask, + beam_size=3, + max_length=4, + vocab_size=10019, + ) + + # 显示结果 + print(f"\n🏆 束搜索结果:") + print("-" * 50) + for i, (seq, score) in enumerate(beams): + # 将ID序列转换为汉字(简化:显示ID) + seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq]) + print(f"{i + 1}. 序列: [{seq_str}]") + print(f" 对数概率: {score:.4f}") + print(f" 概率: {np.exp(score):.6f}") + print() + + print("📝 说明:") + print(" - 'END' 表示结束符 (ID: 0)") + print(" - 实际应用中应将ID转换为汉字") + print(" - 最高分数的序列作为最终预测") + + return beams + + +def main(): + parser = argparse.ArgumentParser(description="束搜索算法演示") + parser.add_argument( + "--context-encoder", + type=str, + default="./exported_models/context_encoder.onnx", + help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)", + ) + parser.add_argument( + "--decoder", + type=str, + default="./exported_models/decoder.onnx", + help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="推理设备(默认: cpu)", + ) + parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3)") + parser.add_argument( + "--max-length", type=int, default=4, help="最大生成长度(默认: 4)" + ) + parser.add_argument( + "--interactive", + action="store_true", + default=True, + help="交互模式(默认: True)", + ) + + args = parser.parse_args() + + # 检查模型文件是否存在 + if not os.path.exists(args.context_encoder): + print(f"❌ 上下文编码器文件不存在: {args.context_encoder}") + print(f" 请先运行 export_onnx.py 导出模型") + return + + if not os.path.exists(args.decoder): + print(f"❌ 解码器文件不存在: {args.decoder}") + print(f" 请先运行 export_onnx.py 导出模型") + return + + print("🚀 束搜索算法演示") + print("=" * 60) + print(f"上下文编码器: {args.context_encoder}") + print(f"解码器: {args.decoder}") + print(f"设备: {args.device}") + print(f"束大小: {args.beam_size}") + print(f"最大长度: {args.max_length}") + + # 初始化束搜索器 + beam_searcher = ONNXBeamSearch( + context_encoder_path=args.context_encoder, + decoder_path=args.decoder, + device=args.device, + ) + + if args.interactive: + # 运行交互演示 + beam_searcher.interactive_beam_search() + + print("\n" + "=" * 60) + print("🎉 演示完成") + print("\n💡 下一步:") + print(" 1. 实现完整的输入预处理(tokenizer和拼音转换)") + print(" 2. 集成查询引擎以将ID转换为汉字") + print(" 3. 根据实际场景调整束搜索参数") + print(" 4. 性能优化:批量处理、缓存等") + + +if __name__ == "__main__": + main() diff --git a/export.record b/export.record new file mode 100644 index 0000000..902372c --- /dev/null +++ b/export.record @@ -0,0 +1,112 @@ +📁 输出目录: /home/songsenand/Project/SUimeModelTraner/exported_models +📦 加载checkpoint: /home/songsenand/下载/20260416best_model.pt +Downloading Model from https://www.modelscope.cn to directory: /home/songsenand/.cache/modelscope/hub/models/iic/nlp_structbert_backbone_lite_std + Loading weights: 0%| | 0/103 [00:00=18 to leverage latest ONNX features +W0417 14:32:00.100000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::nms +W0417 14:32:00.101000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_align +W0417 14:32:00.101000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_pool +/home/songsenand/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py:144: UserWarning: The tensor attributes self.pinyin_lstm._flat_weights[0], self.pinyin_lstm._flat_weights[1], self.pinyin_lstm._flat_weights[2], self.pinyin_lstm._flat_weights[3], self.pinyin_lstm._flat_weights[4], self.pinyin_lstm._flat_weights[5], self.pinyin_lstm._flat_weights[6], self.pinyin_lstm._flat_weights[7], self.pinyin_lstm._flat_weights[8], self.pinyin_lstm._flat_weights[9], self.pinyin_lstm._flat_weights[10], self.pinyin_lstm._flat_weights[11], self.pinyin_lstm._flat_weights[12], self.pinyin_lstm._flat_weights[13], self.pinyin_lstm._flat_weights[14], self.pinyin_lstm._flat_weights[15] were assigned during export. Such attributes must be registered as buffers using the `register_buffer` API (https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer). + next(self.gen) +/home/songsenand/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. + return cls.__new__(cls, *args) +The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 14). +Failed to convert the model to the target version 14 using the ONNX C API. The model was not modified +Traceback (most recent call last): + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnxscript/version_converter/__init__.py", line 120, in call + converted_proto = _c_api_utils.call_onnx_api( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api + result = func(proto) + ^^^^^^^^^^^ + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnxscript/version_converter/__init__.py", line 115, in _partial_convert_version + return onnx.version_converter.convert_version( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnx/version_converter.py", line 39, in convert_version + converted_model_str = C.convert_version(model_str, target_version) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: /github/workspace/onnx/version_converter/adapters/no_previous_version.h:24: adapt: Assertion `false` failed: No Previous Version of LayerNormalization exists +/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:487: UserWarning: # The axis name: batch_size will not be used, since it shares the same shape constraints with another axis: batch_size. + rename_mapping = _dynamic_shapes.create_rename_mapping( +/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:487: UserWarning: # The axis name: seq_len will not be used, since it shares the same shape constraints with another axis: seq_len. + rename_mapping = _dynamic_shapes.create_rename_mapping( +/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/torch/nn/modules/transformer.py:531: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.) + output = torch._nested_tensor_from_mask( +/home/songsenand/Project/SUimeModelTraner/export_onnx.py:155: UserWarning: # 'dynamic_axes' is not recommended when dynamo=True, and may lead to 'torch._dynamo.exc.UserError: Constraints violated.' Supply the 'dynamic_shapes' argument instead if export is unsuccessful. + torch.onnx.export( +W0417 14:32:05.304000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_compat.py:133] Setting ONNX exporter to use operator set version 18 because the requested opset_version 14 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features +W0417 14:32:05.492000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::nms +W0417 14:32:05.492000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_align +W0417 14:32:05.492000 710675 .venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_registration.py:110] torchvision is not installed. Skipping torchvision::roi_pool +/home/songsenand/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. + return cls.__new__(cls, *args) +The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 14). +Failed to convert the model to the target version 14 using the ONNX C API. The model was not modified +Traceback (most recent call last): + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnxscript/version_converter/__init__.py", line 120, in call + converted_proto = _c_api_utils.call_onnx_api( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api + result = func(proto) + ^^^^^^^^^^^ + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnxscript/version_converter/__init__.py", line 115, in _partial_convert_version + return onnx.version_converter.convert_version( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/onnx/version_converter.py", line 39, in convert_version + converted_model_str = C.convert_version(model_str, target_version) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: /github/workspace/onnx/version_converter/adapters/no_previous_version.h:24: adapt: Assertion `false` failed: No Previous Version of LayerNormalization exists +/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:487: UserWarning: # The axis name: batch_size will not be used, since it shares the same shape constraints with another axis: batch_size. + rename_mapping = _dynamic_shapes.create_rename_mapping( +/home/songsenand/Project/SUimeModelTraner/.venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:487: UserWarning: # The axis name: seq_len will not be used, since it shares the same shape constraints with another axis: seq_len. + rename_mapping = _dynamic_shapes.create_rename_mapping( +📊 模型配置: {'learning_rate': 1e-06, 'weight_decay': 0.05, 'warmup_ratio': 0.1, 'label_smoothing': 0.1, 'total_steps': 781250} +正在导出上下文编码器到: exported_models/context_encoder.onnx +Applied 77 of general pattern rewrite rules. +✅ 上下文编码器导出完成 +✅ ONNX模型验证通过 +正在导解码器到: exported_models/decoder.onnx +Applied 21 of general pattern rewrite rules. +✅ 解码器导出完成 +✅ ONNX模型验证通过 +✅ 示例输入已保存到: exported_models/example_inputs.npz +✅ PyTorch示例输入已保存到: exported_models/example_inputs.pt +✅ 推理示例脚本已保存到: exported_models/inference_example.py + +============================================================ +🎉 ONNX导出完成! +============================================================ +生成的模型文件: + - exported_models/context_encoder.onnx + - exported_models/decoder.onnx + - exported_models/example_inputs.npz + - exported_models/example_inputs.pt + - exported_models/inference_example.py + +使用方法: + 1. 检查模型: python -m onnx.checker exported_models/context_encoder.onnx + 2. 运行推理示例: cd exported_models && python inference_example.py + 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类 + +注意: + - 请确保安装了onnxruntime: pip install onnxruntime + - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu + - 束搜索算法需要根据实际需求进行调整 diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 0000000..b9c461b --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,562 @@ +#!/usr/bin/env python3 +""" +输入法模型ONNX导出脚本 + +将模型导出为两个ONNX部分: +1. context_encoder.onnx - 上下文编码器 +2. decoder.onnx - 解码器 + +使用方法: + python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models + +依赖安装: + pip install onnx onnxruntime +""" + +import argparse +import os +import sys +from pathlib import Path + +import torch +import numpy as np + +# 检查ONNX是否可用 +try: + import onnx + import onnxruntime as ort + + ONNX_AVAILABLE = True +except ImportError: + ONNX_AVAILABLE = False + print("警告: ONNX或ONNX Runtime未安装") + print("请运行: pip install onnx onnxruntime") + +# 添加src目录到路径 +sys.path.insert(0, str(Path(__file__).parent)) + +from src.model.export_models import create_export_models_from_checkpoint + + +def check_onnx_available(): + """检查ONNX依赖是否可用""" + if not ONNX_AVAILABLE: + print("错误: ONNX导出需要以下依赖:") + print(" pip install onnx onnxruntime") + print("请安装后重试") + return False + return True + + +def export_context_encoder(model, output_path, config): + """ + 导出上下文编码器为ONNX格式 + + Args: + model: ContextEncoderExport实例 + output_path: 输出路径 + config: 模型配置 + """ + print(f"正在导出上下文编码器到: {output_path}") + + # 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理 + batch_size = 2 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 # 固定长度 + dim = config.get("dim", 512) + + input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) + pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) + + # 设置动态轴 + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "pinyin_ids": {0: "batch_size"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + "context_H": {0: "batch_size", 1: "seq_len"}, + "pinyin_P": {0: "batch_size"}, + "context_mask": {0: "batch_size", 1: "seq_len"}, + "pinyin_mask": {0: "batch_size"}, + } + + # 导出模型 + torch.onnx.export( + model, + (input_ids, pinyin_ids, attention_mask), + output_path, + input_names=["input_ids", "pinyin_ids", "attention_mask"], + output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"], + dynamic_axes=dynamic_axes, + opset_version=18, # 支持现代操作符 + do_constant_folding=True, + verbose=False, + ) + + print("✅ 上下文编码器导出完成") + + # 验证导出 + try: + onnx_model = onnx.load(output_path) + onnx.checker.check_model(onnx_model) + print("✅ ONNX模型验证通过") + except Exception as e: + print(f"⚠️ ONNX模型验证警告: {e}") + + return input_ids, pinyin_ids, attention_mask + + +def export_decoder(model, output_path, config, example_inputs=None): + """ + 导出解码器为ONNX格式 + + Args: + model: DecoderExport实例 + output_path: 输出路径 + config: 模型配置 + example_inputs: 示例输入(用于验证一致性) + """ + print(f"正在导解码器到: {output_path}") + + # 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理 + batch_size = 2 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 + dim = config.get("dim", 512) + num_slots = config.get("num_slots", 8) + + # 使用提供的示例输入或创建新的 + if example_inputs is not None: + context_H, pinyin_P, context_mask, pinyin_mask = example_inputs + batch_size = context_H.size(0) # 使用实际batch size + history_slot_ids = torch.randint( + 0, 100, (batch_size, num_slots), dtype=torch.long + ) + else: + context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32) + pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32) + context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32) + pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32) + history_slot_ids = torch.randint( + 0, 100, (batch_size, num_slots), dtype=torch.long + ) + + # 设置动态轴 + dynamic_axes = { + "context_H": {0: "batch_size", 1: "seq_len"}, + "pinyin_P": {0: "batch_size"}, + "history_slot_ids": {0: "batch_size"}, + "context_mask": {0: "batch_size", 1: "seq_len"}, + "pinyin_mask": {0: "batch_size"}, + "logits": {0: "batch_size"}, + } + + # 导出模型 + torch.onnx.export( + model, + (context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask), + output_path, + input_names=[ + "context_H", + "pinyin_P", + "history_slot_ids", + "context_mask", + "pinyin_mask", + ], + output_names=["logits"], + dynamic_axes=dynamic_axes, + opset_version=18, + do_constant_folding=True, + verbose=False, + ) + + print("✅ 解码器导出完成") + + # 验证导出 + try: + onnx_model = onnx.load(output_path) + onnx.checker.check_model(onnx_model) + print("✅ ONNX模型验证通过") + except Exception as e: + print(f"⚠️ ONNX模型验证警告: {e}") + + return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask + + +def save_example_inputs(output_dir, example_inputs_dict): + """ + 保存示例输入为NPZ文件,用于验证 + + Args: + output_dir: 输出目录 + example_inputs_dict: 示例输入字典 + """ + npz_path = os.path.join(output_dir, "example_inputs.npz") + + # 转换为numpy数组 + np_data = {} + for key, tensor in example_inputs_dict.items(): + if isinstance(tensor, torch.Tensor): + np_data[key] = tensor.cpu().numpy() + elif isinstance(tensor, tuple): + for i, t in enumerate(tensor): + if isinstance(t, torch.Tensor): + np_data[f"{key}_{i}"] = t.cpu().numpy() + + np.savez(npz_path, **np_data) + print(f"✅ 示例输入已保存到: {npz_path}") + + # 同时保存为PyTorch格式 + torch_path = os.path.join(output_dir, "example_inputs.pt") + torch.save(example_inputs_dict, torch_path) + print(f"✅ PyTorch示例输入已保存到: {torch_path}") + + +def create_inference_example(output_dir, config): + """ + 创建推理示例脚本 + + Args: + output_dir: 输出目录 + config: 模型配置 + """ + example_path = os.path.join(output_dir, "inference_example.py") + + example_code = '''#!/usr/bin/env python3 +""" +ONNX模型推理示例 + +展示如何使用导出的两个ONNX模型进行推理, +包括束搜索(beam search)算法。 +""" + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from typing import List, Tuple + + +class ONNXInference: + """ONNX模型推理器""" + + def __init__(self, context_encoder_path, decoder_path): + """ + 初始化ONNX推理器 + + Args: + context_encoder_path: 上下文编码器ONNX模型路径 + decoder_path: 解码器ONNX模型路径 + """ + # 创建ONNX Runtime会话 + self.context_encoder_session = ort.InferenceSession( + context_encoder_path, + providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' + ) + self.decoder_session = ort.InferenceSession( + decoder_path, + providers=['CPUExecutionProvider'] + ) + + # 获取输入输出名称 + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] + self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] + self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] + self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] + + print(f"上下文编码器输入: {self.context_input_names}") + print(f"上下文编码器输出: {self.context_output_names}") + print(f"解码器输入: {self.decoder_input_names}") + print(f"解码器输出: {self.decoder_output_names}") + + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): + """ + 准备模型输入(与原始推理脚本保持一致) + + 注意: 这里需要实现文本到token的转换 + 为了简化示例,假设已经实现了相关函数 + """ + # 这里应该调用实际的预处理函数 + # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids + raise NotImplementedError("请实现实际的输入预处理") + + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): + """ + 运行上下文编码器 + + Args: + input_ids: [batch, seq_len] + pinyin_ids: [batch, 24] + attention_mask: [batch, seq_len] + + Returns: + context_H, pinyin_P, context_mask, pinyin_mask + """ + # 准备输入 + inputs = { + "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, + "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, + "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, + } + + # 运行推理 + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + + # 解包输出 + context_H, pinyin_P, context_mask, pinyin_mask = outputs + + return ( + torch.from_numpy(context_H), + torch.from_numpy(pinyin_P), + torch.from_numpy(context_mask), + torch.from_numpy(pinyin_mask), + ) + + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + """ + 运行解码器 + + Args: + context_H: [batch, seq_len, 512] + pinyin_P: [batch, 24, 512] + history_slot_ids: [batch, 8] + context_mask: [batch, seq_len] + pinyin_mask: [batch, 24] + + Returns: + logits: [batch, vocab_size] + """ + # 准备输入 + inputs = { + "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, + "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, + "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, + "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, + "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, + } + + # 运行推理 + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + + # 解包输出 + logits = outputs[0] + return torch.from_numpy(logits) + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + beam_size=5, max_length=10, vocab_size=10019): + """ + 束搜索算法示例 + + Args: + context_H: 上下文编码 + pinyin_P: 拼音编码 + context_mask: 上下文掩码 + pinyin_mask: 拼音掩码 + beam_size: 束大小 + max_length: 最大生成长度 + vocab_size: 词汇表大小 + + Returns: + 最佳序列列表 + """ + # 初始束:空序列,分数为0 + beams = [([], 0.0)] # (序列, 对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + # 构建history_slot_ids:已确认的字符ID + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] # 只保留最近8个 + + history_tensor = torch.tensor([history], dtype=torch.long) + + # 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_tensor, + context_mask, pinyin_mask + ) + + # 获取概率 + probs = F.softmax(logits[0], dim=-1) + + # 获取top-k候选 + top_probs, top_indices = torch.topk(probs, beam_size) + + # 扩展束 + for prob, idx in zip(top_probs, top_indices): + new_seq = seq + [idx.item()] + new_score = score + torch.log(prob).item() + new_beams.append((new_seq, new_score)) + + # 剪枝:保留beam_size个最佳候选 + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + + # 检查是否所有序列都已结束(以结束符0结尾) + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + + return beams + + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): + """ + 单步预测 + + Args: + input_ids: 输入token IDs + pinyin_ids: 拼音IDs + attention_mask: 注意力掩码 + history_slot_ids: 历史槽位IDs + + Returns: + 预测logits + """ + # 1. 运行上下文编码器 + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + + # 2. 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_slot_ids, + context_mask, pinyin_mask + ) + + return logits + + +def main(): + """示例主函数""" + print("ONNX模型推理示例") + print("=" * 60) + + # 初始化推理器 + context_encoder_path = "context_encoder.onnx" + decoder_path = "decoder.onnx" + + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): + print("错误: 找不到ONNX模型文件") + print("请先运行export_onnx.py导出模型") + return + + inference = ONNXInference(context_encoder_path, decoder_path) + + print("✅ ONNX推理器初始化完成") + print("请参考此示例实现完整的输入法推理流程") + + +if __name__ == "__main__": + main() +''' + + with open(example_path, "w", encoding="utf-8") as f: + f.write(example_code) + + print(f"✅ 推理示例脚本已保存到: {example_path}") + + +def main(): + parser = argparse.ArgumentParser(description="输入法模型ONNX导出") + parser.add_argument( + "--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径" + ) + parser.add_argument( + "--output-dir", + "-o", + type=str, + default="./exported_models", + help="输出目录(默认: ./exported_models)", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="导出设备(默认: cpu)", + ) + parser.add_argument( + "--skip-verification", action="store_true", help="跳过ONNX模型验证" + ) + + args = parser.parse_args() + + # 检查依赖 + if not check_onnx_available(): + sys.exit(1) + + # 创建输出目录 + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"📁 输出目录: {output_dir.absolute()}") + + # 加载checkpoint并创建导出模型 + print(f"📦 加载checkpoint: {args.checkpoint}") + context_encoder_export, decoder_export, config = ( + create_export_models_from_checkpoint(args.checkpoint, args.device) + ) + + print(f"📊 模型配置: {config}") + + # 导出上下文编码器 + context_encoder_path = output_dir / "context_encoder.onnx" + example_inputs = export_context_encoder( + context_encoder_export, str(context_encoder_path), config + ) + + # 使用上下文编码器的输出作为解码器的示例输入 + with torch.no_grad(): + context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export( + *example_inputs + ) + + # 导出解码器 + decoder_path = output_dir / "decoder.onnx" + export_decoder( + decoder_export, + str(decoder_path), + config, + example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask), + ) + + # 保存示例输入 + example_inputs_dict = { + "input_ids": example_inputs[0], + "pinyin_ids": example_inputs[1], + "attention_mask": example_inputs[2], + "context_H": context_H, + "pinyin_P": pinyin_P, + "context_mask": context_mask, + "pinyin_mask": pinyin_mask, + } + save_example_inputs(output_dir, example_inputs_dict) + + # 创建推理示例脚本 + create_inference_example(output_dir, config) + + print("\n" + "=" * 60) + print("🎉 ONNX导出完成!") + print("=" * 60) + print(f"生成的模型文件:") + print(f" - {context_encoder_path}") + print(f" - {decoder_path}") + print(f" - {output_dir}/example_inputs.npz") + print(f" - {output_dir}/example_inputs.pt") + print(f" - {output_dir}/inference_example.py") + print("\n使用方法:") + print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}") + print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py") + print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类") + print("\n注意:") + print(" - 请确保安装了onnxruntime: pip install onnxruntime") + print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu") + print(" - 束搜索算法需要根据实际需求进行调整") + + +if __name__ == "__main__": + main() diff --git a/exported_models/inference_example.py b/exported_models/inference_example.py new file mode 100644 index 0000000..a140239 --- /dev/null +++ b/exported_models/inference_example.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +ONNX模型推理示例 + +展示如何使用导出的两个ONNX模型进行推理, +包括束搜索(beam search)算法。 +""" + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from typing import List, Tuple + + +class ONNXInference: + """ONNX模型推理器""" + + def __init__(self, context_encoder_path, decoder_path): + """ + 初始化ONNX推理器 + + Args: + context_encoder_path: 上下文编码器ONNX模型路径 + decoder_path: 解码器ONNX模型路径 + """ + # 创建ONNX Runtime会话 + self.context_encoder_session = ort.InferenceSession( + context_encoder_path, + providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' + ) + self.decoder_session = ort.InferenceSession( + decoder_path, + providers=['CPUExecutionProvider'] + ) + + # 获取输入输出名称 + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] + self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] + self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] + self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] + + print(f"上下文编码器输入: {self.context_input_names}") + print(f"上下文编码器输出: {self.context_output_names}") + print(f"解码器输入: {self.decoder_input_names}") + print(f"解码器输出: {self.decoder_output_names}") + + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): + """ + 准备模型输入(与原始推理脚本保持一致) + + 注意: 这里需要实现文本到token的转换 + 为了简化示例,假设已经实现了相关函数 + """ + # 这里应该调用实际的预处理函数 + # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids + raise NotImplementedError("请实现实际的输入预处理") + + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): + """ + 运行上下文编码器 + + Args: + input_ids: [batch, seq_len] + pinyin_ids: [batch, 24] + attention_mask: [batch, seq_len] + + Returns: + context_H, pinyin_P, context_mask, pinyin_mask + """ + # 准备输入 + inputs = { + "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, + "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, + "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, + } + + # 运行推理 + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + + # 解包输出 + context_H, pinyin_P, context_mask, pinyin_mask = outputs + + return ( + torch.from_numpy(context_H), + torch.from_numpy(pinyin_P), + torch.from_numpy(context_mask), + torch.from_numpy(pinyin_mask), + ) + + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + """ + 运行解码器 + + Args: + context_H: [batch, seq_len, 512] + pinyin_P: [batch, 24, 512] + history_slot_ids: [batch, 8] + context_mask: [batch, seq_len] + pinyin_mask: [batch, 24] + + Returns: + logits: [batch, vocab_size] + """ + # 准备输入 + inputs = { + "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, + "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, + "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, + "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, + "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, + } + + # 运行推理 + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + + # 解包输出 + logits = outputs[0] + return torch.from_numpy(logits) + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + beam_size=5, max_length=10, vocab_size=10019): + """ + 束搜索算法示例 + + Args: + context_H: 上下文编码 + pinyin_P: 拼音编码 + context_mask: 上下文掩码 + pinyin_mask: 拼音掩码 + beam_size: 束大小 + max_length: 最大生成长度 + vocab_size: 词汇表大小 + + Returns: + 最佳序列列表 + """ + # 初始束:空序列,分数为0 + beams = [([], 0.0)] # (序列, 对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + # 构建history_slot_ids:已确认的字符ID + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] # 只保留最近8个 + + history_tensor = torch.tensor([history], dtype=torch.long) + + # 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_tensor, + context_mask, pinyin_mask + ) + + # 获取概率 + probs = F.softmax(logits[0], dim=-1) + + # 获取top-k候选 + top_probs, top_indices = torch.topk(probs, beam_size) + + # 扩展束 + for prob, idx in zip(top_probs, top_indices): + new_seq = seq + [idx.item()] + new_score = score + torch.log(prob).item() + new_beams.append((new_seq, new_score)) + + # 剪枝:保留beam_size个最佳候选 + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + + # 检查是否所有序列都已结束(以结束符0结尾) + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + + return beams + + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): + """ + 单步预测 + + Args: + input_ids: 输入token IDs + pinyin_ids: 拼音IDs + attention_mask: 注意力掩码 + history_slot_ids: 历史槽位IDs + + Returns: + 预测logits + """ + # 1. 运行上下文编码器 + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + + # 2. 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_slot_ids, + context_mask, pinyin_mask + ) + + return logits + + +def main(): + """示例主函数""" + print("ONNX模型推理示例") + print("=" * 60) + + # 初始化推理器 + context_encoder_path = "context_encoder.onnx" + decoder_path = "decoder.onnx" + + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): + print("错误: 找不到ONNX模型文件") + print("请先运行export_onnx.py导出模型") + return + + inference = ONNXInference(context_encoder_path, decoder_path) + + print("✅ ONNX推理器初始化完成") + print("请参考此示例实现完整的输入法推理流程") + + +if __name__ == "__main__": + main() diff --git a/exported_models_test/inference_example.py b/exported_models_test/inference_example.py new file mode 100644 index 0000000..a140239 --- /dev/null +++ b/exported_models_test/inference_example.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +ONNX模型推理示例 + +展示如何使用导出的两个ONNX模型进行推理, +包括束搜索(beam search)算法。 +""" + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from typing import List, Tuple + + +class ONNXInference: + """ONNX模型推理器""" + + def __init__(self, context_encoder_path, decoder_path): + """ + 初始化ONNX推理器 + + Args: + context_encoder_path: 上下文编码器ONNX模型路径 + decoder_path: 解码器ONNX模型路径 + """ + # 创建ONNX Runtime会话 + self.context_encoder_session = ort.InferenceSession( + context_encoder_path, + providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' + ) + self.decoder_session = ort.InferenceSession( + decoder_path, + providers=['CPUExecutionProvider'] + ) + + # 获取输入输出名称 + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] + self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] + self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] + self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] + + print(f"上下文编码器输入: {self.context_input_names}") + print(f"上下文编码器输出: {self.context_output_names}") + print(f"解码器输入: {self.decoder_input_names}") + print(f"解码器输出: {self.decoder_output_names}") + + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): + """ + 准备模型输入(与原始推理脚本保持一致) + + 注意: 这里需要实现文本到token的转换 + 为了简化示例,假设已经实现了相关函数 + """ + # 这里应该调用实际的预处理函数 + # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids + raise NotImplementedError("请实现实际的输入预处理") + + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): + """ + 运行上下文编码器 + + Args: + input_ids: [batch, seq_len] + pinyin_ids: [batch, 24] + attention_mask: [batch, seq_len] + + Returns: + context_H, pinyin_P, context_mask, pinyin_mask + """ + # 准备输入 + inputs = { + "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, + "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, + "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, + } + + # 运行推理 + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + + # 解包输出 + context_H, pinyin_P, context_mask, pinyin_mask = outputs + + return ( + torch.from_numpy(context_H), + torch.from_numpy(pinyin_P), + torch.from_numpy(context_mask), + torch.from_numpy(pinyin_mask), + ) + + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + """ + 运行解码器 + + Args: + context_H: [batch, seq_len, 512] + pinyin_P: [batch, 24, 512] + history_slot_ids: [batch, 8] + context_mask: [batch, seq_len] + pinyin_mask: [batch, 24] + + Returns: + logits: [batch, vocab_size] + """ + # 准备输入 + inputs = { + "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, + "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, + "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, + "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, + "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, + } + + # 运行推理 + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + + # 解包输出 + logits = outputs[0] + return torch.from_numpy(logits) + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + beam_size=5, max_length=10, vocab_size=10019): + """ + 束搜索算法示例 + + Args: + context_H: 上下文编码 + pinyin_P: 拼音编码 + context_mask: 上下文掩码 + pinyin_mask: 拼音掩码 + beam_size: 束大小 + max_length: 最大生成长度 + vocab_size: 词汇表大小 + + Returns: + 最佳序列列表 + """ + # 初始束:空序列,分数为0 + beams = [([], 0.0)] # (序列, 对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + # 构建history_slot_ids:已确认的字符ID + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] # 只保留最近8个 + + history_tensor = torch.tensor([history], dtype=torch.long) + + # 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_tensor, + context_mask, pinyin_mask + ) + + # 获取概率 + probs = F.softmax(logits[0], dim=-1) + + # 获取top-k候选 + top_probs, top_indices = torch.topk(probs, beam_size) + + # 扩展束 + for prob, idx in zip(top_probs, top_indices): + new_seq = seq + [idx.item()] + new_score = score + torch.log(prob).item() + new_beams.append((new_seq, new_score)) + + # 剪枝:保留beam_size个最佳候选 + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + + # 检查是否所有序列都已结束(以结束符0结尾) + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + + return beams + + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): + """ + 单步预测 + + Args: + input_ids: 输入token IDs + pinyin_ids: 拼音IDs + attention_mask: 注意力掩码 + history_slot_ids: 历史槽位IDs + + Returns: + 预测logits + """ + # 1. 运行上下文编码器 + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + + # 2. 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_slot_ids, + context_mask, pinyin_mask + ) + + return logits + + +def main(): + """示例主函数""" + print("ONNX模型推理示例") + print("=" * 60) + + # 初始化推理器 + context_encoder_path = "context_encoder.onnx" + decoder_path = "decoder.onnx" + + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): + print("错误: 找不到ONNX模型文件") + print("请先运行export_onnx.py导出模型") + return + + inference = ONNXInference(context_encoder_path, decoder_path) + + print("✅ ONNX推理器初始化完成") + print("请参考此示例实现完整的输入法推理流程") + + +if __name__ == "__main__": + main() diff --git a/exported_models_test_fixed/inference_example.py b/exported_models_test_fixed/inference_example.py new file mode 100644 index 0000000..a140239 --- /dev/null +++ b/exported_models_test_fixed/inference_example.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +ONNX模型推理示例 + +展示如何使用导出的两个ONNX模型进行推理, +包括束搜索(beam search)算法。 +""" + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from typing import List, Tuple + + +class ONNXInference: + """ONNX模型推理器""" + + def __init__(self, context_encoder_path, decoder_path): + """ + 初始化ONNX推理器 + + Args: + context_encoder_path: 上下文编码器ONNX模型路径 + decoder_path: 解码器ONNX模型路径 + """ + # 创建ONNX Runtime会话 + self.context_encoder_session = ort.InferenceSession( + context_encoder_path, + providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' + ) + self.decoder_session = ort.InferenceSession( + decoder_path, + providers=['CPUExecutionProvider'] + ) + + # 获取输入输出名称 + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] + self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] + self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] + self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] + + print(f"上下文编码器输入: {self.context_input_names}") + print(f"上下文编码器输出: {self.context_output_names}") + print(f"解码器输入: {self.decoder_input_names}") + print(f"解码器输出: {self.decoder_output_names}") + + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): + """ + 准备模型输入(与原始推理脚本保持一致) + + 注意: 这里需要实现文本到token的转换 + 为了简化示例,假设已经实现了相关函数 + """ + # 这里应该调用实际的预处理函数 + # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids + raise NotImplementedError("请实现实际的输入预处理") + + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): + """ + 运行上下文编码器 + + Args: + input_ids: [batch, seq_len] + pinyin_ids: [batch, 24] + attention_mask: [batch, seq_len] + + Returns: + context_H, pinyin_P, context_mask, pinyin_mask + """ + # 准备输入 + inputs = { + "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, + "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, + "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, + } + + # 运行推理 + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + + # 解包输出 + context_H, pinyin_P, context_mask, pinyin_mask = outputs + + return ( + torch.from_numpy(context_H), + torch.from_numpy(pinyin_P), + torch.from_numpy(context_mask), + torch.from_numpy(pinyin_mask), + ) + + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + """ + 运行解码器 + + Args: + context_H: [batch, seq_len, 512] + pinyin_P: [batch, 24, 512] + history_slot_ids: [batch, 8] + context_mask: [batch, seq_len] + pinyin_mask: [batch, 24] + + Returns: + logits: [batch, vocab_size] + """ + # 准备输入 + inputs = { + "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, + "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, + "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, + "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, + "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, + } + + # 运行推理 + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + + # 解包输出 + logits = outputs[0] + return torch.from_numpy(logits) + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + beam_size=5, max_length=10, vocab_size=10019): + """ + 束搜索算法示例 + + Args: + context_H: 上下文编码 + pinyin_P: 拼音编码 + context_mask: 上下文掩码 + pinyin_mask: 拼音掩码 + beam_size: 束大小 + max_length: 最大生成长度 + vocab_size: 词汇表大小 + + Returns: + 最佳序列列表 + """ + # 初始束:空序列,分数为0 + beams = [([], 0.0)] # (序列, 对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + # 构建history_slot_ids:已确认的字符ID + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] # 只保留最近8个 + + history_tensor = torch.tensor([history], dtype=torch.long) + + # 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_tensor, + context_mask, pinyin_mask + ) + + # 获取概率 + probs = F.softmax(logits[0], dim=-1) + + # 获取top-k候选 + top_probs, top_indices = torch.topk(probs, beam_size) + + # 扩展束 + for prob, idx in zip(top_probs, top_indices): + new_seq = seq + [idx.item()] + new_score = score + torch.log(prob).item() + new_beams.append((new_seq, new_score)) + + # 剪枝:保留beam_size个最佳候选 + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + + # 检查是否所有序列都已结束(以结束符0结尾) + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + + return beams + + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): + """ + 单步预测 + + Args: + input_ids: 输入token IDs + pinyin_ids: 拼音IDs + attention_mask: 注意力掩码 + history_slot_ids: 历史槽位IDs + + Returns: + 预测logits + """ + # 1. 运行上下文编码器 + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + + # 2. 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_slot_ids, + context_mask, pinyin_mask + ) + + return logits + + +def main(): + """示例主函数""" + print("ONNX模型推理示例") + print("=" * 60) + + # 初始化推理器 + context_encoder_path = "context_encoder.onnx" + decoder_path = "decoder.onnx" + + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): + print("错误: 找不到ONNX模型文件") + print("请先运行export_onnx.py导出模型") + return + + inference = ONNXInference(context_encoder_path, decoder_path) + + print("✅ ONNX推理器初始化完成") + print("请参考此示例实现完整的输入法推理流程") + + +if __name__ == "__main__": + main() diff --git a/exported_models_test_new/inference_example.py b/exported_models_test_new/inference_example.py new file mode 100644 index 0000000..a140239 --- /dev/null +++ b/exported_models_test_new/inference_example.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +ONNX模型推理示例 + +展示如何使用导出的两个ONNX模型进行推理, +包括束搜索(beam search)算法。 +""" + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from typing import List, Tuple + + +class ONNXInference: + """ONNX模型推理器""" + + def __init__(self, context_encoder_path, decoder_path): + """ + 初始化ONNX推理器 + + Args: + context_encoder_path: 上下文编码器ONNX模型路径 + decoder_path: 解码器ONNX模型路径 + """ + # 创建ONNX Runtime会话 + self.context_encoder_session = ort.InferenceSession( + context_encoder_path, + providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' + ) + self.decoder_session = ort.InferenceSession( + decoder_path, + providers=['CPUExecutionProvider'] + ) + + # 获取输入输出名称 + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] + self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] + self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] + self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] + + print(f"上下文编码器输入: {self.context_input_names}") + print(f"上下文编码器输出: {self.context_output_names}") + print(f"解码器输入: {self.decoder_input_names}") + print(f"解码器输出: {self.decoder_output_names}") + + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): + """ + 准备模型输入(与原始推理脚本保持一致) + + 注意: 这里需要实现文本到token的转换 + 为了简化示例,假设已经实现了相关函数 + """ + # 这里应该调用实际的预处理函数 + # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids + raise NotImplementedError("请实现实际的输入预处理") + + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): + """ + 运行上下文编码器 + + Args: + input_ids: [batch, seq_len] + pinyin_ids: [batch, 24] + attention_mask: [batch, seq_len] + + Returns: + context_H, pinyin_P, context_mask, pinyin_mask + """ + # 准备输入 + inputs = { + "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, + "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, + "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, + } + + # 运行推理 + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + + # 解包输出 + context_H, pinyin_P, context_mask, pinyin_mask = outputs + + return ( + torch.from_numpy(context_H), + torch.from_numpy(pinyin_P), + torch.from_numpy(context_mask), + torch.from_numpy(pinyin_mask), + ) + + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + """ + 运行解码器 + + Args: + context_H: [batch, seq_len, 512] + pinyin_P: [batch, 24, 512] + history_slot_ids: [batch, 8] + context_mask: [batch, seq_len] + pinyin_mask: [batch, 24] + + Returns: + logits: [batch, vocab_size] + """ + # 准备输入 + inputs = { + "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, + "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, + "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, + "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, + "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, + } + + # 运行推理 + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + + # 解包输出 + logits = outputs[0] + return torch.from_numpy(logits) + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + beam_size=5, max_length=10, vocab_size=10019): + """ + 束搜索算法示例 + + Args: + context_H: 上下文编码 + pinyin_P: 拼音编码 + context_mask: 上下文掩码 + pinyin_mask: 拼音掩码 + beam_size: 束大小 + max_length: 最大生成长度 + vocab_size: 词汇表大小 + + Returns: + 最佳序列列表 + """ + # 初始束:空序列,分数为0 + beams = [([], 0.0)] # (序列, 对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + # 构建history_slot_ids:已确认的字符ID + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] # 只保留最近8个 + + history_tensor = torch.tensor([history], dtype=torch.long) + + # 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_tensor, + context_mask, pinyin_mask + ) + + # 获取概率 + probs = F.softmax(logits[0], dim=-1) + + # 获取top-k候选 + top_probs, top_indices = torch.topk(probs, beam_size) + + # 扩展束 + for prob, idx in zip(top_probs, top_indices): + new_seq = seq + [idx.item()] + new_score = score + torch.log(prob).item() + new_beams.append((new_seq, new_score)) + + # 剪枝:保留beam_size个最佳候选 + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + + # 检查是否所有序列都已结束(以结束符0结尾) + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + + return beams + + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): + """ + 单步预测 + + Args: + input_ids: 输入token IDs + pinyin_ids: 拼音IDs + attention_mask: 注意力掩码 + history_slot_ids: 历史槽位IDs + + Returns: + 预测logits + """ + # 1. 运行上下文编码器 + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + + # 2. 运行解码器 + logits = self.run_decoder( + context_H, pinyin_P, history_slot_ids, + context_mask, pinyin_mask + ) + + return logits + + +def main(): + """示例主函数""" + print("ONNX模型推理示例") + print("=" * 60) + + # 初始化推理器 + context_encoder_path = "context_encoder.onnx" + decoder_path = "decoder.onnx" + + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): + print("错误: 找不到ONNX模型文件") + print("请先运行export_onnx.py导出模型") + return + + inference = ONNXInference(context_encoder_path, decoder_path) + + print("✅ ONNX推理器初始化完成") + print("请参考此示例实现完整的输入法推理流程") + + +if __name__ == "__main__": + main() diff --git a/onnx_inference.py b/onnx_inference.py new file mode 100644 index 0000000..cd4da72 --- /dev/null +++ b/onnx_inference.py @@ -0,0 +1,763 @@ +#!/usr/bin/env python3 +""" +ONNX输入法模型推理脚本 + +使用ONNX Runtime进行推理,测量每个阶段的执行时长 + +使用方法: + python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx + +交互模式: 分步询问输入 + 1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空) + 2. 光标前文本: 光标前的连续文本 + 3. 光标后文本: 光标后的连续文本 + 4. 拼音: 当前输入的拼音 + 5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上") +""" + +import argparse +import os +import sys +import time +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from modelscope import AutoTokenizer + +from src.model.dataset import text_to_pinyin_ids +from src.model.query import QueryEngine + + +class ONNXInference: + """ONNX输入法模型推理器""" + + def __init__( + self, + context_encoder_path: str, + decoder_path: str, + vocab_size: int = 10019, + device: str = "cpu", + use_beam_search: bool = False, + beam_size: int = 5, + ): + self.vocab_size = vocab_size + self.device = device + self.use_beam_search = use_beam_search + self.beam_size = beam_size + + # 加载组件 + print(f"正在加载上下文编码器: {context_encoder_path}") + load_start = time.perf_counter() + self.load_context_encoder(context_encoder_path) + self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000 + print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)") + + print(f"正在加载解码器: {decoder_path}") + load_start = time.perf_counter() + self.load_decoder(decoder_path) + self.decoder_load_time = (time.perf_counter() - load_start) * 1000 + print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)") + + # 加载tokenizer + print("正在加载tokenizer...") + load_start = time.perf_counter() + self.load_tokenizer() + self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000 + print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)") + + # 加载查询引擎 + print("正在加载查询引擎...") + load_start = time.perf_counter() + self.load_query_engine() + self.query_engine_load_time = (time.perf_counter() - load_start) * 1000 + print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)") + + total_load_time = ( + self.context_encoder_load_time + + self.decoder_load_time + + self.tokenizer_load_time + + self.query_engine_load_time + ) + print(f"\n✅ 推理器初始化完成 (设备: {device})") + print(f" 总加载时间: {total_load_time:.2f}ms") + + # 尝试启用readline + try: + import readline + + readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?") + except ImportError: + pass + + def load_context_encoder(self, model_path: str): + """加载上下文编码器ONNX模型""" + providers = ( + ["CUDAExecutionProvider", "CPUExecutionProvider"] + if self.device == "cuda" + else ["CPUExecutionProvider"] + ) + self.context_encoder_session = ort.InferenceSession( + model_path, providers=providers + ) + + self.context_input_names = [ + inp.name for inp in self.context_encoder_session.get_inputs() + ] + self.context_output_names = [ + out.name for out in self.context_encoder_session.get_outputs() + ] + + def load_decoder(self, model_path: str): + """加载解码器ONNX模型""" + providers = ( + ["CUDAExecutionProvider", "CPUExecutionProvider"] + if self.device == "cuda" + else ["CPUExecutionProvider"] + ) + self.decoder_session = ort.InferenceSession(model_path, providers=providers) + + self.decoder_input_names = [ + inp.name for inp in self.decoder_session.get_inputs() + ] + self.decoder_output_names = [ + out.name for out in self.decoder_session.get_outputs() + ] + + def load_tokenizer(self): + """加载tokenizer""" + try: + tokenizer_path = ( + Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" + ) + self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) + except Exception: + print(" ⚠️ 无法加载自定义tokenizer,使用bert-base-chinese") + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") + + def load_query_engine(self): + """加载查询引擎""" + try: + self.query_engine = QueryEngine() + stats_path = ( + Path(__file__).parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + if stats_path.exists(): + self.query_engine.load(stats_path) + except Exception: + self.query_engine = None + + def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int: + """将汉字转换为ID""" + if char == "//": + return 0 + + if self.query_engine is None: + return ord(char) if len(char) == 1 else 0 + + try: + if pinyin is not None: + info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin) + if info: + return info.id + + results = self.query_engine.query_by_char(char, limit=1) + if results: + return results[0][0] + return 0 + except: + return 0 + + def id_to_char(self, id: int) -> str: + """将ID转换为汉字""" + if id == 0: + return "//" + + if self.query_engine is None: + return chr(id) if id < 0x110000 else f"[ID:{id}]" + + try: + info = self.query_engine.query_by_id(id) + return info.char if info else f"[ID:{id}]" + except: + return f"[ID:{id}]" + + def _clean_pinyin_input(self, pinyin: str) -> str: + """清理拼音输入字符串""" + if not pinyin: + return "" + + result = [] + for c in pinyin: + is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"] + if is_valid: + result.append(c.lower()) + elif c == " ": + continue + elif c in ["\b", "\x7f", "\x08"]: + if result: + result.pop() + elif c == "\x1b": + result.clear() + return "".join(result) + + def _safe_input(self, prompt: str, default: str = "") -> str: + """安全的输入函数""" + try: + full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: " + result = input(full_prompt) + if not result and default: + return default + return result.strip() + except (EOFError, KeyboardInterrupt): + print() + return "" + except Exception as e: + print(f"\n⚠️ 输入错误: {e}") + return "" + + def prepare_inputs( + self, + context_prompts: List[str], + text_before: str, + text_after: str, + pinyin: str, + slot_chars: List[str], + max_seq_len: int = 128, + ) -> dict: + """ + 准备模型输入 + + Returns: + dict: { + 'preprocess_time': float, # 预处理时间(ms) + 'input_ids': numpy array, + 'attention_mask': numpy array, + 'pinyin_ids': numpy array, + 'history_slot_ids': numpy array, + } + """ + preprocess_start = time.perf_counter() + + # 1. 构建tokenizer输入 + context_text = "|".join(context_prompts) if context_prompts else "" + + if context_text: + input_text = f"{context_text}|{text_before}" + else: + input_text = text_before + + # 2. Tokenize + encoded = self.tokenizer( + input_text, + text_after, + max_length=max_seq_len, + padding="max_length", + truncation=True, + return_tensors="pt", + return_token_type_ids=True, + ) + + input_ids = encoded["input_ids"].numpy() + attention_mask = encoded["attention_mask"].numpy() + + # 3. 处理拼音输入 + cleaned_pinyin = self._clean_pinyin_input(pinyin) + pinyin_ids = text_to_pinyin_ids(cleaned_pinyin) + + if len(pinyin_ids) < 24: + pinyin_ids.extend([0] * (24 - len(pinyin_ids))) + else: + pinyin_ids = pinyin_ids[:24] + + pinyin_ids = np.array([pinyin_ids], dtype=np.int64) + + # 4. 处理历史槽位 + history_slot_ids = [] + for char in slot_chars: + char_id = self.char_to_id(char) + history_slot_ids.append(char_id) + + if len(history_slot_ids) < 8: + history_slot_ids.extend([0] * (8 - len(history_slot_ids))) + else: + history_slot_ids = history_slot_ids[:8] + + history_slot_ids = np.array([history_slot_ids], dtype=np.int64) + + preprocess_time = (time.perf_counter() - preprocess_start) * 1000 + + return { + "preprocess_time": preprocess_time, + "input_ids": input_ids, + "attention_mask": attention_mask, + "pinyin_ids": pinyin_ids, + "history_slot_ids": history_slot_ids, + } + + def run_context_encoder( + self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + 运行上下文编码器 + + Returns: + context_H, pinyin_P, context_mask, pinyin_mask + """ + context_start = time.perf_counter() + + inputs = { + "input_ids": input_ids, + "pinyin_ids": pinyin_ids, + "attention_mask": attention_mask, + } + + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + + context_H, pinyin_P, context_mask, pinyin_mask = outputs + + self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000 + + return context_H, pinyin_P, context_mask, pinyin_mask + + def run_decoder( + self, + context_H: np.ndarray, + pinyin_P: np.ndarray, + history_slot_ids: np.ndarray, + context_mask: np.ndarray, + pinyin_mask: np.ndarray, + ) -> np.ndarray: + """ + 运行解码器 + + Returns: + logits: [batch, vocab_size] + """ + decoder_start = time.perf_counter() + + inputs = { + "context_H": context_H, + "pinyin_P": pinyin_P, + "history_slot_ids": history_slot_ids, + "context_mask": context_mask, + "pinyin_mask": pinyin_mask, + } + + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + + self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000 + + return outputs[0] + + def predict( + self, + context_prompts: List[str], + text_before: str, + text_after: str, + pinyin: str, + slot_chars: List[str], + top_k: int = 20, + use_beam_search: bool = False, + beam_size: int = 5, + max_length: int = 10, + ) -> Tuple[List[Tuple[str, float, int]], dict]: + """ + 执行推理 + + Args: + context_prompts: 上下文提示 + text_before: 光标前文本 + text_after: 光标后文本 + pinyin: 当前输入的拼音 + slot_chars: 槽位内的汉字列表 + top_k: 返回top-k个预测结果 + use_beam_search: 是否使用束搜索 + beam_size: 束大小 + max_length: 最大生成长度 + + Returns: + (predictions, timing_info) + predictions: List[Tuple[char, prob, id]] + timing_info: 各阶段耗时字典 + """ + total_start = time.perf_counter() + + # 阶段1: 预处理 + prep_start = time.perf_counter() + inputs = self.prepare_inputs( + context_prompts, text_before, text_after, pinyin, slot_chars + ) + preprocess_time = inputs["preprocess_time"] + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + pinyin_ids = inputs["pinyin_ids"] + history_slot_ids = inputs["history_slot_ids"] + prep_time = (time.perf_counter() - prep_start) * 1000 + + # 阶段2: 上下文编码 + context_start = time.perf_counter() + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + context_encoder_time = self.last_context_encoder_time + + if use_beam_search: + # 阶段3: 束搜索解码 + decode_start = time.perf_counter() + predictions, beam_decode_time = self._beam_search_decode( + context_H, + pinyin_P, + context_mask, + pinyin_mask, + beam_size, + max_length, + top_k, + ) + decoder_time = beam_decode_time + else: + # 阶段3: 单步解码 + decode_start = time.perf_counter() + logits = self.run_decoder( + context_H, + pinyin_P, + history_slot_ids, + context_mask, + pinyin_mask, + ) + + # 阶段4: 后处理 + postprocess_start = time.perf_counter() + probs = self._softmax(logits) + top_indices, top_probs = self._topk(probs, top_k) + + predictions = [] + for i in range(top_k): + idx = int(top_indices[0, i]) + prob = float(top_probs[0, i]) + char = self.id_to_char(idx) + predictions.append((char, prob, idx)) + + postprocess_time = (time.perf_counter() - postprocess_start) * 1000 + decoder_time = self.last_decoder_time + + total_time = (time.perf_counter() - total_start) * 1000 + + timing_info = { + "预处理": prep_time, + "上下文编码": context_encoder_time, + "解码": decoder_time, + "后处理": postprocess_time if not use_beam_search else 0, + "总耗时": total_time, + } + + return predictions, timing_info + + def _softmax(self, logits: np.ndarray) -> np.ndarray: + """计算softmax""" + exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) + return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) + + def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: + """获取top-k""" + topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1] + topk_probs = np.take_along_axis(probs, topk_indices, axis=-1) + return topk_indices, topk_probs + + def _beam_search_decode( + self, + context_H: np.ndarray, + pinyin_P: np.ndarray, + context_mask: np.ndarray, + pinyin_mask: np.ndarray, + beam_size: int, + max_length: int, + top_k: int, + ) -> Tuple[List[Tuple[str, float, int]], float]: + """束搜索解码""" + beams = [([], 0.0)] # (序列, 对数概率) + + for step in range(max_length): + new_beams = [] + + for seq, score in beams: + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] + + history_tensor = np.array([history], dtype=np.int64) + + logits = self.run_decoder( + context_H, + pinyin_P, + history_tensor, + context_mask, + pinyin_mask, + ) + + probs = self._softmax(logits)[0] + topk_indices = np.argsort(probs)[-beam_size:][::-1] + topk_probs = probs[topk_indices] + + for idx, prob in zip(topk_indices, topk_probs): + new_seq = seq + [int(idx)] + new_score = score + np.log(prob + 1e-10) + new_beams.append((new_seq, new_score)) + + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + + # 返回top-k个候选 + predictions = [] + for seq, score in beams[:top_k]: + if seq: + char = self.id_to_char(seq[-1]) + prob = np.exp(score / max(len(seq), 1)) + else: + char = self.id_to_char(0) + prob = 0.0 + predictions.append((char, prob, seq[-1] if seq else 0)) + + decode_time = self.last_decoder_time # 只记录最后一次解码的时间 + + return predictions, decode_time + + def interactive_mode(self): + """交互式推理模式""" + print("\n" + "=" * 60) + print("ONNX输入法模型推理 - 交互模式") + print("=" * 60) + + encoding = sys.stdout.encoding or "unknown" + print(f"终端编码: {encoding}") + + print("\n说明:") + print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") + print(" - 光标前文本: 光标前的连续文本") + print(" - 光标后文本: 光标后的连续文本") + print(" - 拼音: 当前输入的拼音") + print(" - 槽位历史: 用户已确认的输入历史") + if self.use_beam_search: + print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})") + else: + print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)") + print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出") + print("-" * 60) + + while True: + try: + print("\n" + "=" * 60) + context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)") + if context_input.lower() in ["quit", "exit", "q"]: + break + + context_prompts = [ + item.strip() for item in context_input.split(",") if item.strip() + ] + + print("\n" + "-" * 40) + text_before = self._safe_input("第2步: 光标前文本") + if text_before.lower() in ["quit", "exit", "q"]: + break + + print("\n" + "-" * 40) + text_after = self._safe_input("第3步: 光标后文本") + if text_after.lower() in ["quit", "exit", "q"]: + break + + print("\n" + "-" * 40) + pinyin = self._safe_input("第4步: 拼音输入") + if pinyin.lower() in ["quit", "exit", "q"]: + break + + print("\n" + "-" * 40) + slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)") + if slot_input.lower() in ["quit", "exit", "q"]: + break + + slot_chars = [ + char.strip() for char in slot_input.split(",") if char.strip() + ] + + print("\n" + "=" * 60) + print("📝 输入汇总:") + print(f" 上下文提示: {context_prompts if context_prompts else '无'}") + print(f" 光标前文本: '{text_before}'") + print(f" 光标后文本: '{text_after}'") + print(f" 拼音: '{pinyin}'") + print(f" 槽位历史: {slot_chars if slot_chars else '无'}") + + print("\n🔮 推理中...") + predictions, timing_info = self.predict( + context_prompts, + text_before, + text_after, + pinyin, + slot_chars, + top_k=20, + use_beam_search=self.use_beam_search, + beam_size=self.beam_size, + ) + + # 显示时间统计 + print(f"\n⏱️ 执行时间统计:") + print("-" * 40) + for stage, duration in timing_info.items(): + if duration > 0: + print(f" {stage:<12}: {duration:>8.2f} ms") + print("-" * 40) + + # 显示结果 + print("\n🏆 Top-20 预测结果:") + print("-" * 50) + for i, (char, prob, idx) in enumerate(predictions): + if char == "//": + print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}") + else: + print( + f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}" + ) + + # 显示拼音参考 + if pinyin and self.query_engine: + print(f"\n📖 拼音 '{pinyin}' 的常见汉字:") + pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10) + if pinyin_results: + for j, (pid, char, count) in enumerate(pinyin_results): + print(f" {char} (频次: {count:,})") + + print("\n" + "-" * 40) + continue_input = ( + self._safe_input("是否继续推理?(y/n)", "y").strip().lower() + ) + if continue_input not in ["y", "yes", ""]: + break + + except KeyboardInterrupt: + print("\n\n退出交互模式") + break + except Exception as e: + print(f"\n❌ 推理出错: {e}") + import traceback + + traceback.print_exc() + + +def main(): + parser = argparse.ArgumentParser(description="ONNX输入法模型推理") + parser.add_argument( + "--context-encoder", + "-c", + type=str, + required=True, + help="上下文编码器ONNX模型路径", + ) + parser.add_argument( + "--decoder", + "-d", + type=str, + required=True, + help="解码器ONNX模型路径", + ) + parser.add_argument( + "--vocab-size", + type=int, + default=10019, + help="词汇表大小 (默认: 10019)", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="推理设备 (默认: cpu)", + ) + parser.add_argument( + "--interactive", + action="store_true", + default=True, + help="交互模式 (默认: True)", + ) + parser.add_argument("--test", action="store_true", help="运行测试推理") + parser.add_argument( + "--beam", + action="store_true", + help="使用束搜索解码 (默认: 单步解码)", + ) + parser.add_argument( + "--beam-size", + type=int, + default=5, + help="束大小 (默认: 5)", + ) + + args = parser.parse_args() + + # 检查文件是否存在 + if not os.path.exists(args.context_encoder): + print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}") + sys.exit(1) + if not os.path.exists(args.decoder): + print(f"❌ 错误: 解码器文件不存在: {args.decoder}") + sys.exit(1) + + # 初始化推理器 + inference = ONNXInference( + context_encoder_path=args.context_encoder, + decoder_path=args.decoder, + vocab_size=args.vocab_size, + device=args.device, + use_beam_search=args.beam, + beam_size=args.beam_size, + ) + + # 测试推理 + if args.test: + print("\n🧪 运行测试推理...") + print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'") + print("上下文提示: 张三、李四(模型不掌握的专有名词)") + + predictions, timing_info = inference.predict( + context_prompts=["张三", "李四"], + text_before="今天天气", + text_after="很好", + pinyin="tian", + slot_chars=["上"], + use_beam_search=args.beam, + beam_size=args.beam_size, + ) + + print(f"\n⏱️ 执行时间统计:") + print("-" * 40) + for stage, duration in timing_info.items(): + if duration > 0: + print(f" {stage:<12}: {duration:>8.2f} ms") + print("-" * 40) + + print(f"\nTop-5 结果:") + for i, (char, prob, idx) in enumerate(predictions[:5]): + if char == "//": + print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}") + else: + print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}") + + # 交互模式 + if args.interactive: + inference.interactive_mode() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index cabac32..3408a04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "onnxruntime>=1.24.2", "pandas>=3.0.0", "plotly>=5.0.0", + "jieba>=0.42.1", "pypinyin>=0.55.0", "requests>=2.32.5", "rich>=14.3.1", @@ -23,6 +24,7 @@ dependencies = [ "transformers==5.1.0", "typer>=0.21.1", "waitress>=3.0.2", + "onnx>=1.21.0", ] [project.scripts] diff --git a/src/analyze_data.py b/src/analyze_data.py new file mode 100644 index 0000000..b5a5719 --- /dev/null +++ b/src/analyze_data.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +数据分布查看脚本 +主要关注 labels 为:0-10, 200-210, 2900-2910, 5100-5110, 9520-9530 的分布情况 + +使用方法: + uv run python src/analyze_data.py --data_path ./data/train.txt + +注意: + data_path 可以是: + - 本地文本文件 (每行一个文本) + - HuggingFace 数据集路径 + - 目录路径 (该目录下需要有 dataset_info.json 或类似配置文件) +""" + +import os +import random +from collections import defaultdict + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from model.dataset import PinyinInputDataset + + +def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 10000): + """分析 label 在指定区间的分布""" + target_ranges = [ + (0, 10), + (200, 210), + (2900, 2910), + (5100, 5110), + (9520, 9530), + ] + + id_counts = defaultdict(int) + all_examples = [] + + dataloader = DataLoader(dataset, batch_size=1, num_workers=16) + + total_count = 0 + sample_collected = 0 + + for batch in dataloader: + if sample_collected >= sample_size: + break + + label = batch["label"].item() + prefix = batch["prefix"][0] + suffix = batch["suffix"][0] + pinyin = batch["pinyin"][0] + history = batch["history_slot_ids"][0].tolist() + part4 = prefix.split("^")[0] if "^" in prefix else "" + + sample_collected += 1 + total_count += 1 + + in_target_range = False + for start, end in target_ranges: + if start <= label <= end: + id_counts[label] += 1 + in_target_range = True + + if len(all_examples) < 200: + all_examples.append( + { + "label": label, + "prefix": prefix, + "suffix": suffix, + "pinyin": pinyin, + "history": history, + "part4": part4, + } + ) + + print(f"\n{'=' * 60}") + print(f"采样总数: {total_count}") + print(f"{'=' * 60}\n") + + print("Label ID 分布统计:") + print("-" * 50) + for start, end in target_ranges: + print(f"\n区间 [{start:5d} - {end:5d}]:") + for label_id in range(start, end + 1): + count = id_counts[label_id] + percentage = (count / total_count * 100) if total_count > 0 else 0 + bar = "█" * min(count, 50) + print(f" ID {label_id:5d}: {count:6d} ({percentage:.3f}%) {bar}") + + print("\n") + print("=" * 80) + print("随机抽取 20 个样本详情:") + print("=" * 80) + + random.shuffle(all_examples) + for idx, ex in enumerate(all_examples[:20], 1): + print(f"\n样本 {idx}:") + print(f" Label: {ex['label']}") + print(f" Part4: {ex['part4']}") + print(f" 光标前: {ex['prefix']}") + print(f" 光标后: {ex['suffix']}") + print(f" 拼音: {ex['pinyin']}") + print(f" 历史槽位: {ex['history']}") + + +def main(): + import argparse + + default_data_path = os.path.expanduser("~/Data/corpus/CCI-Data/") + + parser = argparse.ArgumentParser( + description="分析数据集 label 分布", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +默认数据集路径: {default_data_path} + +示例: + uv run python src/analyze_data.py --data_path {default_data_path} --sample_size 10000 + uv run python src/analyze_data.py --data_path ./data/eval.txt --sample_size 5000 + """, + ) + parser.add_argument( + "--data_path", + type=str, + default=default_data_path, + help="数据集路径 (本地文件或HuggingFace路径)", + ) + parser.add_argument("--sample_size", type=int, default=10000, help="采样大小") + parser.add_argument("--max_workers", type=int, default=-1, help="DataLoader workers") + args = parser.parse_args() + + print(f"加载数据集: {args.data_path}") + print() + + try: + dataset = PinyinInputDataset( + data_path=args.data_path, + max_workers=args.max_workers, + max_iter_length=args.sample_size, + ) + except Exception as e: + print(f"\n错误: 无法加载数据集 '{args.data_path}'") + print(f"原因: {e}") + print("\n请确保:") + print(" 1. 数据路径存在且正确") + print(" 2. 如果是本地文本文件,每行应该是一个 JSON 对象或纯文本") + print(" 3. 如果是 HuggingFace 数据集,路径应该正确") + return + + analyze_label_distribution(dataset, sample_size=args.sample_size) + + +if __name__ == "__main__": + main() diff --git a/src/model/assets/pinyin_char_statistics.json b/src/model/assets/pinyin_char_statistics.json index 66ad104..506e141 100644 --- a/src/model/assets/pinyin_char_statistics.json +++ b/src/model/assets/pinyin_char_statistics.json @@ -8,7 +8,7 @@ "id": 0, "char": "", "pinyin": "", - "count": 11067734826 + "count": 494748360 }, "1": { "id": 1, diff --git a/src/model/dataset.py b/src/model/dataset.py index 4a07de5..3f8d928 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -1,3 +1,4 @@ +import jieba import random import re from importlib.resources import files @@ -23,6 +24,26 @@ CHAR_TO_ID["'"] = 28 # 显式添加引号 CHAR_TO_ID["-"] = 29 # 显式添加短横 +jieba.setLogLevel(jieba.logging.INFO) + + +def segment_text(text: str) -> List[str]: + """使用 jieba 分词,返回词列表""" + return list(jieba.cut(text, HMM=False)) + + +def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]: + """建立词边界列表 [(start, end), ...],基于顺序位置累加""" + result = [] + pos = 0 + for word in words: + start = pos + end = pos + len(word) + result.append((start, end)) + pos = end + return result + + def text_to_pinyin_ids(pinyin_str: str) -> List[int]: """ 将拼音字符串转换为 ID 列表。 @@ -41,17 +62,22 @@ class PinyinInputDataset(IterableDataset): max_iter_length=1e6, max_seq_length=128, text_field: str = "text", - py_style_weight=(9, 2, 1), + py_style_weight=(90, 2, 1), shuffle_buffer_size: int = 100000, retention_ratio: float = 0.8, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ): # 频率调整参数 (可根据需要调整) - self.drop_start_freq = 30_000_000 - self.max_drop_prob = 0.8 + self.drop_start_freq = 10_000_000 + self.max_drop_prob = 0.9 self.repeat_end_freq = 10_000 self.max_repeat_expect = 50 self.min_freq = 109 + self.word_break_prob = 0.10 + self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] + self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] + + jieba.initialize() self.tokenizer = AutoTokenizer.from_pretrained( Path(str(files(__package__))) / "assets" / "tokenizer" @@ -83,14 +109,15 @@ class PinyinInputDataset(IterableDataset): # 提取每个样本的目标字符及其频率 self.sample_freqs = self.query_engine.get_all_weights() + self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0 def adjust_frequency(self, freq: int) -> int: """削峰填谷 - 根据频率调整采样次数,0表示丢弃""" # 1. 削峰处理(高频字) if freq >= self.drop_start_freq: # 线性丢弃概率计算 - max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值 - if max_freq == self.drop_start_freq: + max_freq = self.max_freq # 使用预计算的最大频率值 + if max_freq <= self.drop_start_freq: drop_prob = 0.0 else: drop_prob = ( @@ -119,7 +146,11 @@ class PinyinInputDataset(IterableDataset): ) # 使用泊松分布实现随机重复 repeat_count = np.random.poisson(repeat_expect) - return max(1, repeat_count) + if repeat_expect < 1.0: + # 小期望值时,以概率 repeat_expect 采样 1 次 + return 1 if random.random() < repeat_expect else 0 + else: + return max(1, repeat_count) # 原逻辑 # 3. 中间频率字 else: @@ -196,6 +227,55 @@ class PinyinInputDataset(IterableDataset): mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin + def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor: + pinyin_ids = text_to_pinyin_ids(pinyin_str) + len_py = len(pinyin_ids) + if len_py < 24: + pinyin_ids.extend([0] * (24 - len_py)) + else: + pinyin_ids = pinyin_ids[:24] + return torch.tensor(pinyin_ids, dtype=torch.long) + + def _add_word_samples( + self, + batch_samples: list, + labels: list, + encoded: dict, + part4: str, + part1: str, + part3: str, + pinyin_str: str, + pinyin_ids: torch.Tensor, + ) -> list: + for label_idx, label in enumerate(labels): + base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0)) + if base_repeats == 0: + continue + weight = ( + self._history_weights[label_idx] + if label_idx < len(self._history_weights) + else 3.0 + ) + repeats = max(1, int(base_repeats * weight)) + + history = labels[:label_idx] + len_h = len(history) + history.extend([0] * (8 - len_h)) + + sample_dict = { + "input_ids": encoded["input_ids"], + "token_type_ids": encoded["token_type_ids"], + "attention_mask": encoded["attention_mask"], + "label": torch.tensor([label], dtype=torch.long), + "history_slot_ids": torch.tensor(history, dtype=torch.long), + "prefix": f"{part4}^{part1}", + "suffix": part3, + "pinyin": pinyin_str, + "pinyin_ids": pinyin_ids, + } + batch_samples.extend([sample_dict] * repeats) + return batch_samples + def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: @@ -208,219 +288,219 @@ class PinyinInputDataset(IterableDataset): random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) - # 安全检查:如果worker_id >= num_workers,则该worker不应该工作 - # 这可能发生在self.max_workers小于实际worker数量时 if worker_id >= num_workers: - return # 产生空迭代器 + return - # 使用局部变量存储分片数据集,避免竞争条件 worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id) - # 计算每个worker的配额 - # 将 max_iter_length 转换为整数以确保整数除法 total_quota = int(self.max_iter_length) base_quota = total_quota // num_workers remainder = total_quota % num_workers - # 最后一个worker处理剩余的样本(如果有余数) if worker_id == num_workers - 1: worker_quota = base_quota + remainder else: worker_quota = base_quota else: - # 单worker情况,使用全部配额 worker_quota = int(self.max_iter_length) num_workers = 1 - worker_dataset = self.dataset # 不使用分片 + worker_dataset = self.dataset - # 每个worker有自己的迭代计数器 current_iter_index = 0 batch_samples = [] for sample in worker_dataset: - # 检查是否达到最大迭代次数 if current_iter_index >= worker_quota: break text = sample.get(self.text_field, "") - if text: - pinyin_list = self.generate_pinyin(text) - for i in range(len(text)): - # 在开始处理每个字符前检查配额 - if current_iter_index >= worker_quota: - break + if not text: + continue - labels = [] - # 如果text[i]不在字符库中,则跳过 - # 当i小于48时候,则将part1取text[0:i] - # 当i大于48时候,则将part1取text[i-48:i] - if not self.query_engine.is_chinese_char(text[i]): - continue - if i < 48: - part1 = text[0:i] - else: - part1 = text[i - 48 : i] + words = segment_text(text) + word_boundaries = build_word_boundaries(words) + pinyin_list = self.generate_pinyin(text) - # 方案C:提前检查从位置i开始连续有多少个字符在词库中 - max_valid_len = 0 - for j in range(i, min(i + 8, len(text))): - if self.query_engine.is_chinese_char(text[j]): - max_valid_len += 1 + for word_start, word_end in word_boundaries: + char_positions = [] + for i in range(word_start, word_end): + if self.query_engine.is_chinese_char(text[i]): + char_positions.append(i) + + if not char_positions: + continue + + word_len_chars = len(char_positions) + + should_break = ( + word_len_chars > 1 and random.random() < self.word_break_prob + ) + + if should_break: + break_pos = random.randint(1, word_len_chars - 1) + else: + break_pos = word_len_chars + + # ========== Phase 1: 前缀/整词 ========== + prefix_positions = char_positions[:break_pos] + prefix_text = "".join(text[i] for i in prefix_positions) + prefix_pinyin = [pinyin_list[i] for i in prefix_positions] + + _, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin) + split_char = np.random.choice( + ["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02] + ) + part2 = split_char.join(mask_pinyin) + pinyin_ids = self._compute_pinyin_ids(part2) + + try: + labels = [ + self.query_engine.get_char_info_by_char_pinyin( + text[i], pinyin_list[i] + ).id + for i in prefix_positions + ] + except AttributeError as e: + logger.error( + f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}" + ) + continue + + # 整词末尾 10% 概率追加 EOS(破词前缀不加) + if not should_break and random.random() <= 0.1: + labels.append(0) + + # part1: 词起点前的文本(所有样本共享) + part1 = text[max(0, word_start - 48) : word_start] + + # part3: 词后文本 + part3 = "" + if random.random() > 0.7: + part3 = text[word_end : word_end + np.random.choice(range(1, 17))] + + # part4: 词提示 + part4 = "" + if random.random() > 0.7: + num_words = random.randint(1, 3) + if words: + selected_words = random.sample( + words, min(num_words, len(words)) + ) + part4 = "|".join(selected_words) + + encoded = self.tokenizer( + f"{part4}|{part1}", + part3, + max_length=self.max_seq_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_token_type_ids=True, + ) + + batch_samples = self._add_word_samples( + batch_samples, + labels, + encoded, + part4, + part1, + part3, + part2, + pinyin_ids, + ) + + # ========== Phase 2: 破词续接 ========== + if should_break and break_pos < word_len_chars: + cont_start = char_positions[break_pos] + + # 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止 + target_len = np.random.choice(range(1, 9), p=self.cont_length_probs) + cont_positions = [] + pos = cont_start + while len(cont_positions) < target_len and pos < len(text): + if self.query_engine.is_chinese_char(text[pos]): + cont_positions.append(pos) else: break + pos += 1 - # 如果没有可用字符,跳过 - if max_valid_len == 0: + if not cont_positions: continue - # 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3 - # 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len - # part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2 - # 但是需要注意边界条件 - target_len = np.random.choice( - range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] - ) - # 根据实际可用长度调整 - pinyin_len = min(target_len, max_valid_len) + cont_text = "".join(text[i] for i in cont_positions) + cont_pinyin = [pinyin_list[i] for i in cont_positions] - py_end = min(i + pinyin_len, len(text)) - pinyin_len, part2 = self.get_mask_pinyin( - text[i:py_end], pinyin_list[i:py_end] - ) - - split_char = np.random.choice( + _, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin) + split_char_cont = np.random.choice( ["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02] ) + part2_cont = split_char_cont.join(mask_pinyin_cont) + pinyin_ids_cont = self._compute_pinyin_ids(part2_cont) - part2 = split_char.join(part2) - pinyin_ids = text_to_pinyin_ids(part2) - len_py = len(pinyin_ids) - if len_py < 24: - pinyin_ids.extend([0] * (24 - len_py)) - else: - pinyin_ids = pinyin_ids[:24] - pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) - - # part3为文本,大概率(0.70)为空 - # 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符 - # x为1-16中的任意整数,取值平均分布 - part3 = "" - if random.random() > 0.7: - part3 = text[ - i + pinyin_len : i - + pinyin_len - + np.random.choice(range(1, 17)) - ] - - # part4为文本,0.30的概率为空 - # 不为空则为1-5个连续字符串 - # 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符 - # x为2-6中的任意整数,取值平均分布 - # 使用|将part4中的字符串连接起来 - part4 = "" - if random.random() > 0.7: - # 生成1-5个连续字符串 - num_strings = random.randint(1, 5) - string_list = [] - for _ in range(num_strings): - # 随机选择起始位置 - start_pos = random.randint(0, len(text) - 1) - # 随机选择x的值(2-6) - x = random.randint(2, 6) - # 获取连续字符串 - end_pos = min(start_pos + x + 1, len(text)) - string_list.append(text[start_pos:end_pos]) - # 用|连接所有字符串 - part4 = "|".join(string_list) try: - labels = [ - self.query_engine.get_char_info_by_char_pinyin(c, p).id - for c, p in zip( - text[i : i + pinyin_len], - pinyin_list[i : i + pinyin_len], - ) + cont_labels = [ + self.query_engine.get_char_info_by_char_pinyin( + text[i], pinyin_list[i] + ).id + for i in cont_positions ] except AttributeError as e: logger.error( - f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}" + f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}" ) continue - if random.random() <= 0.1: - labels.append(0) - encoded = self.tokenizer( - f"{part4}|{part1}", - part3, + # 续接末尾 10% 概率追加 EOS + if random.random() <= 0.1: + cont_labels.append(0) + + # part1_cont: 包含已确认前缀的上下文 + part1_cont = text[max(0, cont_start - 48) : cont_start] + + # part3_cont: 续接目标后的文本 + cont_end = cont_positions[-1] + 1 + part3_cont = "" + if random.random() > 0.7: + part3_cont = text[ + cont_end : cont_end + np.random.choice(range(1, 17)) + ] + + encoded_cont = self.tokenizer( + f"{part4}|{part1_cont}", + part3_cont, max_length=self.max_seq_length, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) - samples = [] - # 历史槽位长度权重:增加长历史采样比例 - # 目标分布: H=0-2占45%, H=3-8占55% - history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] - # 修复变量名冲突:将内层循环变量i重命名为label_idx - for label_idx, label in enumerate(labels): - base_repeats = self.adjust_frequency(label) - # 根据历史槽位长度调整采样次数 - weight = ( - history_weights[label_idx] - if label_idx < len(history_weights) - else 3.0 - ) - repeats = max(1, int(base_repeats * weight)) - - # 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程) - masked_labels = labels[:label_idx] - len_l = len(masked_labels) - masked_labels.extend([0] * (8 - len_l)) - - samples.extend( - [ - { - "input_ids": encoded["input_ids"], - "token_type_ids": encoded["token_type_ids"], - "attention_mask": encoded["attention_mask"], - "label": torch.tensor([label], dtype=torch.long), - "history_slot_ids": torch.tensor( - masked_labels, dtype=torch.long - ), - "prefix": f"{part4}^{part1}", - "suffix": part3, - "pinyin": part2, - "pinyin_ids": pinyin_ids, - } - ] - * repeats - ) - - # 添加到缓冲区 - batch_samples.extend(samples) + batch_samples = self._add_word_samples( + batch_samples, + cont_labels, + encoded_cont, + part4, + part1_cont, + part3_cont, + part2_cont, + pinyin_ids_cont, + ) # 处理shuffle buffer - 单缓冲区半保留方案 if len(batch_samples) >= self.shuffle_buffer_size: - # 全量打乱缓冲区 indices = np.random.permutation(len(batch_samples)) - # 计算实际保留大小(不超过缓冲区大小) actual_retention = min(self.retention_size, len(batch_samples)) - # 计算输出数量 output_count = len(batch_samples) - actual_retention - # 输出前output_count个样本 for i in range(output_count): if current_iter_index >= worker_quota: - # 配额用完,清空缓冲区并返回 batch_samples = [] return yield batch_samples[indices[i]] current_iter_index += 1 - # 保留后actual_retention个样本(不清空缓冲区) retained_samples = [ batch_samples[idx] for idx in indices[output_count:] ] diff --git a/src/model/export_components.py b/src/model/export_components.py new file mode 100644 index 0000000..a497ed8 --- /dev/null +++ b/src/model/export_components.py @@ -0,0 +1,169 @@ +""" +ONNX导出专用组件 + +为了支持ONNX导出,对原始组件进行修改: +1. 移除packed sequence操作 +2. 处理动态形状问题 +3. 确保所有操作符都ONNX兼容 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ExportPinyinLSTMEncoder(nn.Module): + """ + ONNX兼容的拼音LSTM编码器 + 简化版本,不使用packed sequence + """ + + def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2 + self.num_layers = num_layers + self.dropout = dropout + + self.lstm = nn.LSTM( + input_size=input_dim, + hidden_size=self.hidden_dim, + num_layers=num_layers, + bidirectional=True, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + ) + + self.proj = nn.Linear(self.hidden_dim * 2, input_dim) + self.layer_norm = nn.LayerNorm(input_dim) + + def forward(self, x, mask=None): + """ + ONNX兼容的前向传播 + 不使用packed sequence,改为使用masked计算 + """ + # 简化:直接使用LSTM,不处理padding + # 在ONNX中,处理可变长度序列较复杂 + # 对于输入法场景,拼音长度固定为24,所以可以这样处理 + output, (hidden, cell) = self.lstm(x) + projected = self.proj(output) + return self.layer_norm(projected) + + +class ExportContextEncoder(nn.Module): + """ + ONNX兼容的上下文编码器 + """ + + def __init__( + self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128 + ): + super().__init__() + self.dim = dim + self.max_len = max_len + + # 使用原始text_emb,但需要确保它支持ONNX + from modelscope import AutoModel + + self.text_emb = AutoModel.from_pretrained( + "iic/nlp_structbert_backbone_lite_std" + ).embeddings + + self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim) + self.pos_emb = nn.Embedding(max_len, dim) + self.pinyin_pooling = ExportPinyinLSTMEncoder(dim) + + # Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=dim, + nhead=n_heads, + dim_feedforward=dim * 4, + dropout=0.1, + batch_first=True, + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + self.ln = nn.LayerNorm(dim) + + def forward(self, text_ids, pinyin_ids, mask=None): + """ + ONNX兼容的前向传播 + """ + # 文本嵌入 + text_emb = self.text_emb(text_ids) # [B, seq_len, dim] + + # 位置编码 - 使用预计算的pos_ids + seq_len = text_emb.size(1) + # 使用静态最大长度,避免动态arange + if seq_len > self.max_len: + seq_len = self.max_len + + # 创建位置ID,确保在导出时是静态的 + # 使用torch.full而不是torch.arange,因为arange在动态形状下可能有问题 + pos_ids = torch.arange( + seq_len, device=text_ids.device, dtype=torch.long + ).unsqueeze(0) + + # 如果实际序列长度小于最大长度,截取位置嵌入 + if seq_len < self.max_len: + x = text_emb[:, :seq_len, :] + self.pos_emb(pos_ids) + else: + x = text_emb + self.pos_emb(pos_ids) + + # 处理mask + if mask is not None: + # 确保mask是bool类型 + src_mask = (mask == 0).to(torch.bool) + else: + src_mask = None + + # Transformer + H = self.transformer(x, src_key_padding_mask=src_mask) + H = self.ln(H) + + # 拼音编码 + pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim] + # 简化:不传递mask给LSTM + P = self.pinyin_pooling(pinyin_emb) # [B, pinyin_len, dim] + + return H, P + + +def create_export_context_encoder(original_context_encoder): + """ + 从原始ContextEncoder创建导出版本 + """ + # 获取原始配置 + config = { + "vocab_size": getattr(original_context_encoder, "vocab_size", 10019), + "pinyin_vocab_size": original_context_encoder.pinyin_emb.num_embeddings, + "dim": original_context_encoder.dim, + "n_layers": original_context_encoder.transformer.num_layers, + "n_heads": original_context_encoder.transformer.layers[0].self_attn.num_heads, + "max_len": original_context_encoder.pos_emb.num_embeddings, + } + + # 创建导出版本 + export_encoder = ExportContextEncoder(**config) + + # 复制权重 + # 复制text_emb权重(从原始AutoModel embeddings) + # 注意:这里假设原始text_emb的结构 + state_dict = original_context_encoder.state_dict() + export_state_dict = export_encoder.state_dict() + + # 复制匹配的权重 + for key in export_state_dict: + if key in state_dict: + export_state_dict[key] = state_dict[key] + + # 特殊处理:复制position embeddings + if "pos_emb.weight" in export_state_dict and "pos_emb.weight" in state_dict: + # 确保大小匹配 + orig_pos_emb = state_dict["pos_emb.weight"] + export_pos_emb = export_state_dict["pos_emb.weight"] + min_len = min(orig_pos_emb.size(0), export_pos_emb.size(0)) + export_state_dict["pos_emb.weight"][:min_len] = orig_pos_emb[:min_len] + + export_encoder.load_state_dict(export_state_dict) + + return export_encoder diff --git a/src/model/export_models.py b/src/model/export_models.py new file mode 100644 index 0000000..a7a98d7 --- /dev/null +++ b/src/model/export_models.py @@ -0,0 +1,291 @@ +""" +ONNX导出模型定义 + +定义两个子模型用于ONNX导出: +1. ContextEncoderExport: 输入文本和拼音,输出上下文编码 +2. DecoderExport: 输入上下文编码、拼音编码和槽位历史,输出logits +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 修补modelscope以支持AutoModel导入 +import sys +import modelscope + +if not hasattr(modelscope, "AutoModel"): + from modelscope import Model + + modelscope.AutoModel = Model + sys.modules["modelscope"].AutoModel = Model + +from .components import ( + ContextEncoder, + CrossAttentionFusion, + MoELayer, + SlotMemory, + PinyinLSTMEncoder, +) + + +class ContextEncoderExport(nn.Module): + """ + 上下文编码器导出模型 + 输入: input_ids, pinyin_ids, attention_mask + 输出: context_H, pinyin_P, context_mask, pinyin_mask + """ + + def __init__(self, context_encoder: ContextEncoder): + super().__init__() + self.context_encoder = context_encoder + self.dim = context_encoder.dim + + # 创建ONNX兼容的拼音LSTM编码器(简化版,不使用packed sequence) + # 复制原始LSTM的参数 + original_pooling = context_encoder.pinyin_pooling + self.pinyin_lstm = nn.LSTM( + input_size=original_pooling.input_dim, + hidden_size=original_pooling.hidden_dim, + num_layers=original_pooling.num_layers, + bidirectional=True, + batch_first=True, + dropout=original_pooling.dropout + if original_pooling.num_layers > 1 + else 0.0, + ) + self.pinyin_proj = nn.Linear( + original_pooling.hidden_dim * 2, original_pooling.input_dim + ) + self.pinyin_ln = nn.LayerNorm(original_pooling.input_dim) + + # 复制权重 + self.pinyin_lstm.load_state_dict(original_pooling.lstm.state_dict()) + self.pinyin_proj.load_state_dict(original_pooling.proj.state_dict()) + self.pinyin_ln.load_state_dict(original_pooling.layer_norm.state_dict()) + + def forward(self, input_ids, pinyin_ids, attention_mask): + """ + Args: + input_ids: [batch_size, seq_len] + pinyin_ids: [batch_size, pinyin_len] (pinyin_len固定为24) + attention_mask: [batch_size, seq_len] + + Returns: + context_H: [batch_size, seq_len, dim] + pinyin_P: [batch_size, pinyin_len, dim] + context_mask: [batch_size, seq_len] (bool -> int32) + pinyin_mask: [batch_size, pinyin_len] (bool -> int32) + """ + # 获取原始context_encoder的组件 + text_emb = self.context_encoder.text_emb + pinyin_emb = self.context_encoder.pinyin_emb + pos_emb = self.context_encoder.pos_emb + transformer = self.context_encoder.transformer + ln = self.context_encoder.ln + + # 文本嵌入 + text_emb_out = text_emb(input_ids) # [B, seq_len, dim] + + # 位置编码 - 修复torch.arange动态形状问题 + seq_len = text_emb_out.size(1) + # 使用预计算的位置ID,确保静态形状 + if seq_len > pos_emb.num_embeddings: + seq_len = pos_emb.num_embeddings + + # 创建位置ID - 使用torch.arange但确保是整数张量 + pos_ids = torch.arange( + seq_len, device=input_ids.device, dtype=torch.long + ).unsqueeze(0) + + # 如果实际序列长度小于最大长度,截取 + if seq_len < text_emb_out.size(1): + x = text_emb_out[:, :seq_len, :] + pos_emb(pos_ids) + else: + x = text_emb_out + pos_emb(pos_ids) + + # 处理mask - 确保bool类型 + if attention_mask is not None: + src_mask = (attention_mask == 0).to(torch.bool) + else: + src_mask = None + + # Transformer + H = transformer(x, src_key_padding_mask=src_mask) + H = ln(H) + + # 恢复原始序列长度(如果需要) + if seq_len < text_emb_out.size(1): + # 填充H到原始长度 + original_seq_len = text_emb_out.size(1) + padding = torch.zeros( + H.size(0), + original_seq_len - seq_len, + H.size(2), + device=H.device, + dtype=H.dtype, + ) + H = torch.cat([H, padding], dim=1) + + # 拼音编码 - 使用ONNX兼容版本 + pinyin_emb_out = pinyin_emb(pinyin_ids) # [B, pinyin_len, dim] + # 简化的LSTM,不使用packed sequence + pinyin_lstm_out, _ = self.pinyin_lstm(pinyin_emb_out) + pinyin_proj_out = self.pinyin_proj(pinyin_lstm_out) + P = self.pinyin_ln(pinyin_proj_out) + + # 生成mask(转换为int32以便ONNX支持) + context_mask = (attention_mask == 0).to(torch.int32) # 1表示padding,0表示有效 + pinyin_mask = (pinyin_ids == 0).to(torch.int32) # 1表示padding,0表示有效 + + return H, P, context_mask, pinyin_mask + + +class DecoderExport(nn.Module): + """ + 解码器导出模型 + 输入: context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask + 输出: logits + """ + + def __init__( + self, + slot_memory: SlotMemory, + cross_attn: CrossAttentionFusion, + moe: MoELayer, + slot_attention: nn.Module, + classifier: nn.Module, + num_slots: int = 8, + dim: int = 512, + ): + super().__init__() + self.slot_memory = slot_memory + self.cross_attn = cross_attn + self.moe = moe + self.slot_attention = slot_attention + self.classifier = classifier + self.num_slots = num_slots + self.dim = dim + + def forward(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + """ + Args: + context_H: [batch_size, seq_len, dim] + pinyin_P: [batch_size, pinyin_len, dim] + history_slot_ids: [batch_size, num_slots] + context_mask: [batch_size, seq_len] (int32, 1表示padding) + pinyin_mask: [batch_size, pinyin_len] (int32, 1表示padding) + + Returns: + logits: [batch_size, vocab_size] + """ + batch_size = context_H.size(0) + + # 确保history_slot_ids形状正确 + history_slot_ids = history_slot_ids.view(batch_size, self.num_slots) + + # 1. 槽位记忆 + S = self.slot_memory(history_slot_ids) # [batch_size, num_slots, dim] + + # 2. 交叉注意力融合 + # 转换mask: int32 -> bool (ONNX中需要bool类型) + context_mask_bool = context_mask.to(torch.bool) + pinyin_mask_bool = pinyin_mask.to(torch.bool) + + fused = self.cross_attn( + S, + context_H, + pinyin_P, + context_mask=context_mask_bool, + pinyin_mask=pinyin_mask_bool, + ) # [batch_size, num_slots, dim] + + # 3. MoE层 + moe_out = self.moe(fused) # [batch_size, num_slots, dim] + + # 4. 槽位注意力池化 + slot_scores = self.slot_attention(moe_out).squeeze( + -1 + ) # [batch_size, num_slots] + slot_weights = torch.softmax(slot_scores, dim=1) # [batch_size, num_slots] + pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch_size, dim] + + # 5. 分类头 + logits = self.classifier(pooled) # [batch_size, vocab_size] + + return logits + + +def create_export_models_from_checkpoint(checkpoint_path, device="cpu"): + """ + 从checkpoint创建导出模型 + + Args: + checkpoint_path: 模型checkpoint路径 + device: 加载设备 + + Returns: + context_encoder_export: ContextEncoderExport实例 + decoder_export: DecoderExport实例 + model_config: 模型配置字典 + """ + # 加载原始模型配置 + checkpoint = torch.load(checkpoint_path, map_location=device) + + # 提取模型配置(从checkpoint或使用默认值) + if "config" in checkpoint: + config = checkpoint["config"] + else: + # 使用模型默认配置 + config = { + "vocab_size": 10019, + "pinyin_vocab_size": 30, + "dim": 512, + "num_slots": 8, + "n_layers": 4, + "n_heads": 4, + "num_experts": 10, + "max_seq_len": 128, + } + + # 创建原始模型 + from .model import InputMethodEngine + + model = InputMethodEngine( + vocab_size=config.get("vocab_size", 10019), + pinyin_vocab_size=config.get("pinyin_vocab_size", 30), + dim=config.get("dim", 512), + num_slots=config.get("num_slots", 8), + n_layers=config.get("n_layers", 4), + n_heads=config.get("n_heads", 4), + num_experts=config.get("num_experts", 10), + max_seq_len=config.get("max_seq_len", 128), + compile=False, + ) + + # 加载权重 + if "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"]) + else: + model.load_state_dict(checkpoint) + + model.eval() + model.to(device) + + # 创建导出模型 + context_encoder_export = ContextEncoderExport(model.context_encoder) + decoder_export = DecoderExport( + slot_memory=model.slot_memory, + cross_attn=model.cross_attn, + moe=model.moe, + slot_attention=model.slot_attention, + classifier=model.classifier, + num_slots=model.num_slots, + dim=model.dim, + ) + + # 设置为评估模式 + context_encoder_export.eval() + decoder_export.eval() + + return context_encoder_export, decoder_export, config diff --git a/test_dataset.py b/test_dataset.py new file mode 100644 index 0000000..d27244d --- /dev/null +++ b/test_dataset.py @@ -0,0 +1,101 @@ +import sys + +sys.path.append("src") +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import time +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from model.model import InputMethodEngine +from model.query import QueryEngine + +import random +import re +from importlib.resources import files +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +from datasets import load_dataset +from loguru import logger +from modelscope import AutoTokenizer +from pypinyin import lazy_pinyin +from pypinyin.contrib.tone_convert import to_initials +from torch.utils.data import IterableDataset +from model.dataset import PinyinInputDataset + + +def worker_init_fn(worker_id: int) -> None: + """ + 初始化每个DataLoader worker的随机种子,确保可复现性 + + Args: + worker_id: worker的ID + """ + worker_seed = torch.initial_seed() % (2**32) + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + 自定义批处理函数,将多个样本组合成一个batch + + Args: + batch: 样本列表,每个样本是一个字典 + + Returns: + 批处理后的字典,tensor字段已stack,字符串字段保持为列表 + """ + # 处理tensor字段 - 使用squeeze去除多余的batch维度 + input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) + token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) + attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) + labels = torch.stack([item["label"].squeeze(0) for item in batch]) + history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) + pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) + + # 字符串字段保持为列表 + prefixes = [item["prefix"] for item in batch] + suffixes = [item["suffix"] for item in batch] + pinyins = [item["pinyin"] for item in batch] + + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + "labels": labels, + "history_slot_ids": history_slot_ids, + "prefix": prefixes, + "suffix": suffixes, + "pinyin": pinyins, + "pinyin_ids": pinyin_ids, + } + + +train_dataset = PinyinInputDataset( + data_path="/home/songsenand/Data/corpus/CCI-Data/", + max_workers=-1, # 自动选择worker数量 + max_iter_length=1000000, + text_field="text", + py_style_weight=(90, 2, 1), + shuffle_buffer_size=20000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, +) + +dataloader = DataLoader( + train_dataset, + batch_size=512, + num_workers=2, + worker_init_fn=worker_init_fn, + collate_fn=collate_fn, + prefetch_factor=2, # 减少预取以避免内存问题 + persistent_workers=True, +) + +for i, shape in tqdm(enumerate(dataloader), total=1000000/512): + pass + diff --git a/verify_onnx.py b/verify_onnx.py new file mode 100644 index 0000000..4976a62 --- /dev/null +++ b/verify_onnx.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +""" +ONNX模型验证脚本 + +验证导出的ONNX模型与原始PyTorch模型输出的一致性 +""" + +import argparse +import os +import sys +from pathlib import Path + +import numpy as np +import torch +import onnxruntime as ort + +# 添加src目录到路径 +sys.path.insert(0, str(Path(__file__).parent)) + +from src.model.export_models import create_export_models_from_checkpoint + + +def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5): + """ + 比较PyTorch和ONNX输出 + + Args: + pytorch_output: PyTorch张量 + onnx_output: ONNX Runtime输出(numpy数组) + name: 输出名称(用于错误信息) + rtol: 相对容差 + atol: 绝对容差 + + Returns: + bool: 是否匹配 + """ + # 转换PyTorch输出为numpy + if isinstance(pytorch_output, torch.Tensor): + pytorch_np = pytorch_output.detach().cpu().numpy() + else: + pytorch_np = np.array(pytorch_output) + + # 确保形状一致 + if pytorch_np.shape != onnx_output.shape: + print( + f"❌ {name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}" + ) + return False + + # 计算差异 + diff = np.abs(pytorch_np - onnx_output) + max_diff = np.max(diff) + mean_diff = np.mean(diff) + + # 检查是否在容差范围内 + is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol) + + if is_close: + print(f"✅ {name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + else: + print(f"❌ {name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + print( + f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], " + f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]" + ) + + return is_close + + +def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"): + """ + 验证上下文编码器 + + Args: + checkpoint_path: PyTorch checkpoint路径 + onnx_path: ONNX模型路径 + device: 设备 + + Returns: + bool: 验证是否通过 + """ + print(f"\n🔍 验证上下文编码器: {onnx_path}") + + # 加载PyTorch模型 + context_encoder_export, _, config = create_export_models_from_checkpoint( + checkpoint_path, device + ) + + # 创建ONNX Runtime会话 + session = ort.InferenceSession( + onnx_path, + providers=[ + "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" + ], + ) + + # 创建测试输入 + batch_size = 2 # 使用batch_size=2测试动态批处理 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 + + input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) + pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) + # 随机屏蔽一些位置 + attention_mask[:, seq_len // 2 :] = 0 + + # PyTorch推理 + with torch.no_grad(): + pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask) + + # ONNX推理 + onnx_inputs = { + "input_ids": input_ids.numpy(), + "pinyin_ids": pinyin_ids.numpy(), + "attention_mask": attention_mask.numpy(), + } + + onnx_outputs = session.run(None, onnx_inputs) + + # 比较输出 + output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"] + all_match = True + + for i, name in enumerate(output_names): + if i < len(pytorch_outputs) and i < len(onnx_outputs): + match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name) + all_match = all_match and match + + return all_match + + +def verify_decoder(checkpoint_path, onnx_path, device="cpu"): + """ + 验证解码器 + + Args: + checkpoint_path: PyTorch checkpoint路径 + onnx_path: ONNX模型路径 + device: 设备 + + Returns: + bool: 验证是否通过 + """ + print(f"\n🔍 验证解码器: {onnx_path}") + + # 加载PyTorch模型 + _, decoder_export, config = create_export_models_from_checkpoint( + checkpoint_path, device + ) + + # 创建ONNX Runtime会话 + session = ort.InferenceSession( + onnx_path, + providers=[ + "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" + ], + ) + + # 创建测试输入 + batch_size = 2 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 + dim = config.get("dim", 512) + num_slots = config.get("num_slots", 8) + + context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32) + pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32) + context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32) + pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32) + history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long) + + # PyTorch推理 + with torch.no_grad(): + pytorch_output = decoder_export( + context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask + ) + + # ONNX推理 + onnx_inputs = { + "context_H": context_H.numpy(), + "pinyin_P": pinyin_P.numpy(), + "history_slot_ids": history_slot_ids.numpy(), + "context_mask": context_mask.numpy(), + "pinyin_mask": pinyin_mask.numpy(), + } + + onnx_outputs = session.run(None, onnx_inputs) + + # 比较输出 + return compare_outputs(pytorch_output, onnx_outputs[0], "logits") + + +def verify_end_to_end( + checkpoint_path, context_encoder_path, decoder_path, device="cpu" +): + """ + 端到端验证:比较完整推理流程 + + Args: + checkpoint_path: PyTorch checkpoint路径 + context_encoder_path: 上下文编码器ONNX路径 + decoder_path: 解码器ONNX路径 + device: 设备 + + Returns: + bool: 验证是否通过 + """ + print(f"\n🔍 端到端验证") + + # 加载原始PyTorch模型 + from src.model.model import InputMethodEngine + + checkpoint = torch.load(checkpoint_path, map_location=device) + + if "config" in checkpoint: + config = checkpoint["config"] + else: + config = { + "vocab_size": 10019, + "pinyin_vocab_size": 30, + "dim": 512, + "num_slots": 8, + "n_layers": 4, + "n_heads": 4, + "num_experts": 10, + "max_seq_len": 128, + } + + model = InputMethodEngine( + vocab_size=config.get("vocab_size", 10019), + pinyin_vocab_size=config.get("pinyin_vocab_size", 30), + dim=config.get("dim", 512), + num_slots=config.get("num_slots", 8), + n_layers=config.get("n_layers", 4), + n_heads=config.get("n_heads", 4), + num_experts=config.get("num_experts", 10), + max_seq_len=config.get("max_seq_len", 128), + compile=False, + ) + + if "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"]) + else: + model.load_state_dict(checkpoint) + + model.eval() + model.to(device) + + # 创建ONNX Runtime会话 + context_session = ort.InferenceSession( + context_encoder_path, + providers=[ + "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" + ], + ) + decoder_session = ort.InferenceSession( + decoder_path, + providers=[ + "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" + ], + ) + + # 创建测试输入 + batch_size = 1 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 + + input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) + pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) + history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long) + + # PyTorch完整推理 + with torch.no_grad(): + pytorch_logits = model( + input_ids=input_ids, + token_type_ids=torch.zeros_like(input_ids), # 简化处理 + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + + # ONNX推理流程 + # 1. 上下文编码器 + context_inputs = { + "input_ids": input_ids.numpy(), + "pinyin_ids": pinyin_ids.numpy(), + "attention_mask": attention_mask.numpy(), + } + context_outputs = context_session.run(None, context_inputs) + context_H, pinyin_P, context_mask, pinyin_mask = context_outputs + + # 2. 解码器 + decoder_inputs = { + "context_H": context_H, + "pinyin_P": pinyin_P, + "history_slot_ids": history_slot_ids.numpy(), + "context_mask": context_mask, + "pinyin_mask": pinyin_mask, + } + onnx_outputs = decoder_session.run(None, decoder_inputs) + + # 比较输出 + return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits") + + +def main(): + parser = argparse.ArgumentParser(description="ONNX模型验证") + parser.add_argument( + "--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径" + ) + parser.add_argument( + "--context-encoder", + type=str, + help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)", + ) + parser.add_argument( + "--decoder", + type=str, + help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)", + ) + parser.add_argument( + "--output-dir", + "-o", + type=str, + default="./exported_models", + help="导出目录(如果未指定单个模型路径)", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="验证设备(默认: cpu)", + ) + parser.add_argument( + "--skip-context", action="store_true", help="跳过上下文编码器验证" + ) + parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证") + parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证") + + args = parser.parse_args() + + # 确定模型路径 + if args.context_encoder: + context_encoder_path = args.context_encoder + else: + context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx") + + if args.decoder: + decoder_path = args.decoder + else: + decoder_path = os.path.join(args.output_dir, "decoder.onnx") + + print("🔬 ONNX模型验证") + print("=" * 60) + print(f"Checkpoint: {args.checkpoint}") + print(f"上下文编码器: {context_encoder_path}") + print(f"解码器: {decoder_path}") + print(f"设备: {args.device}") + print() + + all_pass = True + + # 验证上下文编码器 + if not args.skip_context and os.path.exists(context_encoder_path): + if verify_context_encoder(args.checkpoint, context_encoder_path, args.device): + print("✅ 上下文编码器验证通过") + else: + print("❌ 上下文编码器验证失败") + all_pass = False + elif not args.skip_context: + print("⚠️ 上下文编码器文件不存在,跳过验证") + + # 验证解码器 + if not args.skip_decoder and os.path.exists(decoder_path): + if verify_decoder(args.checkpoint, decoder_path, args.device): + print("✅ 解码器验证通过") + else: + print("❌ 解码器验证失败") + all_pass = False + elif not args.skip_decoder: + print("⚠️ 解码器文件不存在,跳过验证") + + # 端到端验证 + if ( + not args.skip_end_to_end + and os.path.exists(context_encoder_path) + and os.path.exists(decoder_path) + ): + if verify_end_to_end( + args.checkpoint, context_encoder_path, decoder_path, args.device + ): + print("✅ 端到端验证通过") + else: + print("❌ 端到端验证失败") + all_pass = False + + print("\n" + "=" * 60) + if all_pass: + print("🎉 所有验证通过!ONNX模型与PyTorch模型输出一致") + else: + print("❌ 部分验证失败,请检查模型导出过程") + sys.exit(1) + + +if __name__ == "__main__": + main()