SUimeModelTraner/inference.py

714 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
输入法模型推理脚本
使用方法:
python inference.py --checkpoint ./output/checkpoints/best_model.pt
交互模式: 分步询问输入
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
2. 光标前文本: 光标前的连续文本
3. 光标后文本: 光标后的连续文本
4. 拼音: 当前输入的拼音
5. 槽位历史: 用户已确认的输入历史如输入shanghai已确认""
示例场景:
输入"shanghai"已确认"",继续输入"tian"
上下文提示: 张三,李四
光标前文本: 今天天气很好
光标后文本: 我们去公园玩
拼音: tian
槽位历史: 上
"""
import argparse
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from modelscope import AutoTokenizer
from src.model.dataset import text_to_pinyin_ids
from src.model.model import InputMethodEngine
from src.model.query import QueryEngine
class InputMethodInference:
"""输入法模型推理器"""
def __init__(
self,
checkpoint_path: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
self.device = torch.device(device)
self.checkpoint_path = checkpoint_path
# 加载组件
print(f"正在加载模型从: {checkpoint_path}")
self.load_model()
# 加载tokenizer
print("正在加载tokenizer...")
self.load_tokenizer()
# 加载查询引擎
print("正在加载查询引擎...")
self.load_query_engine()
print(f"✅ 推理器初始化完成 (设备: {self.device})")
# 尝试启用readline以获得更好的行编辑功能
try:
import readline
# 设置readline使用UTF-8编码
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
print("📝 readline已启用支持更好的行编辑功能")
except ImportError:
print("📝 readline不可用使用标准输入")
def _safe_input(self, prompt: str, default: str = "") -> str:
"""
安全的输入函数尝试正确处理UTF-8字符和退格键
Args:
prompt: 提示文本
default: 默认值
Returns:
用户输入的字符串
"""
try:
# 显示提示和默认值
if default:
full_prompt = f"{prompt} [{default}]: "
else:
full_prompt = f"{prompt}: "
# 使用标准input
result = input(full_prompt)
# 如果用户直接回车且存在默认值,则返回默认值
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
# 用户按Ctrl+D或Ctrl+C
print()
return ""
except Exception as e:
# 其他错误
print(f"\n⚠️ 输入错误: {e}")
return ""
def _clean_pinyin_input(self, pinyin: str) -> str:
"""
清理拼音输入字符串,处理退格键等特殊字符
拼音只允许: a-z, `, ', -
中文字符和其他字符会被忽略
Args:
pinyin: 原始拼音输入字符串
Returns:
清理后的拼音字符串
"""
if not pinyin:
return ""
result = []
for c in pinyin:
# 检查是否为合法拼音字符 (a-z, `, ', -)
# 注意: 中文字符的isalpha()也返回True所以需要额外检查
is_valid_pinyin_char = (
("a" <= c <= "z")
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
or c in ["`", "'", "-"]
)
if is_valid_pinyin_char:
# 合法拼音字符,转换为小写
result.append(c.lower())
elif c == " ":
# 空格忽略
continue
elif c == "\b" or c == "\x7f" or c == "\x08":
# 退格键、删除键:删除前一个字符
if result:
result.pop()
elif c == "\x1b":
# ESC键清空所有输入
result.clear()
else:
# 其他字符(包括中文字符)忽略
# 注意这里不添加到result所以退格键无法删除它们
# 但用户可能在拼音输入中误输入中文字符,应该忽略
pass
return "".join(result)
def load_model(self):
"""加载训练好的模型"""
# 创建模型实例(不编译)
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
# 加载checkpoint
# 加载训练好的权重强制先加载到CPU再移动到目标设备
# 这样确保GPU训练的权重能正确转换到CPU
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
if "model_state_dict" in checkpoint:
self.model.load_state_dict(checkpoint["model_state_dict"])
else:
self.model.load_state_dict(checkpoint)
self.model.eval()
self.model.to(self.device)
print(
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
)
def load_tokenizer(self):
"""加载tokenizer"""
try:
# 从assets/tokenizer加载tokenizer
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"✅ Tokenizer加载完成词汇表大小: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"⚠️ 无法加载tokenizer: {e}")
print("使用默认的bert-base-chinese tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def load_query_engine(self):
"""加载查询引擎用于字符-ID转换"""
try:
self.query_engine = QueryEngine()
# 加载默认的统计文件
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
self.query_engine.load(stats_path)
print(
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
)
else:
print(f"⚠️ 统计文件不存在: {stats_path}")
self.query_engine = None
except Exception as e:
print(f"⚠️ 无法加载查询引擎: {e}")
self.query_engine = None
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
"""将汉字转换为ID如果提供拼音则更精确"""
# 处理结束符
if char == "//":
return 0 # 假设0是结束符ID
if self.query_engine is None:
# 简单回退使用unicode编码
return ord(char) if len(char) == 1 else 0
try:
if pinyin is not None:
# 使用精确的字符-拼音对
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
if info:
return info.id
# 回退:获取字符的第一个拼音变体
results = self.query_engine.query_by_char(char, limit=1)
if results:
return results[0][0] # 返回ID
return 0
except:
return 0
def id_to_char(self, id: int) -> str:
"""将ID转换为汉字"""
# 处理结束符ID (假设0是结束符)
if id == 0:
return "//"
if self.query_engine is None:
return chr(id) if id < 0x110000 else "<?>"
try:
info = self.query_engine.query_by_id(id)
return info.char if info else f"[ID:{id}]"
except:
return f"[ID:{id}]"
def prepare_inputs(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
max_seq_len: int = 128,
):
"""
准备模型输入
Args:
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表(用户已确认的输入历史)
max_seq_len: 最大序列长度
Returns:
模型输入字典
"""
# 1. 构建tokenizer输入
# 根据test.py和dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before
# part3: text_after
# 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本 - 与test.py保持一致
# test.py: f"{part4}|{part1}" 作为第一个参数part3作为第二个参数
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize - 与test.py保持一致
encoded = self.tokenizer(
input_text,
text_after,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
# 3. 处理拼音输入 - 与test.py保持一致
# 首先清理拼音字符串,处理退格键等特殊字符
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long)
# 4. 处理历史槽位(用户已确认的输入历史)
history_slot_ids = []
for char in slot_chars:
# 为每个槽位汉字查找ID用户已确认的输入历史
char_id = self.char_to_id(char)
history_slot_ids.append(char_id)
# 填充到8个槽位
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
history_tensor = torch.tensor([history_slot_ids], dtype=torch.long)
# 5. 移动到设备
inputs = {
"input_ids": encoded["input_ids"].to(self.device),
"token_type_ids": encoded["token_type_ids"].to(self.device),
"attention_mask": encoded["attention_mask"].to(self.device),
"pinyin_ids": pinyin_tensor.to(self.device),
"history_slot_ids": history_tensor.to(self.device),
}
return inputs
def predict(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
top_k: int = 20,
) -> Tuple[List[Tuple[str, float, int]], float]:
"""
执行推理
Args:
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表用户已确认的输入历史最大8个
top_k: 返回top-k个预测结果
Returns:
(predictions, inference_time_ms)
predictions: List[Tuple[char, score, id]]
"""
start_time = time.perf_counter()
# 准备输入
inputs = self.prepare_inputs(
context_prompts, text_before, text_after, pinyin, slot_chars
)
# 调试信息:打印输入形状
print("\n🔍 调试信息 - 输入检查:")
for key, tensor in inputs.items():
print(
f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}"
)
if key in ["history_slot_ids", "pinyin_ids"]:
print(f" 值: {tensor.cpu().numpy().tolist()}")
# 特别检查拼音和槽位输入
print("\n🔍 拼音输入详细分析:")
pinyin_tensor = inputs["pinyin_ids"]
pinyin_values = pinyin_tensor.cpu().numpy()[0]
# 将拼音ID转换回字符
from src.model.dataset import text_to_pinyin_ids
# 逆向转换可能需要一个反向映射,这里先简单显示
print(f" 拼音ID列表: {pinyin_values}")
print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}")
print("\n🔍 槽位历史详细分析:")
slot_tensor = inputs["history_slot_ids"]
slot_values = slot_tensor.cpu().numpy()[0]
print(f" 槽位ID列表: {slot_values}")
print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}")
# 将ID转换为汉字
slot_chars_converted = [self.id_to_char(id) for id in slot_values]
print(f" 槽位汉字: {slot_chars_converted}")
# 调试:检查模型权重
print(f"\n🔍 调试信息 - 模型检查:")
print(f" 模型设备: {self.device}")
print(f" 模型是否训练模式: {self.model.training}")
# 检查模型词汇表大小
vocab_size = self.model.vocab_size
print(f" 模型词汇表大小: {vocab_size}")
# 检查前5个ID对应的字符
print(f" 前5个ID对应的字符:")
for i in range(1, 6):
char = self.id_to_char(i)
print(f" ID {i}: '{char}'")
# 检查分类头权重
# 调试:检查分类头权重
classifier_weight = self.model.classifier.weight.data
classifier_bias = self.model.classifier.bias.data
print(f" 分类头权重形状: {classifier_weight.shape}")
print(
f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]"
)
print(f" 分类头权重均值: {classifier_weight.mean():.4f}")
print(f" 分类头权重标准差: {classifier_weight.std():.4f}")
print(f" 分类头偏置形状: {classifier_bias.shape}")
print(
f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]"
)
print(f" 分类头偏置均值: {classifier_bias.mean():.4f}")
print(f" 分类头偏置标准差: {classifier_bias.std():.4f}")
# 推理CPU推理时禁用混合精度
with torch.no_grad():
if self.device.type == "cuda":
with torch.autocast(device_type="cuda"):
logits = self.model(**inputs)
else:
# CPU推理时不使用autocast
logits = self.model(**inputs)
# 调试检查logits
print(f"\n🔍 调试信息 - 输出检查:")
print(f" Logits形状: {logits.shape}")
print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
print(f" Logits均值: {logits.mean():.4f}")
# 检查logits中最大值和对应的ID
max_val, max_idx = torch.max(logits, dim=-1)
print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}")
# 获取top-k预测
probs = F.softmax(logits, dim=-1)
top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
# 调试:检查概率分布
print(f" 概率总和: {probs.sum().item():.4f}")
top_probs_array = top_probs.cpu().numpy().flatten()
top_indices_array = top_indices.cpu().numpy().flatten()
print(f" Top-{top_k}概率: {top_probs_array}")
print(f" Top-{top_k} ID: {top_indices_array}")
# 检查概率是否均匀分布
print(f" 概率分布分析:")
print(f" 平均概率: {probs.mean().item():.6f}")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(f" 标准差: {probs.std().item():.6f}")
# 检查top-20概率是否都很小
if top_probs_array[0] < 0.01:
print(
f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练"
)
print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配")
inference_time_ms = (time.perf_counter() - start_time) * 1000
# 转换为可读结果
predictions = []
for i in range(top_k):
idx = int(top_indices[0, i].item())
prob = top_probs[0, i].item()
char = self.id_to_char(idx)
predictions.append((char, prob, idx))
return predictions, inference_time_ms
def interactive_mode(self):
"""交互式推理模式 - 分步询问输入"""
print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式")
print("=" * 60)
# 检查终端编码
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
if encoding.lower() not in ["utf-8", "utf8"]:
print("⚠️ 警告: 终端编码不是UTF-8中文输入可能有问题")
print(" 建议设置: export LANG=en_US.UTF-8")
print(" 或设置: export LC_ALL=en_US.UTF-8")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
print(" - 拼音: 当前输入的拼音")
print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认''")
print("提示: 输入 'quit''exit''q' 可随时退出")
print("-" * 60)
while True:
try:
print("\n" + "=" * 60)
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学")
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
# 解析上下文提示
context_prompts = []
if context_input:
context_prompts = [
item.strip()
for item in context_input.split(",")
if item.strip()
]
print(
f"✅ 已记录上下文提示: {context_prompts if context_prompts else ''}"
)
print("\n" + "-" * 40)
print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好")
text_before = self._safe_input("请输入光标前文本")
if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录光标前文本: '{text_before}'")
print("\n" + "-" * 40)
print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩")
text_after = self._safe_input("请输入光标后文本")
if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录光标后文本: '{text_after}'")
print("\n" + "-" * 40)
print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao")
pinyin = self._safe_input("请输入拼音")
if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录拼音: '{pinyin}'")
print("\n" + "-" * 40)
print("第5步: 槽位历史(已确认的输入)")
print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)")
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
# 解析槽位历史
slot_chars = []
if slot_input:
slot_chars = [
char.strip() for char in slot_input.split(",") if char.strip()
]
print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else ''}")
print("\n" + "=" * 60)
print("📝 输入汇总:")
print(f" 上下文提示: {context_prompts if context_prompts else ''}")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars if slot_chars else ''}")
# 执行推理
print("\n🔮 推理中...")
predictions, inference_time = self.predict(
context_prompts, text_before, text_after, pinyin, slot_chars
)
# 显示结果
print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)")
print("\n🏆 Top-20 预测结果:")
print("-" * 50)
for i, (char, prob, idx) in enumerate(predictions):
if char == "//":
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
else:
print(
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
)
# 显示原始拼音对应的可能汉字
if pinyin and self.query_engine:
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
if pinyin_results:
for j, (pid, char, count) in enumerate(pinyin_results):
print(f" {char} (频次: {count:,})")
else:
print(" (无匹配结果)")
# 询问是否继续
print("\n" + "-" * 40)
continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
except KeyboardInterrupt:
print("\n\n退出交互模式")
break
except Exception as e:
print(f"\n❌ 推理出错: {e}")
import traceback
traceback.print_exc()
# 询问是否继续
continue_input = (
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
def main():
parser = argparse.ArgumentParser(description="输入法模型推理")
parser.add_argument(
"--checkpoint", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cpu", "cuda"],
help="推理设备 (默认: auto)",
)
parser.add_argument(
"--interactive", action="store_true", default=True, help="交互模式 (默认: True)"
)
parser.add_argument("--test", action="store_true", help="运行测试推理")
args = parser.parse_args()
# 选择设备
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# 初始化推理器
inference = InputMethodInference(args.checkpoint, device)
# 测试推理
if args.test:
print("\n🧪 运行测试推理...")
print("测试场景: 输入'shanghai',已确认第一个字'',继续输入'tian'")
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
test_predictions, test_time = inference.predict(
context_prompts=["张三", "李四"],
text_before="今天天气",
text_after="很好",
pinyin="tian",
slot_chars=[""], # 用户已确认输入"上"
)
print(f"测试推理耗时: {test_time:.2f}ms")
print(f"Top-5 结果:")
for i, (char, prob, idx) in enumerate(test_predictions[:5]):
if char == "//":
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
else:
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
# 交互模式
if args.interactive:
inference.interactive_mode()
if __name__ == "__main__":
main()