714 lines
26 KiB
Python
714 lines
26 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
输入法模型推理脚本
|
||
|
||
使用方法:
|
||
python inference.py --checkpoint ./output/checkpoints/best_model.pt
|
||
|
||
交互模式: 分步询问输入
|
||
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
|
||
2. 光标前文本: 光标前的连续文本
|
||
3. 光标后文本: 光标后的连续文本
|
||
4. 拼音: 当前输入的拼音
|
||
5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上")
|
||
|
||
示例场景:
|
||
输入"shanghai"已确认"上",继续输入"tian"
|
||
上下文提示: 张三,李四
|
||
光标前文本: 今天天气很好
|
||
光标后文本: 我们去公园玩
|
||
拼音: tian
|
||
槽位历史: 上
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from typing import List, Optional, Tuple
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from modelscope import AutoTokenizer
|
||
|
||
from src.model.dataset import text_to_pinyin_ids
|
||
from src.model.model import InputMethodEngine
|
||
from src.model.query import QueryEngine
|
||
|
||
|
||
class InputMethodInference:
|
||
"""输入法模型推理器"""
|
||
|
||
def __init__(
|
||
self,
|
||
checkpoint_path: str,
|
||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||
):
|
||
self.device = torch.device(device)
|
||
self.checkpoint_path = checkpoint_path
|
||
|
||
# 加载组件
|
||
print(f"正在加载模型从: {checkpoint_path}")
|
||
self.load_model()
|
||
|
||
# 加载tokenizer
|
||
print("正在加载tokenizer...")
|
||
self.load_tokenizer()
|
||
|
||
# 加载查询引擎
|
||
print("正在加载查询引擎...")
|
||
self.load_query_engine()
|
||
|
||
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
||
|
||
# 尝试启用readline以获得更好的行编辑功能
|
||
try:
|
||
import readline
|
||
|
||
# 设置readline使用UTF-8编码
|
||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||
print("📝 readline已启用,支持更好的行编辑功能")
|
||
except ImportError:
|
||
print("📝 readline不可用,使用标准输入")
|
||
|
||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||
"""
|
||
安全的输入函数,尝试正确处理UTF-8字符和退格键
|
||
|
||
Args:
|
||
prompt: 提示文本
|
||
default: 默认值
|
||
|
||
Returns:
|
||
用户输入的字符串
|
||
"""
|
||
try:
|
||
# 显示提示和默认值
|
||
if default:
|
||
full_prompt = f"{prompt} [{default}]: "
|
||
else:
|
||
full_prompt = f"{prompt}: "
|
||
|
||
# 使用标准input
|
||
result = input(full_prompt)
|
||
|
||
# 如果用户直接回车且存在默认值,则返回默认值
|
||
if not result and default:
|
||
return default
|
||
|
||
return result.strip()
|
||
except (EOFError, KeyboardInterrupt):
|
||
# 用户按Ctrl+D或Ctrl+C
|
||
print()
|
||
return ""
|
||
except Exception as e:
|
||
# 其他错误
|
||
print(f"\n⚠️ 输入错误: {e}")
|
||
return ""
|
||
|
||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||
"""
|
||
清理拼音输入字符串,处理退格键等特殊字符
|
||
|
||
拼音只允许: a-z, `, ', -
|
||
中文字符和其他字符会被忽略
|
||
|
||
Args:
|
||
pinyin: 原始拼音输入字符串
|
||
|
||
Returns:
|
||
清理后的拼音字符串
|
||
"""
|
||
if not pinyin:
|
||
return ""
|
||
|
||
result = []
|
||
for c in pinyin:
|
||
# 检查是否为合法拼音字符 (a-z, `, ', -)
|
||
# 注意: 中文字符的isalpha()也返回True,所以需要额外检查
|
||
is_valid_pinyin_char = (
|
||
("a" <= c <= "z")
|
||
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
|
||
or c in ["`", "'", "-"]
|
||
)
|
||
|
||
if is_valid_pinyin_char:
|
||
# 合法拼音字符,转换为小写
|
||
result.append(c.lower())
|
||
elif c == " ":
|
||
# 空格忽略
|
||
continue
|
||
elif c == "\b" or c == "\x7f" or c == "\x08":
|
||
# 退格键、删除键:删除前一个字符
|
||
if result:
|
||
result.pop()
|
||
elif c == "\x1b":
|
||
# ESC键:清空所有输入
|
||
result.clear()
|
||
else:
|
||
# 其他字符(包括中文字符)忽略
|
||
# 注意:这里不添加到result,所以退格键无法删除它们
|
||
# 但用户可能在拼音输入中误输入中文字符,应该忽略
|
||
pass
|
||
|
||
return "".join(result)
|
||
|
||
def load_model(self):
|
||
"""加载训练好的模型"""
|
||
# 创建模型实例(不编译)
|
||
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||
|
||
# 加载checkpoint
|
||
# 加载训练好的权重(强制先加载到CPU,再移动到目标设备)
|
||
# 这样确保GPU训练的权重能正确转换到CPU
|
||
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
||
if "model_state_dict" in checkpoint:
|
||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||
else:
|
||
self.model.load_state_dict(checkpoint)
|
||
|
||
self.model.eval()
|
||
self.model.to(self.device)
|
||
print(
|
||
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
|
||
)
|
||
|
||
def load_tokenizer(self):
|
||
"""加载tokenizer"""
|
||
try:
|
||
# 从assets/tokenizer加载tokenizer
|
||
tokenizer_path = (
|
||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||
)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||
except Exception as e:
|
||
print(f"⚠️ 无法加载tokenizer: {e}")
|
||
print("使用默认的bert-base-chinese tokenizer")
|
||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||
|
||
def load_query_engine(self):
|
||
"""加载查询引擎用于字符-ID转换"""
|
||
try:
|
||
self.query_engine = QueryEngine()
|
||
# 加载默认的统计文件
|
||
stats_path = (
|
||
Path(__file__).parent
|
||
/ "src"
|
||
/ "model"
|
||
/ "assets"
|
||
/ "pinyin_char_statistics.json"
|
||
)
|
||
if stats_path.exists():
|
||
self.query_engine.load(stats_path)
|
||
print(
|
||
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
|
||
)
|
||
else:
|
||
print(f"⚠️ 统计文件不存在: {stats_path}")
|
||
self.query_engine = None
|
||
except Exception as e:
|
||
print(f"⚠️ 无法加载查询引擎: {e}")
|
||
self.query_engine = None
|
||
|
||
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
|
||
"""将汉字转换为ID,如果提供拼音则更精确"""
|
||
# 处理结束符
|
||
if char == "//":
|
||
return 0 # 假设0是结束符ID
|
||
|
||
if self.query_engine is None:
|
||
# 简单回退:使用unicode编码
|
||
return ord(char) if len(char) == 1 else 0
|
||
|
||
try:
|
||
if pinyin is not None:
|
||
# 使用精确的字符-拼音对
|
||
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
|
||
if info:
|
||
return info.id
|
||
|
||
# 回退:获取字符的第一个拼音变体
|
||
results = self.query_engine.query_by_char(char, limit=1)
|
||
if results:
|
||
return results[0][0] # 返回ID
|
||
|
||
return 0
|
||
except:
|
||
return 0
|
||
|
||
def id_to_char(self, id: int) -> str:
|
||
"""将ID转换为汉字"""
|
||
# 处理结束符ID (假设0是结束符)
|
||
if id == 0:
|
||
return "//"
|
||
|
||
if self.query_engine is None:
|
||
return chr(id) if id < 0x110000 else "<?>"
|
||
|
||
try:
|
||
info = self.query_engine.query_by_id(id)
|
||
return info.char if info else f"[ID:{id}]"
|
||
except:
|
||
return f"[ID:{id}]"
|
||
|
||
def prepare_inputs(
|
||
self,
|
||
context_prompts: List[str],
|
||
text_before: str,
|
||
text_after: str,
|
||
pinyin: str,
|
||
slot_chars: List[str],
|
||
max_seq_len: int = 128,
|
||
):
|
||
"""
|
||
准备模型输入
|
||
|
||
Args:
|
||
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
|
||
text_before: 光标前文本
|
||
text_after: 光标后文本
|
||
pinyin: 当前输入的拼音
|
||
slot_chars: 槽位内的汉字列表(用户已确认的输入历史)
|
||
max_seq_len: 最大序列长度
|
||
|
||
Returns:
|
||
模型输入字典
|
||
"""
|
||
|
||
# 1. 构建tokenizer输入
|
||
# 根据test.py和dataset.py,格式为: "part4|part1" 和 part3
|
||
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
||
# part1: text_before
|
||
# part3: text_after
|
||
|
||
# 处理上下文提示
|
||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||
|
||
# 构建输入文本 - 与test.py保持一致
|
||
# test.py: f"{part4}|{part1}" 作为第一个参数,part3作为第二个参数
|
||
if context_text:
|
||
input_text = f"{context_text}|{text_before}"
|
||
else:
|
||
input_text = text_before
|
||
|
||
# 2. Tokenize - 与test.py保持一致
|
||
encoded = self.tokenizer(
|
||
input_text,
|
||
text_after,
|
||
max_length=max_seq_len,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
# 3. 处理拼音输入 - 与test.py保持一致
|
||
# 首先清理拼音字符串,处理退格键等特殊字符
|
||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||
|
||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||
if len(pinyin_ids) < 24:
|
||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long)
|
||
|
||
# 4. 处理历史槽位(用户已确认的输入历史)
|
||
history_slot_ids = []
|
||
for char in slot_chars:
|
||
# 为每个槽位汉字查找ID(用户已确认的输入历史)
|
||
char_id = self.char_to_id(char)
|
||
history_slot_ids.append(char_id)
|
||
|
||
# 填充到8个槽位
|
||
if len(history_slot_ids) < 8:
|
||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||
else:
|
||
history_slot_ids = history_slot_ids[:8]
|
||
|
||
history_tensor = torch.tensor([history_slot_ids], dtype=torch.long)
|
||
|
||
# 5. 移动到设备
|
||
inputs = {
|
||
"input_ids": encoded["input_ids"].to(self.device),
|
||
"token_type_ids": encoded["token_type_ids"].to(self.device),
|
||
"attention_mask": encoded["attention_mask"].to(self.device),
|
||
"pinyin_ids": pinyin_tensor.to(self.device),
|
||
"history_slot_ids": history_tensor.to(self.device),
|
||
}
|
||
|
||
return inputs
|
||
|
||
def predict(
|
||
self,
|
||
context_prompts: List[str],
|
||
text_before: str,
|
||
text_after: str,
|
||
pinyin: str,
|
||
slot_chars: List[str],
|
||
top_k: int = 20,
|
||
) -> Tuple[List[Tuple[str, float, int]], float]:
|
||
"""
|
||
执行推理
|
||
|
||
Args:
|
||
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
|
||
text_before: 光标前文本
|
||
text_after: 光标后文本
|
||
pinyin: 当前输入的拼音
|
||
slot_chars: 槽位内的汉字列表(用户已确认的输入历史,最大8个)
|
||
top_k: 返回top-k个预测结果
|
||
|
||
Returns:
|
||
(predictions, inference_time_ms)
|
||
predictions: List[Tuple[char, score, id]]
|
||
"""
|
||
start_time = time.perf_counter()
|
||
|
||
# 准备输入
|
||
inputs = self.prepare_inputs(
|
||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||
)
|
||
|
||
# 调试信息:打印输入形状
|
||
print("\n🔍 调试信息 - 输入检查:")
|
||
for key, tensor in inputs.items():
|
||
print(
|
||
f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}"
|
||
)
|
||
if key in ["history_slot_ids", "pinyin_ids"]:
|
||
print(f" 值: {tensor.cpu().numpy().tolist()}")
|
||
|
||
# 特别检查拼音和槽位输入
|
||
print("\n🔍 拼音输入详细分析:")
|
||
pinyin_tensor = inputs["pinyin_ids"]
|
||
pinyin_values = pinyin_tensor.cpu().numpy()[0]
|
||
# 将拼音ID转换回字符
|
||
from src.model.dataset import text_to_pinyin_ids
|
||
|
||
# 逆向转换可能需要一个反向映射,这里先简单显示
|
||
print(f" 拼音ID列表: {pinyin_values}")
|
||
print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}")
|
||
|
||
print("\n🔍 槽位历史详细分析:")
|
||
slot_tensor = inputs["history_slot_ids"]
|
||
slot_values = slot_tensor.cpu().numpy()[0]
|
||
print(f" 槽位ID列表: {slot_values}")
|
||
print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}")
|
||
# 将ID转换为汉字
|
||
slot_chars_converted = [self.id_to_char(id) for id in slot_values]
|
||
print(f" 槽位汉字: {slot_chars_converted}")
|
||
|
||
# 调试:检查模型权重
|
||
print(f"\n🔍 调试信息 - 模型检查:")
|
||
print(f" 模型设备: {self.device}")
|
||
print(f" 模型是否训练模式: {self.model.training}")
|
||
|
||
# 检查模型词汇表大小
|
||
vocab_size = self.model.vocab_size
|
||
print(f" 模型词汇表大小: {vocab_size}")
|
||
|
||
# 检查前5个ID对应的字符
|
||
print(f" 前5个ID对应的字符:")
|
||
for i in range(1, 6):
|
||
char = self.id_to_char(i)
|
||
print(f" ID {i}: '{char}'")
|
||
|
||
# 检查分类头权重
|
||
# 调试:检查分类头权重
|
||
classifier_weight = self.model.classifier.weight.data
|
||
classifier_bias = self.model.classifier.bias.data
|
||
print(f" 分类头权重形状: {classifier_weight.shape}")
|
||
print(
|
||
f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]"
|
||
)
|
||
print(f" 分类头权重均值: {classifier_weight.mean():.4f}")
|
||
print(f" 分类头权重标准差: {classifier_weight.std():.4f}")
|
||
print(f" 分类头偏置形状: {classifier_bias.shape}")
|
||
print(
|
||
f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]"
|
||
)
|
||
print(f" 分类头偏置均值: {classifier_bias.mean():.4f}")
|
||
print(f" 分类头偏置标准差: {classifier_bias.std():.4f}")
|
||
|
||
# 推理(CPU推理时禁用混合精度)
|
||
with torch.no_grad():
|
||
if self.device.type == "cuda":
|
||
with torch.autocast(device_type="cuda"):
|
||
logits = self.model(**inputs)
|
||
else:
|
||
# CPU推理时不使用autocast
|
||
logits = self.model(**inputs)
|
||
|
||
# 调试:检查logits
|
||
print(f"\n🔍 调试信息 - 输出检查:")
|
||
print(f" Logits形状: {logits.shape}")
|
||
print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
|
||
print(f" Logits均值: {logits.mean():.4f}")
|
||
|
||
# 检查logits中最大值和对应的ID
|
||
max_val, max_idx = torch.max(logits, dim=-1)
|
||
print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}")
|
||
|
||
# 获取top-k预测
|
||
probs = F.softmax(logits, dim=-1)
|
||
top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
|
||
|
||
# 调试:检查概率分布
|
||
print(f" 概率总和: {probs.sum().item():.4f}")
|
||
top_probs_array = top_probs.cpu().numpy().flatten()
|
||
top_indices_array = top_indices.cpu().numpy().flatten()
|
||
print(f" Top-{top_k}概率: {top_probs_array}")
|
||
print(f" Top-{top_k} ID: {top_indices_array}")
|
||
|
||
# 检查概率是否均匀分布
|
||
print(f" 概率分布分析:")
|
||
print(f" 平均概率: {probs.mean().item():.6f}")
|
||
print(f" 最大概率: {probs.max().item():.6f}")
|
||
print(f" 最小概率: {probs.min().item():.6f}")
|
||
print(f" 标准差: {probs.std().item():.6f}")
|
||
|
||
# 检查top-20概率是否都很小
|
||
if top_probs_array[0] < 0.01:
|
||
print(
|
||
f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练"
|
||
)
|
||
print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配")
|
||
|
||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||
|
||
# 转换为可读结果
|
||
predictions = []
|
||
for i in range(top_k):
|
||
idx = int(top_indices[0, i].item())
|
||
prob = top_probs[0, i].item()
|
||
char = self.id_to_char(idx)
|
||
predictions.append((char, prob, idx))
|
||
|
||
return predictions, inference_time_ms
|
||
|
||
def interactive_mode(self):
|
||
"""交互式推理模式 - 分步询问输入"""
|
||
print("\n" + "=" * 60)
|
||
print("输入法模型推理 - 交互模式")
|
||
print("=" * 60)
|
||
|
||
# 检查终端编码
|
||
encoding = sys.stdout.encoding or "unknown"
|
||
print(f"终端编码: {encoding}")
|
||
if encoding.lower() not in ["utf-8", "utf8"]:
|
||
print("⚠️ 警告: 终端编码不是UTF-8,中文输入可能有问题")
|
||
print(" 建议设置: export LANG=en_US.UTF-8")
|
||
print(" 或设置: export LC_ALL=en_US.UTF-8")
|
||
|
||
print("\n说明:")
|
||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||
print(" - 光标前文本: 光标前的连续文本")
|
||
print(" - 光标后文本: 光标后的连续文本")
|
||
print(" - 拼音: 当前输入的拼音")
|
||
print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认'上'")
|
||
print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出")
|
||
print("-" * 60)
|
||
|
||
while True:
|
||
try:
|
||
print("\n" + "=" * 60)
|
||
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
||
print("格式: 用逗号分隔多个词汇,可为空")
|
||
print("示例: 张三,李四,北京大学")
|
||
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
|
||
|
||
if context_input.lower() in ["quit", "exit", "q"]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
# 解析上下文提示
|
||
context_prompts = []
|
||
if context_input:
|
||
context_prompts = [
|
||
item.strip()
|
||
for item in context_input.split(",")
|
||
if item.strip()
|
||
]
|
||
|
||
print(
|
||
f"✅ 已记录上下文提示: {context_prompts if context_prompts else '无'}"
|
||
)
|
||
|
||
print("\n" + "-" * 40)
|
||
print("第2步: 光标前文本")
|
||
print("说明: 光标前的连续文本内容")
|
||
print("示例: 今天天气很好")
|
||
text_before = self._safe_input("请输入光标前文本")
|
||
|
||
if text_before.lower() in ["quit", "exit", "q"]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
print(f"✅ 已记录光标前文本: '{text_before}'")
|
||
|
||
print("\n" + "-" * 40)
|
||
print("第3步: 光标后文本")
|
||
print("说明: 光标后的连续文本内容")
|
||
print("示例: 我们去公园玩")
|
||
text_after = self._safe_input("请输入光标后文本")
|
||
|
||
if text_after.lower() in ["quit", "exit", "q"]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
print(f"✅ 已记录光标后文本: '{text_after}'")
|
||
|
||
print("\n" + "-" * 40)
|
||
print("第4步: 拼音输入")
|
||
print("说明: 当前正在输入的拼音")
|
||
print("示例: tian, shang, hao")
|
||
pinyin = self._safe_input("请输入拼音")
|
||
|
||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
print(f"✅ 已记录拼音: '{pinyin}'")
|
||
|
||
print("\n" + "-" * 40)
|
||
print("第5步: 槽位历史(已确认的输入)")
|
||
print("说明: 用户已确认的输入历史,用逗号分隔")
|
||
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
||
print(" 今天,天气 (表示已确认两个词)")
|
||
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
|
||
|
||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
# 解析槽位历史
|
||
slot_chars = []
|
||
if slot_input:
|
||
slot_chars = [
|
||
char.strip() for char in slot_input.split(",") if char.strip()
|
||
]
|
||
|
||
print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else '无'}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("📝 输入汇总:")
|
||
print(f" 上下文提示: {context_prompts if context_prompts else '无'}")
|
||
print(f" 光标前文本: '{text_before}'")
|
||
print(f" 光标后文本: '{text_after}'")
|
||
print(f" 拼音: '{pinyin}'")
|
||
print(f" 槽位历史: {slot_chars if slot_chars else '无'}")
|
||
|
||
# 执行推理
|
||
print("\n🔮 推理中...")
|
||
predictions, inference_time = self.predict(
|
||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||
)
|
||
|
||
# 显示结果
|
||
print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)")
|
||
print("\n🏆 Top-20 预测结果:")
|
||
print("-" * 50)
|
||
for i, (char, prob, idx) in enumerate(predictions):
|
||
if char == "//":
|
||
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
|
||
else:
|
||
print(
|
||
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
|
||
)
|
||
|
||
# 显示原始拼音对应的可能汉字
|
||
if pinyin and self.query_engine:
|
||
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
|
||
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
|
||
if pinyin_results:
|
||
for j, (pid, char, count) in enumerate(pinyin_results):
|
||
print(f" {char} (频次: {count:,})")
|
||
else:
|
||
print(" (无匹配结果)")
|
||
|
||
# 询问是否继续
|
||
print("\n" + "-" * 40)
|
||
continue_input = (
|
||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||
)
|
||
if continue_input not in ["y", "yes", ""]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n退出交互模式")
|
||
break
|
||
except Exception as e:
|
||
print(f"\n❌ 推理出错: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
# 询问是否继续
|
||
continue_input = (
|
||
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
|
||
)
|
||
if continue_input not in ["y", "yes", ""]:
|
||
print("退出交互模式")
|
||
break
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="输入法模型推理")
|
||
parser.add_argument(
|
||
"--checkpoint", type=str, required=True, help="模型checkpoint路径"
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="auto",
|
||
choices=["auto", "cpu", "cuda"],
|
||
help="推理设备 (默认: auto)",
|
||
)
|
||
parser.add_argument(
|
||
"--interactive", action="store_true", default=True, help="交互模式 (默认: True)"
|
||
)
|
||
parser.add_argument("--test", action="store_true", help="运行测试推理")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 选择设备
|
||
if args.device == "auto":
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
else:
|
||
device = args.device
|
||
|
||
# 初始化推理器
|
||
inference = InputMethodInference(args.checkpoint, device)
|
||
|
||
# 测试推理
|
||
if args.test:
|
||
print("\n🧪 运行测试推理...")
|
||
print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'")
|
||
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
|
||
test_predictions, test_time = inference.predict(
|
||
context_prompts=["张三", "李四"],
|
||
text_before="今天天气",
|
||
text_after="很好",
|
||
pinyin="tian",
|
||
slot_chars=["上"], # 用户已确认输入"上"
|
||
)
|
||
print(f"测试推理耗时: {test_time:.2f}ms")
|
||
print(f"Top-5 结果:")
|
||
for i, (char, prob, idx) in enumerate(test_predictions[:5]):
|
||
if char == "//":
|
||
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
|
||
else:
|
||
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
|
||
|
||
# 交互模式
|
||
if args.interactive:
|
||
inference.interactive_mode()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|