764 lines
25 KiB
Python
764 lines
25 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
ONNX输入法模型推理脚本
|
||
|
||
使用ONNX Runtime进行推理,测量每个阶段的执行时长
|
||
|
||
使用方法:
|
||
python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx
|
||
|
||
交互模式: 分步询问输入
|
||
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
|
||
2. 光标前文本: 光标前的连续文本
|
||
3. 光标后文本: 光标后的连续文本
|
||
4. 拼音: 当前输入的拼音
|
||
5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上")
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from typing import List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from modelscope import AutoTokenizer
|
||
|
||
from src.model.dataset import text_to_pinyin_ids
|
||
from src.model.query import QueryEngine
|
||
|
||
|
||
class ONNXInference:
|
||
"""ONNX输入法模型推理器"""
|
||
|
||
def __init__(
|
||
self,
|
||
context_encoder_path: str,
|
||
decoder_path: str,
|
||
vocab_size: int = 10019,
|
||
device: str = "cpu",
|
||
use_beam_search: bool = False,
|
||
beam_size: int = 5,
|
||
):
|
||
self.vocab_size = vocab_size
|
||
self.device = device
|
||
self.use_beam_search = use_beam_search
|
||
self.beam_size = beam_size
|
||
|
||
# 加载组件
|
||
print(f"正在加载上下文编码器: {context_encoder_path}")
|
||
load_start = time.perf_counter()
|
||
self.load_context_encoder(context_encoder_path)
|
||
self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000
|
||
print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)")
|
||
|
||
print(f"正在加载解码器: {decoder_path}")
|
||
load_start = time.perf_counter()
|
||
self.load_decoder(decoder_path)
|
||
self.decoder_load_time = (time.perf_counter() - load_start) * 1000
|
||
print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)")
|
||
|
||
# 加载tokenizer
|
||
print("正在加载tokenizer...")
|
||
load_start = time.perf_counter()
|
||
self.load_tokenizer()
|
||
self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000
|
||
print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)")
|
||
|
||
# 加载查询引擎
|
||
print("正在加载查询引擎...")
|
||
load_start = time.perf_counter()
|
||
self.load_query_engine()
|
||
self.query_engine_load_time = (time.perf_counter() - load_start) * 1000
|
||
print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)")
|
||
|
||
total_load_time = (
|
||
self.context_encoder_load_time
|
||
+ self.decoder_load_time
|
||
+ self.tokenizer_load_time
|
||
+ self.query_engine_load_time
|
||
)
|
||
print(f"\n✅ 推理器初始化完成 (设备: {device})")
|
||
print(f" 总加载时间: {total_load_time:.2f}ms")
|
||
|
||
# 尝试启用readline
|
||
try:
|
||
import readline
|
||
|
||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||
except ImportError:
|
||
pass
|
||
|
||
def load_context_encoder(self, model_path: str):
|
||
"""加载上下文编码器ONNX模型"""
|
||
providers = (
|
||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||
if self.device == "cuda"
|
||
else ["CPUExecutionProvider"]
|
||
)
|
||
self.context_encoder_session = ort.InferenceSession(
|
||
model_path, providers=providers
|
||
)
|
||
|
||
self.context_input_names = [
|
||
inp.name for inp in self.context_encoder_session.get_inputs()
|
||
]
|
||
self.context_output_names = [
|
||
out.name for out in self.context_encoder_session.get_outputs()
|
||
]
|
||
|
||
def load_decoder(self, model_path: str):
|
||
"""加载解码器ONNX模型"""
|
||
providers = (
|
||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||
if self.device == "cuda"
|
||
else ["CPUExecutionProvider"]
|
||
)
|
||
self.decoder_session = ort.InferenceSession(model_path, providers=providers)
|
||
|
||
self.decoder_input_names = [
|
||
inp.name for inp in self.decoder_session.get_inputs()
|
||
]
|
||
self.decoder_output_names = [
|
||
out.name for out in self.decoder_session.get_outputs()
|
||
]
|
||
|
||
def load_tokenizer(self):
|
||
"""加载tokenizer"""
|
||
try:
|
||
tokenizer_path = (
|
||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||
)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||
except Exception:
|
||
print(" ⚠️ 无法加载自定义tokenizer,使用bert-base-chinese")
|
||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||
|
||
def load_query_engine(self):
|
||
"""加载查询引擎"""
|
||
try:
|
||
self.query_engine = QueryEngine()
|
||
stats_path = (
|
||
Path(__file__).parent
|
||
/ "src"
|
||
/ "model"
|
||
/ "assets"
|
||
/ "pinyin_char_statistics.json"
|
||
)
|
||
if stats_path.exists():
|
||
self.query_engine.load(stats_path)
|
||
except Exception:
|
||
self.query_engine = None
|
||
|
||
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
|
||
"""将汉字转换为ID"""
|
||
if char == "//":
|
||
return 0
|
||
|
||
if self.query_engine is None:
|
||
return ord(char) if len(char) == 1 else 0
|
||
|
||
try:
|
||
if pinyin is not None:
|
||
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
|
||
if info:
|
||
return info.id
|
||
|
||
results = self.query_engine.query_by_char(char, limit=1)
|
||
if results:
|
||
return results[0][0]
|
||
return 0
|
||
except:
|
||
return 0
|
||
|
||
def id_to_char(self, id: int) -> str:
|
||
"""将ID转换为汉字"""
|
||
if id == 0:
|
||
return "//"
|
||
|
||
if self.query_engine is None:
|
||
return chr(id) if id < 0x110000 else f"[ID:{id}]"
|
||
|
||
try:
|
||
info = self.query_engine.query_by_id(id)
|
||
return info.char if info else f"[ID:{id}]"
|
||
except:
|
||
return f"[ID:{id}]"
|
||
|
||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||
"""清理拼音输入字符串"""
|
||
if not pinyin:
|
||
return ""
|
||
|
||
result = []
|
||
for c in pinyin:
|
||
is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"]
|
||
if is_valid:
|
||
result.append(c.lower())
|
||
elif c == " ":
|
||
continue
|
||
elif c in ["\b", "\x7f", "\x08"]:
|
||
if result:
|
||
result.pop()
|
||
elif c == "\x1b":
|
||
result.clear()
|
||
return "".join(result)
|
||
|
||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||
"""安全的输入函数"""
|
||
try:
|
||
full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: "
|
||
result = input(full_prompt)
|
||
if not result and default:
|
||
return default
|
||
return result.strip()
|
||
except (EOFError, KeyboardInterrupt):
|
||
print()
|
||
return ""
|
||
except Exception as e:
|
||
print(f"\n⚠️ 输入错误: {e}")
|
||
return ""
|
||
|
||
def prepare_inputs(
|
||
self,
|
||
context_prompts: List[str],
|
||
text_before: str,
|
||
text_after: str,
|
||
pinyin: str,
|
||
slot_chars: List[str],
|
||
max_seq_len: int = 128,
|
||
) -> dict:
|
||
"""
|
||
准备模型输入
|
||
|
||
Returns:
|
||
dict: {
|
||
'preprocess_time': float, # 预处理时间(ms)
|
||
'input_ids': numpy array,
|
||
'attention_mask': numpy array,
|
||
'pinyin_ids': numpy array,
|
||
'history_slot_ids': numpy array,
|
||
}
|
||
"""
|
||
preprocess_start = time.perf_counter()
|
||
|
||
# 1. 构建tokenizer输入
|
||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||
|
||
if context_text:
|
||
input_text = f"{context_text}|{text_before}"
|
||
else:
|
||
input_text = text_before
|
||
|
||
# 2. Tokenize
|
||
encoded = self.tokenizer(
|
||
input_text,
|
||
text_after,
|
||
max_length=max_seq_len,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
input_ids = encoded["input_ids"].numpy()
|
||
attention_mask = encoded["attention_mask"].numpy()
|
||
|
||
# 3. 处理拼音输入
|
||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||
|
||
if len(pinyin_ids) < 24:
|
||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
|
||
pinyin_ids = np.array([pinyin_ids], dtype=np.int64)
|
||
|
||
# 4. 处理历史槽位
|
||
history_slot_ids = []
|
||
for char in slot_chars:
|
||
char_id = self.char_to_id(char)
|
||
history_slot_ids.append(char_id)
|
||
|
||
if len(history_slot_ids) < 8:
|
||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||
else:
|
||
history_slot_ids = history_slot_ids[:8]
|
||
|
||
history_slot_ids = np.array([history_slot_ids], dtype=np.int64)
|
||
|
||
preprocess_time = (time.perf_counter() - preprocess_start) * 1000
|
||
|
||
return {
|
||
"preprocess_time": preprocess_time,
|
||
"input_ids": input_ids,
|
||
"attention_mask": attention_mask,
|
||
"pinyin_ids": pinyin_ids,
|
||
"history_slot_ids": history_slot_ids,
|
||
}
|
||
|
||
def run_context_encoder(
|
||
self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray
|
||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||
"""
|
||
运行上下文编码器
|
||
|
||
Returns:
|
||
context_H, pinyin_P, context_mask, pinyin_mask
|
||
"""
|
||
context_start = time.perf_counter()
|
||
|
||
inputs = {
|
||
"input_ids": input_ids,
|
||
"pinyin_ids": pinyin_ids,
|
||
"attention_mask": attention_mask,
|
||
}
|
||
|
||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||
|
||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||
|
||
self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000
|
||
|
||
return context_H, pinyin_P, context_mask, pinyin_mask
|
||
|
||
def run_decoder(
|
||
self,
|
||
context_H: np.ndarray,
|
||
pinyin_P: np.ndarray,
|
||
history_slot_ids: np.ndarray,
|
||
context_mask: np.ndarray,
|
||
pinyin_mask: np.ndarray,
|
||
) -> np.ndarray:
|
||
"""
|
||
运行解码器
|
||
|
||
Returns:
|
||
logits: [batch, vocab_size]
|
||
"""
|
||
decoder_start = time.perf_counter()
|
||
|
||
inputs = {
|
||
"context_H": context_H,
|
||
"pinyin_P": pinyin_P,
|
||
"history_slot_ids": history_slot_ids,
|
||
"context_mask": context_mask,
|
||
"pinyin_mask": pinyin_mask,
|
||
}
|
||
|
||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||
|
||
self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000
|
||
|
||
return outputs[0]
|
||
|
||
def predict(
|
||
self,
|
||
context_prompts: List[str],
|
||
text_before: str,
|
||
text_after: str,
|
||
pinyin: str,
|
||
slot_chars: List[str],
|
||
top_k: int = 20,
|
||
use_beam_search: bool = False,
|
||
beam_size: int = 5,
|
||
max_length: int = 10,
|
||
) -> Tuple[List[Tuple[str, float, int]], dict]:
|
||
"""
|
||
执行推理
|
||
|
||
Args:
|
||
context_prompts: 上下文提示
|
||
text_before: 光标前文本
|
||
text_after: 光标后文本
|
||
pinyin: 当前输入的拼音
|
||
slot_chars: 槽位内的汉字列表
|
||
top_k: 返回top-k个预测结果
|
||
use_beam_search: 是否使用束搜索
|
||
beam_size: 束大小
|
||
max_length: 最大生成长度
|
||
|
||
Returns:
|
||
(predictions, timing_info)
|
||
predictions: List[Tuple[char, prob, id]]
|
||
timing_info: 各阶段耗时字典
|
||
"""
|
||
total_start = time.perf_counter()
|
||
|
||
# 阶段1: 预处理
|
||
prep_start = time.perf_counter()
|
||
inputs = self.prepare_inputs(
|
||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||
)
|
||
preprocess_time = inputs["preprocess_time"]
|
||
|
||
input_ids = inputs["input_ids"]
|
||
attention_mask = inputs["attention_mask"]
|
||
pinyin_ids = inputs["pinyin_ids"]
|
||
history_slot_ids = inputs["history_slot_ids"]
|
||
prep_time = (time.perf_counter() - prep_start) * 1000
|
||
|
||
# 阶段2: 上下文编码
|
||
context_start = time.perf_counter()
|
||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||
input_ids, pinyin_ids, attention_mask
|
||
)
|
||
context_encoder_time = self.last_context_encoder_time
|
||
|
||
if use_beam_search:
|
||
# 阶段3: 束搜索解码
|
||
decode_start = time.perf_counter()
|
||
predictions, beam_decode_time = self._beam_search_decode(
|
||
context_H,
|
||
pinyin_P,
|
||
context_mask,
|
||
pinyin_mask,
|
||
beam_size,
|
||
max_length,
|
||
top_k,
|
||
)
|
||
decoder_time = beam_decode_time
|
||
else:
|
||
# 阶段3: 单步解码
|
||
decode_start = time.perf_counter()
|
||
logits = self.run_decoder(
|
||
context_H,
|
||
pinyin_P,
|
||
history_slot_ids,
|
||
context_mask,
|
||
pinyin_mask,
|
||
)
|
||
|
||
# 阶段4: 后处理
|
||
postprocess_start = time.perf_counter()
|
||
probs = self._softmax(logits)
|
||
top_indices, top_probs = self._topk(probs, top_k)
|
||
|
||
predictions = []
|
||
for i in range(top_k):
|
||
idx = int(top_indices[0, i])
|
||
prob = float(top_probs[0, i])
|
||
char = self.id_to_char(idx)
|
||
predictions.append((char, prob, idx))
|
||
|
||
postprocess_time = (time.perf_counter() - postprocess_start) * 1000
|
||
decoder_time = self.last_decoder_time
|
||
|
||
total_time = (time.perf_counter() - total_start) * 1000
|
||
|
||
timing_info = {
|
||
"预处理": prep_time,
|
||
"上下文编码": context_encoder_time,
|
||
"解码": decoder_time,
|
||
"后处理": postprocess_time if not use_beam_search else 0,
|
||
"总耗时": total_time,
|
||
}
|
||
|
||
return predictions, timing_info
|
||
|
||
def _softmax(self, logits: np.ndarray) -> np.ndarray:
|
||
"""计算softmax"""
|
||
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
|
||
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
|
||
|
||
def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""获取top-k"""
|
||
topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1]
|
||
topk_probs = np.take_along_axis(probs, topk_indices, axis=-1)
|
||
return topk_indices, topk_probs
|
||
|
||
def _beam_search_decode(
|
||
self,
|
||
context_H: np.ndarray,
|
||
pinyin_P: np.ndarray,
|
||
context_mask: np.ndarray,
|
||
pinyin_mask: np.ndarray,
|
||
beam_size: int,
|
||
max_length: int,
|
||
top_k: int,
|
||
) -> Tuple[List[Tuple[str, float, int]], float]:
|
||
"""束搜索解码"""
|
||
beams = [([], 0.0)] # (序列, 对数概率)
|
||
|
||
for step in range(max_length):
|
||
new_beams = []
|
||
|
||
for seq, score in beams:
|
||
if len(seq) < 8:
|
||
history = seq + [0] * (8 - len(seq))
|
||
else:
|
||
history = seq[-8:]
|
||
|
||
history_tensor = np.array([history], dtype=np.int64)
|
||
|
||
logits = self.run_decoder(
|
||
context_H,
|
||
pinyin_P,
|
||
history_tensor,
|
||
context_mask,
|
||
pinyin_mask,
|
||
)
|
||
|
||
probs = self._softmax(logits)[0]
|
||
topk_indices = np.argsort(probs)[-beam_size:][::-1]
|
||
topk_probs = probs[topk_indices]
|
||
|
||
for idx, prob in zip(topk_indices, topk_probs):
|
||
new_seq = seq + [int(idx)]
|
||
new_score = score + np.log(prob + 1e-10)
|
||
new_beams.append((new_seq, new_score))
|
||
|
||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||
beams = new_beams[:beam_size]
|
||
|
||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||
if all_ended:
|
||
break
|
||
|
||
# 返回top-k个候选
|
||
predictions = []
|
||
for seq, score in beams[:top_k]:
|
||
if seq:
|
||
char = self.id_to_char(seq[-1])
|
||
prob = np.exp(score / max(len(seq), 1))
|
||
else:
|
||
char = self.id_to_char(0)
|
||
prob = 0.0
|
||
predictions.append((char, prob, seq[-1] if seq else 0))
|
||
|
||
decode_time = self.last_decoder_time # 只记录最后一次解码的时间
|
||
|
||
return predictions, decode_time
|
||
|
||
def interactive_mode(self):
|
||
"""交互式推理模式"""
|
||
print("\n" + "=" * 60)
|
||
print("ONNX输入法模型推理 - 交互模式")
|
||
print("=" * 60)
|
||
|
||
encoding = sys.stdout.encoding or "unknown"
|
||
print(f"终端编码: {encoding}")
|
||
|
||
print("\n说明:")
|
||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||
print(" - 光标前文本: 光标前的连续文本")
|
||
print(" - 光标后文本: 光标后的连续文本")
|
||
print(" - 拼音: 当前输入的拼音")
|
||
print(" - 槽位历史: 用户已确认的输入历史")
|
||
if self.use_beam_search:
|
||
print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})")
|
||
else:
|
||
print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)")
|
||
print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出")
|
||
print("-" * 60)
|
||
|
||
while True:
|
||
try:
|
||
print("\n" + "=" * 60)
|
||
context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)")
|
||
if context_input.lower() in ["quit", "exit", "q"]:
|
||
break
|
||
|
||
context_prompts = [
|
||
item.strip() for item in context_input.split(",") if item.strip()
|
||
]
|
||
|
||
print("\n" + "-" * 40)
|
||
text_before = self._safe_input("第2步: 光标前文本")
|
||
if text_before.lower() in ["quit", "exit", "q"]:
|
||
break
|
||
|
||
print("\n" + "-" * 40)
|
||
text_after = self._safe_input("第3步: 光标后文本")
|
||
if text_after.lower() in ["quit", "exit", "q"]:
|
||
break
|
||
|
||
print("\n" + "-" * 40)
|
||
pinyin = self._safe_input("第4步: 拼音输入")
|
||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||
break
|
||
|
||
print("\n" + "-" * 40)
|
||
slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)")
|
||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||
break
|
||
|
||
slot_chars = [
|
||
char.strip() for char in slot_input.split(",") if char.strip()
|
||
]
|
||
|
||
print("\n" + "=" * 60)
|
||
print("📝 输入汇总:")
|
||
print(f" 上下文提示: {context_prompts if context_prompts else '无'}")
|
||
print(f" 光标前文本: '{text_before}'")
|
||
print(f" 光标后文本: '{text_after}'")
|
||
print(f" 拼音: '{pinyin}'")
|
||
print(f" 槽位历史: {slot_chars if slot_chars else '无'}")
|
||
|
||
print("\n🔮 推理中...")
|
||
predictions, timing_info = self.predict(
|
||
context_prompts,
|
||
text_before,
|
||
text_after,
|
||
pinyin,
|
||
slot_chars,
|
||
top_k=20,
|
||
use_beam_search=self.use_beam_search,
|
||
beam_size=self.beam_size,
|
||
)
|
||
|
||
# 显示时间统计
|
||
print(f"\n⏱️ 执行时间统计:")
|
||
print("-" * 40)
|
||
for stage, duration in timing_info.items():
|
||
if duration > 0:
|
||
print(f" {stage:<12}: {duration:>8.2f} ms")
|
||
print("-" * 40)
|
||
|
||
# 显示结果
|
||
print("\n🏆 Top-20 预测结果:")
|
||
print("-" * 50)
|
||
for i, (char, prob, idx) in enumerate(predictions):
|
||
if char == "//":
|
||
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
|
||
else:
|
||
print(
|
||
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
|
||
)
|
||
|
||
# 显示拼音参考
|
||
if pinyin and self.query_engine:
|
||
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
|
||
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
|
||
if pinyin_results:
|
||
for j, (pid, char, count) in enumerate(pinyin_results):
|
||
print(f" {char} (频次: {count:,})")
|
||
|
||
print("\n" + "-" * 40)
|
||
continue_input = (
|
||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||
)
|
||
if continue_input not in ["y", "yes", ""]:
|
||
break
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n退出交互模式")
|
||
break
|
||
except Exception as e:
|
||
print(f"\n❌ 推理出错: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="ONNX输入法模型推理")
|
||
parser.add_argument(
|
||
"--context-encoder",
|
||
"-c",
|
||
type=str,
|
||
required=True,
|
||
help="上下文编码器ONNX模型路径",
|
||
)
|
||
parser.add_argument(
|
||
"--decoder",
|
||
"-d",
|
||
type=str,
|
||
required=True,
|
||
help="解码器ONNX模型路径",
|
||
)
|
||
parser.add_argument(
|
||
"--vocab-size",
|
||
type=int,
|
||
default=10019,
|
||
help="词汇表大小 (默认: 10019)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cpu",
|
||
choices=["cpu", "cuda"],
|
||
help="推理设备 (默认: cpu)",
|
||
)
|
||
parser.add_argument(
|
||
"--interactive",
|
||
action="store_true",
|
||
default=True,
|
||
help="交互模式 (默认: True)",
|
||
)
|
||
parser.add_argument("--test", action="store_true", help="运行测试推理")
|
||
parser.add_argument(
|
||
"--beam",
|
||
action="store_true",
|
||
help="使用束搜索解码 (默认: 单步解码)",
|
||
)
|
||
parser.add_argument(
|
||
"--beam-size",
|
||
type=int,
|
||
default=5,
|
||
help="束大小 (默认: 5)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(args.context_encoder):
|
||
print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}")
|
||
sys.exit(1)
|
||
if not os.path.exists(args.decoder):
|
||
print(f"❌ 错误: 解码器文件不存在: {args.decoder}")
|
||
sys.exit(1)
|
||
|
||
# 初始化推理器
|
||
inference = ONNXInference(
|
||
context_encoder_path=args.context_encoder,
|
||
decoder_path=args.decoder,
|
||
vocab_size=args.vocab_size,
|
||
device=args.device,
|
||
use_beam_search=args.beam,
|
||
beam_size=args.beam_size,
|
||
)
|
||
|
||
# 测试推理
|
||
if args.test:
|
||
print("\n🧪 运行测试推理...")
|
||
print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'")
|
||
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
|
||
|
||
predictions, timing_info = inference.predict(
|
||
context_prompts=["张三", "李四"],
|
||
text_before="今天天气",
|
||
text_after="很好",
|
||
pinyin="tian",
|
||
slot_chars=["上"],
|
||
use_beam_search=args.beam,
|
||
beam_size=args.beam_size,
|
||
)
|
||
|
||
print(f"\n⏱️ 执行时间统计:")
|
||
print("-" * 40)
|
||
for stage, duration in timing_info.items():
|
||
if duration > 0:
|
||
print(f" {stage:<12}: {duration:>8.2f} ms")
|
||
print("-" * 40)
|
||
|
||
print(f"\nTop-5 结果:")
|
||
for i, (char, prob, idx) in enumerate(predictions[:5]):
|
||
if char == "//":
|
||
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
|
||
else:
|
||
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
|
||
|
||
# 交互模式
|
||
if args.interactive:
|
||
inference.interactive_mode()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|