#!/usr/bin/env python3 """ 束搜索算法演示 展示如何使用导出的两个ONNX模型进行束搜索推理, 模拟输入法场景:给定上下文、拼音和已确认字符,生成候选汉字序列。 """ import argparse import os import sys from pathlib import Path from typing import List, Tuple, Dict, Any import numpy as np import onnxruntime as ort import torch import torch.nn.functional as F # 添加src目录到路径 sys.path.insert(0, str(Path(__file__).parent)) class ONNXBeamSearch: """ 基于ONNX模型的束搜索解码器 """ def __init__( self, context_encoder_path: str, decoder_path: str, device: str = "cpu" ): """ 初始化 Args: context_encoder_path: 上下文编码器ONNX模型路径 decoder_path: 解码器ONNX模型路径 device: 推理设备(cpu或cuda) """ self.device = device # 配置ONNX Runtime提供程序 if device == "cuda": providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] # 创建ONNX Runtime会话 self.context_session = ort.InferenceSession( context_encoder_path, providers=providers ) self.decoder_session = ort.InferenceSession(decoder_path, providers=providers) # 获取输入输出信息 self.context_input_info = { input.name: input for input in self.context_session.get_inputs() } self.context_output_info = { output.name: output for output in self.context_session.get_outputs() } self.decoder_input_info = { input.name: input for input in self.decoder_session.get_inputs() } self.decoder_output_info = { output.name: output for output in self.decoder_session.get_outputs() } print("📦 ONNX模型加载完成") print(f" 上下文编码器输入: {list(self.context_input_info.keys())}") print(f" 上下文编码器输出: {list(self.context_output_info.keys())}") print(f" 解码器输入: {list(self.decoder_input_info.keys())}") print(f" 解码器输出: {list(self.decoder_output_info.keys())}") def prepare_inputs( self, text_before: str, text_after: str, pinyin: str, context_prompts: List[str] = None, tokenizer=None, query_engine=None, max_seq_len: int = 128, ) -> Dict[str, np.ndarray]: """ 准备模型输入 注意: 这是一个简化的版本,实际应用中需要实现完整的预处理逻辑 这里使用随机数据模拟,实际应使用tokenizer和查询引擎 Args: text_before: 光标前文本 text_after: 光标后文本 pinyin: 拼音输入 context_prompts: 上下文提示 tokenizer: 分词器(未实现) query_engine: 查询引擎(未实现) max_seq_len: 最大序列长度 Returns: 输入字典,包含input_ids, pinyin_ids, attention_mask """ # 在实际应用中,这里应该实现: # 1. 使用tokenizer将文本转换为input_ids # 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids # 3. 构建attention_mask # 简化为随机数据(用于演示) batch_size = 1 seq_len = min(max_seq_len, 64) # 简化长度 # 模拟输入 input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64) pinyin_ids = np.random.randint( 0, 30, (batch_size, 24), dtype=np.int64 ) # 固定24长度 attention_mask = np.ones((batch_size, seq_len), dtype=np.int64) # 在实际应用中,应该这样处理拼音: # from src.model.dataset import text_to_pinyin_ids # pinyin_ids_list = text_to_pinyin_ids(pinyin) # pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1) return { "input_ids": input_ids, "pinyin_ids": pinyin_ids, "attention_mask": attention_mask, } def run_context_encoder( self, inputs: Dict[str, np.ndarray] ) -> Tuple[np.ndarray, ...]: """ 运行上下文编码器 Args: inputs: 输入字典 Returns: 上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask) """ # 准备ONNX输入(确保顺序正确) onnx_inputs = {} for input_name in self.context_input_info.keys(): if input_name in inputs: onnx_inputs[input_name] = inputs[input_name] else: # 对于缺失的输入,使用默认值 shape = self.context_input_info[input_name].shape dtype = self.context_input_info[input_name].type # 简化处理:创建零数组 if "int" in dtype: onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64) else: onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32) # 运行推理 outputs = self.context_session.run(None, onnx_inputs) return tuple(outputs) def run_decoder( self, context_H: np.ndarray, pinyin_P: np.ndarray, history_slot_ids: np.ndarray, context_mask: np.ndarray, pinyin_mask: np.ndarray, ) -> np.ndarray: """ 运行解码器 Args: context_H: 上下文编码 [batch, seq_len, 512] pinyin_P: 拼音编码 [batch, 24, 512] history_slot_ids: 历史槽位IDs [batch, 8] context_mask: 上下文掩码 [batch, seq_len] pinyin_mask: 拼音掩码 [batch, 24] Returns: logits [batch, vocab_size] """ inputs = { "context_H": context_H, "pinyin_P": pinyin_P, "history_slot_ids": history_slot_ids, "context_mask": context_mask, "pinyin_mask": pinyin_mask, } # 运行推理 outputs = self.decoder_session.run(None, inputs) return outputs[0] def beam_search( self, context_H: np.ndarray, pinyin_P: np.ndarray, context_mask: np.ndarray, pinyin_mask: np.ndarray, beam_size: int = 5, max_length: int = 8, vocab_size: int = 10019, temperature: float = 1.0, length_penalty: float = 0.6, ) -> List[Tuple[List[int], float]]: """ 束搜索算法 Args: context_H: 上下文编码 pinyin_P: 拼音编码 context_mask: 上下文掩码 pinyin_mask: 拼音掩码 beam_size: 束大小 max_length: 最大生成长度 vocab_size: 词汇表大小 temperature: 温度参数(控制随机性) length_penalty: 长度惩罚系数 Returns: 排序后的候选序列列表,每个元素为(序列, 分数) """ # 初始束:空序列,对数概率为0 beams = [([], 0.0)] # (序列, 累计对数概率) for step in range(max_length): new_beams = [] for seq, score in beams: # 构建history_slot_ids if len(seq) < 8: # 填充到8个槽位 history = seq + [0] * (8 - len(seq)) else: # 只保留最近8个字符 history = seq[-8:] history_array = np.array([history], dtype=np.int64) # 运行解码器 logits = self.run_decoder( context_H, pinyin_P, history_array, context_mask, pinyin_mask ) # 应用温度 logits = logits / temperature # 转换为概率 probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy() # 获取top-k候选(k = beam_size * 2 以增加多样性) top_k = min(beam_size * 2, vocab_size) top_indices = np.argsort(probs)[-top_k:][::-1] top_probs = probs[top_indices] # 扩展束 for idx, prob in zip(top_indices, top_probs): new_seq = seq + [int(idx)] # 更新分数:累计对数概率 + log(prob) new_score = score + np.log(prob + 1e-10) # 应用长度惩罚 length_penalized_score = new_score / ( (5 + len(new_seq)) ** length_penalty ) new_beams.append((new_seq, length_penalized_score, new_score)) # 剪枝:保留beam_size个最佳候选 new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序 beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]] # 检查是否所有序列都已结束(结束符ID为0) all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0) if all_ended: break # 返回原始分数(未应用长度惩罚) return beams def interactive_beam_search(self): """交互式束搜索演示""" print("\n" + "=" * 60) print("束搜索交互演示") print("=" * 60) print("\n📝 输入法场景模拟:") print(" 假设用户正在输入拼音 'nihao',已确认第一个字 '你'") print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'") print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)") print("-" * 60) # 模拟输入(实际应用中应从用户获取) text_before = "今天天气很好" text_after = "我们去公园玩" pinyin = "hao" # 继续输入"hao" slot_chars = ["你"] # 已确认的字符 context_prompts = ["张三", "李四"] print(f"\n📋 输入参数:") print(f" 光标前文本: '{text_before}'") print(f" 光标后文本: '{text_after}'") print(f" 拼音: '{pinyin}'") print(f" 槽位历史: {slot_chars}") print(f" 上下文提示: {context_prompts}") # 准备输入(简化版,使用随机数据) print(f"\n🔧 准备模型输入...") inputs = self.prepare_inputs( text_before=text_before, text_after=text_after, pinyin=pinyin, context_prompts=context_prompts, ) # 运行上下文编码器 print(f"🧠 运行上下文编码器...") context_outputs = self.run_context_encoder(inputs) context_H, pinyin_P, context_mask, pinyin_mask = context_outputs print(f"✅ 上下文编码完成") print(f" context_H形状: {context_H.shape}") print(f" pinyin_P形状: {pinyin_P.shape}") # 将槽位历史转换为ID(简化:使用随机ID) # 实际应用中应使用query_engine将汉字转换为ID slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42 if len(slot_ids) < 8: slot_ids = slot_ids + [0] * (8 - len(slot_ids)) # 运行束搜索 print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...") beams = self.beam_search( context_H, pinyin_P, context_mask, pinyin_mask, beam_size=3, max_length=4, vocab_size=10019, ) # 显示结果 print(f"\n🏆 束搜索结果:") print("-" * 50) for i, (seq, score) in enumerate(beams): # 将ID序列转换为汉字(简化:显示ID) seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq]) print(f"{i + 1}. 序列: [{seq_str}]") print(f" 对数概率: {score:.4f}") print(f" 概率: {np.exp(score):.6f}") print() print("📝 说明:") print(" - 'END' 表示结束符 (ID: 0)") print(" - 实际应用中应将ID转换为汉字") print(" - 最高分数的序列作为最终预测") return beams def main(): parser = argparse.ArgumentParser(description="束搜索算法演示") parser.add_argument( "--context-encoder", type=str, default="./exported_models/context_encoder.onnx", help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)", ) parser.add_argument( "--decoder", type=str, default="./exported_models/decoder.onnx", help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="推理设备(默认: cpu)", ) parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3)") parser.add_argument( "--max-length", type=int, default=4, help="最大生成长度(默认: 4)" ) parser.add_argument( "--interactive", action="store_true", default=True, help="交互模式(默认: True)", ) args = parser.parse_args() # 检查模型文件是否存在 if not os.path.exists(args.context_encoder): print(f"❌ 上下文编码器文件不存在: {args.context_encoder}") print(f" 请先运行 export_onnx.py 导出模型") return if not os.path.exists(args.decoder): print(f"❌ 解码器文件不存在: {args.decoder}") print(f" 请先运行 export_onnx.py 导出模型") return print("🚀 束搜索算法演示") print("=" * 60) print(f"上下文编码器: {args.context_encoder}") print(f"解码器: {args.decoder}") print(f"设备: {args.device}") print(f"束大小: {args.beam_size}") print(f"最大长度: {args.max_length}") # 初始化束搜索器 beam_searcher = ONNXBeamSearch( context_encoder_path=args.context_encoder, decoder_path=args.decoder, device=args.device, ) if args.interactive: # 运行交互演示 beam_searcher.interactive_beam_search() print("\n" + "=" * 60) print("🎉 演示完成") print("\n💡 下一步:") print(" 1. 实现完整的输入预处理(tokenizer和拼音转换)") print(" 2. 集成查询引擎以将ID转换为汉字") print(" 3. 根据实际场景调整束搜索参数") print(" 4. 性能优化:批量处理、缓存等") if __name__ == "__main__": main()