#!/usr/bin/env python3 """ 输入法模型推理脚本 使用方法: python inference.py --checkpoint ./output/checkpoints/best_model.pt 交互模式: 分步询问输入 1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空) 2. 光标前文本: 光标前的连续文本 3. 光标后文本: 光标后的连续文本 4. 拼音: 当前输入的拼音 5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上") 示例场景: 输入"shanghai"已确认"上",继续输入"tian" 上下文提示: 张三,李四 光标前文本: 今天天气很好 光标后文本: 我们去公园玩 拼音: tian 槽位历史: 上 """ import argparse import time from pathlib import Path from typing import List, Optional, Tuple import torch import torch.nn.functional as F from modelscope import AutoTokenizer from src.model.dataset import text_to_pinyin_ids from src.model.model import InputMethodEngine from src.model.query import QueryEngine class InputMethodInference: """输入法模型推理器""" def __init__( self, checkpoint_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): self.device = torch.device(device) self.checkpoint_path = checkpoint_path # 加载组件 print(f"正在加载模型从: {checkpoint_path}") self.load_model() # 加载tokenizer print("正在加载tokenizer...") self.load_tokenizer() # 加载查询引擎 print("正在加载查询引擎...") self.load_query_engine() print(f"✅ 推理器初始化完成 (设备: {self.device})") def load_model(self): """加载训练好的模型""" # 创建模型实例(不编译) self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False) # 加载checkpoint # 加载训练好的权重(强制先加载到CPU,再移动到目标设备) # 这样确保GPU训练的权重能正确转换到CPU checkpoint = torch.load(self.checkpoint_path, map_location="cpu") if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: self.model.load_state_dict(checkpoint) self.model.eval() self.model.to(self.device) print( f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}" ) def load_tokenizer(self): """加载tokenizer""" try: # 从assets/tokenizer加载tokenizer tokenizer_path = ( Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" ) self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") except Exception as e: print(f"⚠️ 无法加载tokenizer: {e}") print("使用默认的bert-base-chinese tokenizer") self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") def load_query_engine(self): """加载查询引擎用于字符-ID转换""" try: self.query_engine = QueryEngine() # 加载默认的统计文件 stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): self.query_engine.load(stats_path) print( f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}" ) else: print(f"⚠️ 统计文件不存在: {stats_path}") self.query_engine = None except Exception as e: print(f"⚠️ 无法加载查询引擎: {e}") self.query_engine = None def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int: """将汉字转换为ID,如果提供拼音则更精确""" # 处理结束符 if char == "//": return 0 # 假设0是结束符ID if self.query_engine is None: # 简单回退:使用unicode编码 return ord(char) if len(char) == 1 else 0 try: if pinyin is not None: # 使用精确的字符-拼音对 info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin) if info: return info.id # 回退:获取字符的第一个拼音变体 results = self.query_engine.query_by_char(char, limit=1) if results: return results[0][0] # 返回ID return 0 except: return 0 def id_to_char(self, id: int) -> str: """将ID转换为汉字""" # 处理结束符ID (假设0是结束符) if id == 0: return "//" if self.query_engine is None: return chr(id) if id < 0x110000 else "" try: info = self.query_engine.query_by_id(id) return info.char if info else f"[ID:{id}]" except: return f"[ID:{id}]" def prepare_inputs( self, context_prompts: List[str], text_before: str, text_after: str, pinyin: str, slot_chars: List[str], max_seq_len: int = 128, ): """ 准备模型输入 Args: context_prompts: 上下文提示(专有词汇、姓名等,用|分隔) text_before: 光标前文本 text_after: 光标后文本 pinyin: 当前输入的拼音 slot_chars: 槽位内的汉字列表(用户已确认的输入历史) max_seq_len: 最大序列长度 Returns: 模型输入字典 """ # 1. 构建tokenizer输入 # 根据dataset.py,格式为: "part4|part1" 和 part3 # part4: 上下文提示(专有词汇、姓名等,模型不掌握) # part1: text_before # part3: text_after # 处理上下文提示 context_text = "|".join(context_prompts) if context_prompts else "" # 构建输入文本 if context_text: input_text = f"{context_text}|{text_before}" else: input_text = text_before # 2. Tokenize encoded = self.tokenizer( input_text, text_after, max_length=max_seq_len, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) # 3. 处理拼音输入 pinyin_ids = text_to_pinyin_ids(pinyin) if len(pinyin_ids) < 24: pinyin_ids.extend([0] * (24 - len(pinyin_ids))) else: pinyin_ids = pinyin_ids[:24] pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long) # 4. 处理历史槽位(用户已确认的输入历史) history_slot_ids = [] for char in slot_chars: # 为每个槽位汉字查找ID(用户已确认的输入历史) char_id = self.char_to_id(char) history_slot_ids.append(char_id) # 填充到8个槽位 if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] history_tensor = torch.tensor([history_slot_ids], dtype=torch.long) # 5. 移动到设备 inputs = { "input_ids": encoded["input_ids"].to(self.device), "token_type_ids": encoded["token_type_ids"].to(self.device), "attention_mask": encoded["attention_mask"].to(self.device), "pinyin_ids": pinyin_tensor.to(self.device), "history_slot_ids": history_tensor.to(self.device), } return inputs def predict( self, context_prompts: List[str], text_before: str, text_after: str, pinyin: str, slot_chars: List[str], top_k: int = 20, ) -> Tuple[List[Tuple[str, float, int]], float]: """ 执行推理 Args: context_prompts: 上下文提示(专有词汇、姓名等,用|分隔) text_before: 光标前文本 text_after: 光标后文本 pinyin: 当前输入的拼音 slot_chars: 槽位内的汉字列表(用户已确认的输入历史,最大8个) top_k: 返回top-k个预测结果 Returns: (predictions, inference_time_ms) predictions: List[Tuple[char, score, id]] """ start_time = time.perf_counter() # 准备输入 inputs = self.prepare_inputs( context_prompts, text_before, text_after, pinyin, slot_chars ) # 调试信息:打印输入形状 print("\n🔍 调试信息 - 输入检查:") for key, tensor in inputs.items(): print( f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}" ) if key in ["history_slot_ids", "pinyin_ids"]: print(f" 值: {tensor.cpu().numpy().tolist()}") # 特别检查拼音和槽位输入 print("\n🔍 拼音输入详细分析:") pinyin_tensor = inputs["pinyin_ids"] pinyin_values = pinyin_tensor.cpu().numpy()[0] # 将拼音ID转换回字符 from src.model.dataset import text_to_pinyin_ids # 逆向转换可能需要一个反向映射,这里先简单显示 print(f" 拼音ID列表: {pinyin_values}") print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}") print("\n🔍 槽位历史详细分析:") slot_tensor = inputs["history_slot_ids"] slot_values = slot_tensor.cpu().numpy()[0] print(f" 槽位ID列表: {slot_values}") print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}") # 将ID转换为汉字 slot_chars_converted = [self.id_to_char(id) for id in slot_values] print(f" 槽位汉字: {slot_chars_converted}") # 调试:检查模型权重 print(f"\n🔍 调试信息 - 模型检查:") print(f" 模型设备: {self.device}") print(f" 模型是否训练模式: {self.model.training}") # 检查模型词汇表大小 vocab_size = self.model.vocab_size print(f" 模型词汇表大小: {vocab_size}") # 检查前5个ID对应的字符 print(f" 前5个ID对应的字符:") for i in range(1, 6): char = self.id_to_char(i) print(f" ID {i}: '{char}'") # 检查分类头权重 # 调试:检查分类头权重 classifier_weight = self.model.classifier.weight.data classifier_bias = self.model.classifier.bias.data print(f" 分类头权重形状: {classifier_weight.shape}") print( f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]" ) print(f" 分类头权重均值: {classifier_weight.mean():.4f}") print(f" 分类头权重标准差: {classifier_weight.std():.4f}") print(f" 分类头偏置形状: {classifier_bias.shape}") print( f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]" ) print(f" 分类头偏置均值: {classifier_bias.mean():.4f}") print(f" 分类头偏置标准差: {classifier_bias.std():.4f}") # 推理(CPU推理时禁用混合精度) with torch.no_grad(): if self.device.type == "cuda": with torch.autocast(device_type="cuda"): logits = self.model(**inputs) else: # CPU推理时不使用autocast logits = self.model(**inputs) # 调试:检查logits print(f"\n🔍 调试信息 - 输出检查:") print(f" Logits形状: {logits.shape}") print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]") print(f" Logits均值: {logits.mean():.4f}") # 检查logits中最大值和对应的ID max_val, max_idx = torch.max(logits, dim=-1) print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}") # 获取top-k预测 probs = F.softmax(logits, dim=-1) top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1) # 调试:检查概率分布 print(f" 概率总和: {probs.sum().item():.4f}") top_probs_array = top_probs.cpu().numpy().flatten() top_indices_array = top_indices.cpu().numpy().flatten() print(f" Top-{top_k}概率: {top_probs_array}") print(f" Top-{top_k} ID: {top_indices_array}") # 检查概率是否均匀分布 print(f" 概率分布分析:") print(f" 平均概率: {probs.mean().item():.6f}") print(f" 最大概率: {probs.max().item():.6f}") print(f" 最小概率: {probs.min().item():.6f}") print(f" 标准差: {probs.std().item():.6f}") # 检查top-20概率是否都很小 if top_probs_array[0] < 0.01: print( f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练" ) print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配") inference_time_ms = (time.perf_counter() - start_time) * 1000 # 转换为可读结果 predictions = [] for i in range(top_k): idx = int(top_indices[0, i].item()) prob = top_probs[0, i].item() char = self.id_to_char(idx) predictions.append((char, prob, idx)) return predictions, inference_time_ms def interactive_mode(self): """交互式推理模式 - 分步询问输入""" print("\n" + "=" * 60) print("输入法模型推理 - 交互模式") print("=" * 60) print("说明:") print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") print(" - 光标前文本: 光标前的连续文本") print(" - 光标后文本: 光标后的连续文本") print(" - 拼音: 当前输入的拼音") print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认'上'") print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出") print("-" * 60) while True: try: print("\n" + "=" * 60) print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)") print("格式: 用逗号分隔多个词汇,可为空") print("示例: 张三,李四,北京大学") context_input = input("请输入上下文提示(直接回车跳过): ").strip() if context_input.lower() in ["quit", "exit", "q"]: print("退出交互模式") break # 解析上下文提示 context_prompts = [] if context_input: context_prompts = [ item.strip() for item in context_input.split(",") if item.strip() ] print( f"✅ 已记录上下文提示: {context_prompts if context_prompts else '无'}" ) print("\n" + "-" * 40) print("第2步: 光标前文本") print("说明: 光标前的连续文本内容") print("示例: 今天天气很好") text_before = input("请输入光标前文本: ").strip() if text_before.lower() in ["quit", "exit", "q"]: print("退出交互模式") break print(f"✅ 已记录光标前文本: '{text_before}'") print("\n" + "-" * 40) print("第3步: 光标后文本") print("说明: 光标后的连续文本内容") print("示例: 我们去公园玩") text_after = input("请输入光标后文本: ").strip() if text_after.lower() in ["quit", "exit", "q"]: print("退出交互模式") break print(f"✅ 已记录光标后文本: '{text_after}'") print("\n" + "-" * 40) print("第4步: 拼音输入") print("说明: 当前正在输入的拼音") print("示例: tian, shang, hao") pinyin = input("请输入拼音: ").strip() if pinyin.lower() in ["quit", "exit", "q"]: print("退出交互模式") break print(f"✅ 已记录拼音: '{pinyin}'") print("\n" + "-" * 40) print("第5步: 槽位历史(已确认的输入)") print("说明: 用户已确认的输入历史,用逗号分隔") print("示例: 上 (表示输入'shanghai'已确认'上')") print(" 今天,天气 (表示已确认两个词)") slot_input = input("请输入槽位历史(直接回车表示无): ").strip() if slot_input.lower() in ["quit", "exit", "q"]: print("退出交互模式") break # 解析槽位历史 slot_chars = [] if slot_input: slot_chars = [ char.strip() for char in slot_input.split(",") if char.strip() ] print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else '无'}") print("\n" + "=" * 60) print("📝 输入汇总:") print(f" 上下文提示: {context_prompts if context_prompts else '无'}") print(f" 光标前文本: '{text_before}'") print(f" 光标后文本: '{text_after}'") print(f" 拼音: '{pinyin}'") print(f" 槽位历史: {slot_chars if slot_chars else '无'}") # 执行推理 print("\n🔮 推理中...") predictions, inference_time = self.predict( context_prompts, text_before, text_after, pinyin, slot_chars ) # 显示结果 print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)") print("\n🏆 Top-20 预测结果:") print("-" * 50) for i, (char, prob, idx) in enumerate(predictions): if char == "//": print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}") else: print( f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}" ) # 显示原始拼音对应的可能汉字 if pinyin and self.query_engine: print(f"\n📖 拼音 '{pinyin}' 的常见汉字:") pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10) if pinyin_results: for j, (pid, char, count) in enumerate(pinyin_results): print(f" {char} (频次: {count:,})") else: print(" (无匹配结果)") # 询问是否继续 print("\n" + "-" * 40) continue_input = input("是否继续推理?(y/n): ").strip().lower() if continue_input not in ["y", "yes", ""]: print("退出交互模式") break except KeyboardInterrupt: print("\n\n退出交互模式") break except Exception as e: print(f"\n❌ 推理出错: {e}") import traceback traceback.print_exc() # 询问是否继续 continue_input = input("\n是否继续?(y/n): ").strip().lower() if continue_input not in ["y", "yes", ""]: print("退出交互模式") break def main(): parser = argparse.ArgumentParser(description="输入法模型推理") parser.add_argument( "--checkpoint", type=str, required=True, help="模型checkpoint路径" ) parser.add_argument( "--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="推理设备 (默认: auto)", ) parser.add_argument( "--interactive", action="store_true", default=True, help="交互模式 (默认: True)" ) parser.add_argument("--test", action="store_true", help="运行测试推理") args = parser.parse_args() # 选择设备 if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device # 初始化推理器 inference = InputMethodInference(args.checkpoint, device) # 测试推理 if args.test: print("\n🧪 运行测试推理...") print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'") print("上下文提示: 张三、李四(模型不掌握的专有名词)") test_predictions, test_time = inference.predict( context_prompts=["张三", "李四"], text_before="今天天气", text_after="很好", pinyin="tian", slot_chars=["上"], # 用户已确认输入"上" ) print(f"测试推理耗时: {test_time:.2f}ms") print(f"Top-5 结果:") for i, (char, prob, idx) in enumerate(test_predictions[:5]): if char == "//": print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}") else: print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}") # 交互模式 if args.interactive: inference.interactive_mode() if __name__ == "__main__": main()