SUimeModelTraner/inference.py

603 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
输入法模型推理脚本
使用方法:
python inference.py --checkpoint ./output/checkpoints/best_model.pt
交互模式: 分步询问输入
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
2. 光标前文本: 光标前的连续文本
3. 光标后文本: 光标后的连续文本
4. 拼音: 当前输入的拼音
5. 槽位历史: 用户已确认的输入历史如输入shanghai已确认""
示例场景:
输入"shanghai"已确认"",继续输入"tian"
上下文提示: 张三,李四
光标前文本: 今天天气很好
光标后文本: 我们去公园玩
拼音: tian
槽位历史: 上
"""
import argparse
import time
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from modelscope import AutoTokenizer
from src.model.dataset import text_to_pinyin_ids
from src.model.model import InputMethodEngine
from src.model.query import QueryEngine
class InputMethodInference:
"""输入法模型推理器"""
def __init__(
self,
checkpoint_path: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
self.device = torch.device(device)
self.checkpoint_path = checkpoint_path
# 加载组件
print(f"正在加载模型从: {checkpoint_path}")
self.load_model()
# 加载tokenizer
print("正在加载tokenizer...")
self.load_tokenizer()
# 加载查询引擎
print("正在加载查询引擎...")
self.load_query_engine()
print(f"✅ 推理器初始化完成 (设备: {self.device})")
def load_model(self):
"""加载训练好的模型"""
# 创建模型实例(不编译)
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
# 加载checkpoint
# 加载训练好的权重强制先加载到CPU再移动到目标设备
# 这样确保GPU训练的权重能正确转换到CPU
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
if "model_state_dict" in checkpoint:
self.model.load_state_dict(checkpoint["model_state_dict"])
else:
self.model.load_state_dict(checkpoint)
self.model.eval()
self.model.to(self.device)
print(
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
)
def load_tokenizer(self):
"""加载tokenizer"""
try:
# 从assets/tokenizer加载tokenizer
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"✅ Tokenizer加载完成词汇表大小: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"⚠️ 无法加载tokenizer: {e}")
print("使用默认的bert-base-chinese tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def load_query_engine(self):
"""加载查询引擎用于字符-ID转换"""
try:
self.query_engine = QueryEngine()
# 加载默认的统计文件
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
self.query_engine.load(stats_path)
print(
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
)
else:
print(f"⚠️ 统计文件不存在: {stats_path}")
self.query_engine = None
except Exception as e:
print(f"⚠️ 无法加载查询引擎: {e}")
self.query_engine = None
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
"""将汉字转换为ID如果提供拼音则更精确"""
# 处理结束符
if char == "//":
return 0 # 假设0是结束符ID
if self.query_engine is None:
# 简单回退使用unicode编码
return ord(char) if len(char) == 1 else 0
try:
if pinyin is not None:
# 使用精确的字符-拼音对
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
if info:
return info.id
# 回退:获取字符的第一个拼音变体
results = self.query_engine.query_by_char(char, limit=1)
if results:
return results[0][0] # 返回ID
return 0
except:
return 0
def id_to_char(self, id: int) -> str:
"""将ID转换为汉字"""
# 处理结束符ID (假设0是结束符)
if id == 0:
return "//"
if self.query_engine is None:
return chr(id) if id < 0x110000 else "<?>"
try:
info = self.query_engine.query_by_id(id)
return info.char if info else f"[ID:{id}]"
except:
return f"[ID:{id}]"
def prepare_inputs(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
max_seq_len: int = 128,
):
"""
准备模型输入
Args:
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表(用户已确认的输入历史)
max_seq_len: 最大序列长度
Returns:
模型输入字典
"""
# 1. 构建tokenizer输入
# 根据dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before
# part3: text_after
# 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize
encoded = self.tokenizer(
input_text,
text_after,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
# 3. 处理拼音输入
pinyin_ids = text_to_pinyin_ids(pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long)
# 4. 处理历史槽位(用户已确认的输入历史)
history_slot_ids = []
for char in slot_chars:
# 为每个槽位汉字查找ID用户已确认的输入历史
char_id = self.char_to_id(char)
history_slot_ids.append(char_id)
# 填充到8个槽位
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
history_tensor = torch.tensor([history_slot_ids], dtype=torch.long)
# 5. 移动到设备
inputs = {
"input_ids": encoded["input_ids"].to(self.device),
"token_type_ids": encoded["token_type_ids"].to(self.device),
"attention_mask": encoded["attention_mask"].to(self.device),
"pinyin_ids": pinyin_tensor.to(self.device),
"history_slot_ids": history_tensor.to(self.device),
}
return inputs
def predict(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
top_k: int = 20,
) -> Tuple[List[Tuple[str, float, int]], float]:
"""
执行推理
Args:
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表用户已确认的输入历史最大8个
top_k: 返回top-k个预测结果
Returns:
(predictions, inference_time_ms)
predictions: List[Tuple[char, score, id]]
"""
start_time = time.perf_counter()
# 准备输入
inputs = self.prepare_inputs(
context_prompts, text_before, text_after, pinyin, slot_chars
)
# 调试信息:打印输入形状
print("\n🔍 调试信息 - 输入检查:")
for key, tensor in inputs.items():
print(
f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}"
)
if key in ["history_slot_ids", "pinyin_ids"]:
print(f" 值: {tensor.cpu().numpy().tolist()}")
# 特别检查拼音和槽位输入
print("\n🔍 拼音输入详细分析:")
pinyin_tensor = inputs["pinyin_ids"]
pinyin_values = pinyin_tensor.cpu().numpy()[0]
# 将拼音ID转换回字符
from src.model.dataset import text_to_pinyin_ids
# 逆向转换可能需要一个反向映射,这里先简单显示
print(f" 拼音ID列表: {pinyin_values}")
print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}")
print("\n🔍 槽位历史详细分析:")
slot_tensor = inputs["history_slot_ids"]
slot_values = slot_tensor.cpu().numpy()[0]
print(f" 槽位ID列表: {slot_values}")
print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}")
# 将ID转换为汉字
slot_chars_converted = [self.id_to_char(id) for id in slot_values]
print(f" 槽位汉字: {slot_chars_converted}")
# 调试:检查模型权重
print(f"\n🔍 调试信息 - 模型检查:")
print(f" 模型设备: {self.device}")
print(f" 模型是否训练模式: {self.model.training}")
# 检查模型词汇表大小
vocab_size = self.model.vocab_size
print(f" 模型词汇表大小: {vocab_size}")
# 检查前5个ID对应的字符
print(f" 前5个ID对应的字符:")
for i in range(1, 6):
char = self.id_to_char(i)
print(f" ID {i}: '{char}'")
# 检查分类头权重
# 调试:检查分类头权重
classifier_weight = self.model.classifier.weight.data
classifier_bias = self.model.classifier.bias.data
print(f" 分类头权重形状: {classifier_weight.shape}")
print(
f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]"
)
print(f" 分类头权重均值: {classifier_weight.mean():.4f}")
print(f" 分类头权重标准差: {classifier_weight.std():.4f}")
print(f" 分类头偏置形状: {classifier_bias.shape}")
print(
f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]"
)
print(f" 分类头偏置均值: {classifier_bias.mean():.4f}")
print(f" 分类头偏置标准差: {classifier_bias.std():.4f}")
# 推理CPU推理时禁用混合精度
with torch.no_grad():
if self.device.type == "cuda":
with torch.autocast(device_type="cuda"):
logits = self.model(**inputs)
else:
# CPU推理时不使用autocast
logits = self.model(**inputs)
# 调试检查logits
print(f"\n🔍 调试信息 - 输出检查:")
print(f" Logits形状: {logits.shape}")
print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
print(f" Logits均值: {logits.mean():.4f}")
# 检查logits中最大值和对应的ID
max_val, max_idx = torch.max(logits, dim=-1)
print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}")
# 获取top-k预测
probs = F.softmax(logits, dim=-1)
top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
# 调试:检查概率分布
print(f" 概率总和: {probs.sum().item():.4f}")
top_probs_array = top_probs.cpu().numpy().flatten()
top_indices_array = top_indices.cpu().numpy().flatten()
print(f" Top-{top_k}概率: {top_probs_array}")
print(f" Top-{top_k} ID: {top_indices_array}")
# 检查概率是否均匀分布
print(f" 概率分布分析:")
print(f" 平均概率: {probs.mean().item():.6f}")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(f" 标准差: {probs.std().item():.6f}")
# 检查top-20概率是否都很小
if top_probs_array[0] < 0.01:
print(
f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练"
)
print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配")
inference_time_ms = (time.perf_counter() - start_time) * 1000
# 转换为可读结果
predictions = []
for i in range(top_k):
idx = int(top_indices[0, i].item())
prob = top_probs[0, i].item()
char = self.id_to_char(idx)
predictions.append((char, prob, idx))
return predictions, inference_time_ms
def interactive_mode(self):
"""交互式推理模式 - 分步询问输入"""
print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式")
print("=" * 60)
print("说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
print(" - 拼音: 当前输入的拼音")
print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认''")
print("提示: 输入 'quit''exit''q' 可随时退出")
print("-" * 60)
while True:
try:
print("\n" + "=" * 60)
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学")
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
# 解析上下文提示
context_prompts = []
if context_input:
context_prompts = [
item.strip()
for item in context_input.split(",")
if item.strip()
]
print(
f"✅ 已记录上下文提示: {context_prompts if context_prompts else ''}"
)
print("\n" + "-" * 40)
print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好")
text_before = input("请输入光标前文本: ").strip()
if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录光标前文本: '{text_before}'")
print("\n" + "-" * 40)
print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩")
text_after = input("请输入光标后文本: ").strip()
if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录光标后文本: '{text_after}'")
print("\n" + "-" * 40)
print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao")
pinyin = input("请输入拼音: ").strip()
if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录拼音: '{pinyin}'")
print("\n" + "-" * 40)
print("第5步: 槽位历史(已确认的输入)")
print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)")
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
# 解析槽位历史
slot_chars = []
if slot_input:
slot_chars = [
char.strip() for char in slot_input.split(",") if char.strip()
]
print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else ''}")
print("\n" + "=" * 60)
print("📝 输入汇总:")
print(f" 上下文提示: {context_prompts if context_prompts else ''}")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars if slot_chars else ''}")
# 执行推理
print("\n🔮 推理中...")
predictions, inference_time = self.predict(
context_prompts, text_before, text_after, pinyin, slot_chars
)
# 显示结果
print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)")
print("\n🏆 Top-20 预测结果:")
print("-" * 50)
for i, (char, prob, idx) in enumerate(predictions):
if char == "//":
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
else:
print(
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
)
# 显示原始拼音对应的可能汉字
if pinyin and self.query_engine:
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
if pinyin_results:
for j, (pid, char, count) in enumerate(pinyin_results):
print(f" {char} (频次: {count:,})")
else:
print(" (无匹配结果)")
# 询问是否继续
print("\n" + "-" * 40)
continue_input = input("是否继续推理?(y/n): ").strip().lower()
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
except KeyboardInterrupt:
print("\n\n退出交互模式")
break
except Exception as e:
print(f"\n❌ 推理出错: {e}")
import traceback
traceback.print_exc()
# 询问是否继续
continue_input = input("\n是否继续?(y/n): ").strip().lower()
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
def main():
parser = argparse.ArgumentParser(description="输入法模型推理")
parser.add_argument(
"--checkpoint", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cpu", "cuda"],
help="推理设备 (默认: auto)",
)
parser.add_argument(
"--interactive", action="store_true", default=True, help="交互模式 (默认: True)"
)
parser.add_argument("--test", action="store_true", help="运行测试推理")
args = parser.parse_args()
# 选择设备
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# 初始化推理器
inference = InputMethodInference(args.checkpoint, device)
# 测试推理
if args.test:
print("\n🧪 运行测试推理...")
print("测试场景: 输入'shanghai',已确认第一个字'',继续输入'tian'")
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
test_predictions, test_time = inference.predict(
context_prompts=["张三", "李四"],
text_before="今天天气",
text_after="很好",
pinyin="tian",
slot_chars=[""], # 用户已确认输入"上"
)
print(f"测试推理耗时: {test_time:.2f}ms")
print(f"Top-5 结果:")
for i, (char, prob, idx) in enumerate(test_predictions[:5]):
if char == "//":
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
else:
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
# 交互模式
if args.interactive:
inference.interactive_mode()
if __name__ == "__main__":
main()