#!/usr/bin/env python3 """ ONNX输入法模型推理脚本 使用ONNX Runtime进行推理,测量每个阶段的执行时长 使用方法: python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx 交互模式: 分步询问输入 1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空) 2. 光标前文本: 光标前的连续文本 3. 光标后文本: 光标后的连续文本 4. 拼音: 当前输入的拼音 5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上") """ import argparse import os import sys import time from pathlib import Path from typing import List, Optional, Tuple import numpy as np import onnxruntime as ort import torch import torch.nn.functional as F from modelscope import AutoTokenizer from src.model.dataset import text_to_pinyin_ids from src.model.query import QueryEngine class ONNXInference: """ONNX输入法模型推理器""" def __init__( self, context_encoder_path: str, decoder_path: str, vocab_size: int = 10019, device: str = "cpu", use_beam_search: bool = False, beam_size: int = 5, ): self.vocab_size = vocab_size self.device = device self.use_beam_search = use_beam_search self.beam_size = beam_size # 加载组件 print(f"正在加载上下文编码器: {context_encoder_path}") load_start = time.perf_counter() self.load_context_encoder(context_encoder_path) self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000 print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)") print(f"正在加载解码器: {decoder_path}") load_start = time.perf_counter() self.load_decoder(decoder_path) self.decoder_load_time = (time.perf_counter() - load_start) * 1000 print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)") # 加载tokenizer print("正在加载tokenizer...") load_start = time.perf_counter() self.load_tokenizer() self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000 print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)") # 加载查询引擎 print("正在加载查询引擎...") load_start = time.perf_counter() self.load_query_engine() self.query_engine_load_time = (time.perf_counter() - load_start) * 1000 print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)") total_load_time = ( self.context_encoder_load_time + self.decoder_load_time + self.tokenizer_load_time + self.query_engine_load_time ) print(f"\n✅ 推理器初始化完成 (设备: {device})") print(f" 总加载时间: {total_load_time:.2f}ms") # 尝试启用readline try: import readline readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?") except ImportError: pass def load_context_encoder(self, model_path: str): """加载上下文编码器ONNX模型""" providers = ( ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"] ) self.context_encoder_session = ort.InferenceSession( model_path, providers=providers ) self.context_input_names = [ inp.name for inp in self.context_encoder_session.get_inputs() ] self.context_output_names = [ out.name for out in self.context_encoder_session.get_outputs() ] def load_decoder(self, model_path: str): """加载解码器ONNX模型""" providers = ( ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"] ) self.decoder_session = ort.InferenceSession(model_path, providers=providers) self.decoder_input_names = [ inp.name for inp in self.decoder_session.get_inputs() ] self.decoder_output_names = [ out.name for out in self.decoder_session.get_outputs() ] def load_tokenizer(self): """加载tokenizer""" try: tokenizer_path = ( Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" ) self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) except Exception: print(" ⚠️ 无法加载自定义tokenizer,使用bert-base-chinese") self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") def load_query_engine(self): """加载查询引擎""" try: self.query_engine = QueryEngine() stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): self.query_engine.load(stats_path) except Exception: self.query_engine = None def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int: """将汉字转换为ID""" if char == "//": return 0 if self.query_engine is None: return ord(char) if len(char) == 1 else 0 try: if pinyin is not None: info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin) if info: return info.id results = self.query_engine.query_by_char(char, limit=1) if results: return results[0][0] return 0 except: return 0 def id_to_char(self, id: int) -> str: """将ID转换为汉字""" if id == 0: return "//" if self.query_engine is None: return chr(id) if id < 0x110000 else f"[ID:{id}]" try: info = self.query_engine.query_by_id(id) return info.char if info else f"[ID:{id}]" except: return f"[ID:{id}]" def _clean_pinyin_input(self, pinyin: str) -> str: """清理拼音输入字符串""" if not pinyin: return "" result = [] for c in pinyin: is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"] if is_valid: result.append(c.lower()) elif c == " ": continue elif c in ["\b", "\x7f", "\x08"]: if result: result.pop() elif c == "\x1b": result.clear() return "".join(result) def _safe_input(self, prompt: str, default: str = "") -> str: """安全的输入函数""" try: full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: " result = input(full_prompt) if not result and default: return default return result.strip() except (EOFError, KeyboardInterrupt): print() return "" except Exception as e: print(f"\n⚠️ 输入错误: {e}") return "" def prepare_inputs( self, context_prompts: List[str], text_before: str, text_after: str, pinyin: str, slot_chars: List[str], max_seq_len: int = 128, ) -> dict: """ 准备模型输入 Returns: dict: { 'preprocess_time': float, # 预处理时间(ms) 'input_ids': numpy array, 'attention_mask': numpy array, 'pinyin_ids': numpy array, 'history_slot_ids': numpy array, } """ preprocess_start = time.perf_counter() # 1. 构建tokenizer输入 context_text = "|".join(context_prompts) if context_prompts else "" if context_text: input_text = f"{context_text}|{text_before}" else: input_text = text_before # 2. Tokenize encoded = self.tokenizer( input_text, text_after, max_length=max_seq_len, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) input_ids = encoded["input_ids"].numpy() attention_mask = encoded["attention_mask"].numpy() # 3. 处理拼音输入 cleaned_pinyin = self._clean_pinyin_input(pinyin) pinyin_ids = text_to_pinyin_ids(cleaned_pinyin) if len(pinyin_ids) < 24: pinyin_ids.extend([0] * (24 - len(pinyin_ids))) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = np.array([pinyin_ids], dtype=np.int64) # 4. 处理历史槽位 history_slot_ids = [] for char in slot_chars: char_id = self.char_to_id(char) history_slot_ids.append(char_id) if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] history_slot_ids = np.array([history_slot_ids], dtype=np.int64) preprocess_time = (time.perf_counter() - preprocess_start) * 1000 return { "preprocess_time": preprocess_time, "input_ids": input_ids, "attention_mask": attention_mask, "pinyin_ids": pinyin_ids, "history_slot_ids": history_slot_ids, } def run_context_encoder( self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ 运行上下文编码器 Returns: context_H, pinyin_P, context_mask, pinyin_mask """ context_start = time.perf_counter() inputs = { "input_ids": input_ids, "pinyin_ids": pinyin_ids, "attention_mask": attention_mask, } outputs = self.context_encoder_session.run(self.context_output_names, inputs) context_H, pinyin_P, context_mask, pinyin_mask = outputs self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000 return context_H, pinyin_P, context_mask, pinyin_mask def run_decoder( self, context_H: np.ndarray, pinyin_P: np.ndarray, history_slot_ids: np.ndarray, context_mask: np.ndarray, pinyin_mask: np.ndarray, ) -> np.ndarray: """ 运行解码器 Returns: logits: [batch, vocab_size] """ decoder_start = time.perf_counter() inputs = { "context_H": context_H, "pinyin_P": pinyin_P, "history_slot_ids": history_slot_ids, "context_mask": context_mask, "pinyin_mask": pinyin_mask, } outputs = self.decoder_session.run(self.decoder_output_names, inputs) self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000 return outputs[0] def predict( self, context_prompts: List[str], text_before: str, text_after: str, pinyin: str, slot_chars: List[str], top_k: int = 20, use_beam_search: bool = False, beam_size: int = 5, max_length: int = 10, ) -> Tuple[List[Tuple[str, float, int]], dict]: """ 执行推理 Args: context_prompts: 上下文提示 text_before: 光标前文本 text_after: 光标后文本 pinyin: 当前输入的拼音 slot_chars: 槽位内的汉字列表 top_k: 返回top-k个预测结果 use_beam_search: 是否使用束搜索 beam_size: 束大小 max_length: 最大生成长度 Returns: (predictions, timing_info) predictions: List[Tuple[char, prob, id]] timing_info: 各阶段耗时字典 """ total_start = time.perf_counter() # 阶段1: 预处理 prep_start = time.perf_counter() inputs = self.prepare_inputs( context_prompts, text_before, text_after, pinyin, slot_chars ) preprocess_time = inputs["preprocess_time"] input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] pinyin_ids = inputs["pinyin_ids"] history_slot_ids = inputs["history_slot_ids"] prep_time = (time.perf_counter() - prep_start) * 1000 # 阶段2: 上下文编码 context_start = time.perf_counter() context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( input_ids, pinyin_ids, attention_mask ) context_encoder_time = self.last_context_encoder_time if use_beam_search: # 阶段3: 束搜索解码 decode_start = time.perf_counter() predictions, beam_decode_time = self._beam_search_decode( context_H, pinyin_P, context_mask, pinyin_mask, beam_size, max_length, top_k, ) decoder_time = beam_decode_time else: # 阶段3: 单步解码 decode_start = time.perf_counter() logits = self.run_decoder( context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask, ) # 阶段4: 后处理 postprocess_start = time.perf_counter() probs = self._softmax(logits) top_indices, top_probs = self._topk(probs, top_k) predictions = [] for i in range(top_k): idx = int(top_indices[0, i]) prob = float(top_probs[0, i]) char = self.id_to_char(idx) predictions.append((char, prob, idx)) postprocess_time = (time.perf_counter() - postprocess_start) * 1000 decoder_time = self.last_decoder_time total_time = (time.perf_counter() - total_start) * 1000 timing_info = { "预处理": prep_time, "上下文编码": context_encoder_time, "解码": decoder_time, "后处理": postprocess_time if not use_beam_search else 0, "总耗时": total_time, } return predictions, timing_info def _softmax(self, logits: np.ndarray) -> np.ndarray: """计算softmax""" exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: """获取top-k""" topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1] topk_probs = np.take_along_axis(probs, topk_indices, axis=-1) return topk_indices, topk_probs def _beam_search_decode( self, context_H: np.ndarray, pinyin_P: np.ndarray, context_mask: np.ndarray, pinyin_mask: np.ndarray, beam_size: int, max_length: int, top_k: int, ) -> Tuple[List[Tuple[str, float, int]], float]: """束搜索解码""" beams = [([], 0.0)] # (序列, 对数概率) for step in range(max_length): new_beams = [] for seq, score in beams: if len(seq) < 8: history = seq + [0] * (8 - len(seq)) else: history = seq[-8:] history_tensor = np.array([history], dtype=np.int64) logits = self.run_decoder( context_H, pinyin_P, history_tensor, context_mask, pinyin_mask, ) probs = self._softmax(logits)[0] topk_indices = np.argsort(probs)[-beam_size:][::-1] topk_probs = probs[topk_indices] for idx, prob in zip(topk_indices, topk_probs): new_seq = seq + [int(idx)] new_score = score + np.log(prob + 1e-10) new_beams.append((new_seq, new_score)) new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) if all_ended: break # 返回top-k个候选 predictions = [] for seq, score in beams[:top_k]: if seq: char = self.id_to_char(seq[-1]) prob = np.exp(score / max(len(seq), 1)) else: char = self.id_to_char(0) prob = 0.0 predictions.append((char, prob, seq[-1] if seq else 0)) decode_time = self.last_decoder_time # 只记录最后一次解码的时间 return predictions, decode_time def interactive_mode(self): """交互式推理模式""" print("\n" + "=" * 60) print("ONNX输入法模型推理 - 交互模式") print("=" * 60) encoding = sys.stdout.encoding or "unknown" print(f"终端编码: {encoding}") print("\n说明:") print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") print(" - 光标前文本: 光标前的连续文本") print(" - 光标后文本: 光标后的连续文本") print(" - 拼音: 当前输入的拼音") print(" - 槽位历史: 用户已确认的输入历史") if self.use_beam_search: print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})") else: print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)") print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出") print("-" * 60) while True: try: print("\n" + "=" * 60) context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)") if context_input.lower() in ["quit", "exit", "q"]: break context_prompts = [ item.strip() for item in context_input.split(",") if item.strip() ] print("\n" + "-" * 40) text_before = self._safe_input("第2步: 光标前文本") if text_before.lower() in ["quit", "exit", "q"]: break print("\n" + "-" * 40) text_after = self._safe_input("第3步: 光标后文本") if text_after.lower() in ["quit", "exit", "q"]: break print("\n" + "-" * 40) pinyin = self._safe_input("第4步: 拼音输入") if pinyin.lower() in ["quit", "exit", "q"]: break print("\n" + "-" * 40) slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)") if slot_input.lower() in ["quit", "exit", "q"]: break slot_chars = [ char.strip() for char in slot_input.split(",") if char.strip() ] print("\n" + "=" * 60) print("📝 输入汇总:") print(f" 上下文提示: {context_prompts if context_prompts else '无'}") print(f" 光标前文本: '{text_before}'") print(f" 光标后文本: '{text_after}'") print(f" 拼音: '{pinyin}'") print(f" 槽位历史: {slot_chars if slot_chars else '无'}") print("\n🔮 推理中...") predictions, timing_info = self.predict( context_prompts, text_before, text_after, pinyin, slot_chars, top_k=20, use_beam_search=self.use_beam_search, beam_size=self.beam_size, ) # 显示时间统计 print(f"\n⏱️ 执行时间统计:") print("-" * 40) for stage, duration in timing_info.items(): if duration > 0: print(f" {stage:<12}: {duration:>8.2f} ms") print("-" * 40) # 显示结果 print("\n🏆 Top-20 预测结果:") print("-" * 50) for i, (char, prob, idx) in enumerate(predictions): if char == "//": print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}") else: print( f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}" ) # 显示拼音参考 if pinyin and self.query_engine: print(f"\n📖 拼音 '{pinyin}' 的常见汉字:") pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10) if pinyin_results: for j, (pid, char, count) in enumerate(pinyin_results): print(f" {char} (频次: {count:,})") print("\n" + "-" * 40) continue_input = ( self._safe_input("是否继续推理?(y/n)", "y").strip().lower() ) if continue_input not in ["y", "yes", ""]: break except KeyboardInterrupt: print("\n\n退出交互模式") break except Exception as e: print(f"\n❌ 推理出错: {e}") import traceback traceback.print_exc() def main(): parser = argparse.ArgumentParser(description="ONNX输入法模型推理") parser.add_argument( "--context-encoder", "-c", type=str, required=True, help="上下文编码器ONNX模型路径", ) parser.add_argument( "--decoder", "-d", type=str, required=True, help="解码器ONNX模型路径", ) parser.add_argument( "--vocab-size", type=int, default=10019, help="词汇表大小 (默认: 10019)", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="推理设备 (默认: cpu)", ) parser.add_argument( "--interactive", action="store_true", default=True, help="交互模式 (默认: True)", ) parser.add_argument("--test", action="store_true", help="运行测试推理") parser.add_argument( "--beam", action="store_true", help="使用束搜索解码 (默认: 单步解码)", ) parser.add_argument( "--beam-size", type=int, default=5, help="束大小 (默认: 5)", ) args = parser.parse_args() # 检查文件是否存在 if not os.path.exists(args.context_encoder): print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}") sys.exit(1) if not os.path.exists(args.decoder): print(f"❌ 错误: 解码器文件不存在: {args.decoder}") sys.exit(1) # 初始化推理器 inference = ONNXInference( context_encoder_path=args.context_encoder, decoder_path=args.decoder, vocab_size=args.vocab_size, device=args.device, use_beam_search=args.beam, beam_size=args.beam_size, ) # 测试推理 if args.test: print("\n🧪 运行测试推理...") print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'") print("上下文提示: 张三、李四(模型不掌握的专有名词)") predictions, timing_info = inference.predict( context_prompts=["张三", "李四"], text_before="今天天气", text_after="很好", pinyin="tian", slot_chars=["上"], use_beam_search=args.beam, beam_size=args.beam_size, ) print(f"\n⏱️ 执行时间统计:") print("-" * 40) for stage, duration in timing_info.items(): if duration > 0: print(f" {stage:<12}: {duration:>8.2f} ms") print("-" * 40) print(f"\nTop-5 结果:") for i, (char, prob, idx) in enumerate(predictions[:5]): if char == "//": print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}") else: print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}") # 交互模式 if args.interactive: inference.interactive_mode() if __name__ == "__main__": main()