#!/usr/bin/env python3 """ 输入法模型推理调试脚本 用于诊断为什么模型预测结果异常 """ from pathlib import Path import torch import torch.nn.functional as F def debug_model_checkpoint(checkpoint_path: str): """调试checkpoint文件""" print(f"\n🔍 调试checkpoint: {checkpoint_path}") if not Path(checkpoint_path).exists(): print(f"❌ Checkpoint文件不存在: {checkpoint_path}") return None # 加载checkpoint checkpoint = torch.load(checkpoint_path, map_location="cpu") print(f"Checkpoint类型: {type(checkpoint)}") if isinstance(checkpoint, dict): print(f"Checkpoint键名: {list(checkpoint.keys())}") # 检查是否有模型状态字典 if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] print(f"使用'model_state_dict'键,包含{len(state_dict)}个权重") elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] print(f"使用'state_dict'键,包含{len(state_dict)}个权重") else: # 可能是直接的状态字典 state_dict = checkpoint print(f"使用直接状态字典,包含{len(state_dict)}个权重") # 打印前10个权重键名和形状 print("\n前10个权重键名和形状:") for i, (key, value) in enumerate(list(state_dict.items())[:10]): print(f" {i + 1}. {key}: {value.shape}") # 特别检查分类头 classifier_keys = [k for k in state_dict.keys() if "classifier" in k] if classifier_keys: print(f"\n分类头相关权重:") for key in classifier_keys: weight = state_dict[key] print( f" {key}: shape={weight.shape}, range=[{weight.min():.4f}, {weight.max():.4f}]" ) else: print(f"\n⚠️ 未找到分类头权重") return state_dict else: print(f"❌ Checkpoint不是字典类型: {type(checkpoint)}") return None def debug_input_preparation(): """调试输入数据准备""" print("\n🔍 调试输入数据准备") # 测试text_to_pinyin_ids函数 from src.model.dataset import text_to_pinyin_ids test_pinyins = ["tian", "shang", "ha", "ni", "hao"] for pinyin in test_pinyins: ids = text_to_pinyin_ids(pinyin) print(f"拼音 '{pinyin}' -> ID列表: {ids}") # 测试tokenizer from modelscope import AutoTokenizer try: tokenizer_path = ( Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" ) tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) print(f"\n✅ Tokenizer加载成功,词汇表大小: {tokenizer.vocab_size}") # 测试tokenizer test_texts = ["今天天气", "我们去公园", "张三|李四|今天天气"] for text in test_texts: encoded = tokenizer( text, max_length=128, truncation=True, return_tensors="pt" ) print(f"文本 '{text}' -> input_ids形状: {encoded['input_ids'].shape}") except Exception as e: print(f"❌ Tokenizer加载失败: {e}") def debug_query_engine(): """调试查询引擎""" print("\n🔍 调试查询引擎") try: from src.model.query import QueryEngine query_engine = QueryEngine() stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): query_engine.load(stats_path) print(f"✅ 查询引擎加载成功") # 测试字符查询 test_chars = ["天", "上", "人", "好", "的"] for char in test_chars: results = query_engine.query_by_char(char, limit=3) if results: print(f"字符 '{char}': {[(r[0], r[1], r[2]) for r in results[:3]]}") else: print(f"字符 '{char}': 无结果") # 测试拼音查询 test_pinyins = ["tian", "shang", "ren", "hao", "de"] for pinyin in test_pinyins: results = query_engine.query_by_pinyin(pinyin, limit=3) if results: print( f"拼音 '{pinyin}': {[(r[0], r[1], r[2]) for r in results[:3]]}" ) else: print(f"拼音 '{pinyin}': 无结果") else: print(f"❌ 统计文件不存在: {stats_path}") except Exception as e: print(f"❌ 查询引擎调试失败: {e}") import traceback traceback.print_exc() def create_minimal_test_model(): """创建最小测试模型,检查基本功能""" print("\n🔍 创建最小测试模型") try: from src.model.model import InputMethodEngine # 创建小模型 model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=30, dim=64, # 小维度测试 num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=64, compile=False, ) print(f"✅ 最小模型创建成功") print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}") print(f" 分类头形状: {model.classifier.weight.shape}") # 测试前向传播 batch_size = 2 input_ids = torch.randint(0, 100, (batch_size, 64)) token_type_ids = torch.zeros((batch_size, 64), dtype=torch.long) attention_mask = torch.ones((batch_size, 64), dtype=torch.long) pinyin_ids = torch.randint(0, 30, (batch_size, 24)) history_slot_ids = torch.randint(0, 100, (batch_size, 8)) with torch.no_grad(): logits = model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) print(f" 前向传播成功,logits形状: {logits.shape}") print(f" logits范围: [{logits.min():.4f}, {logits.max():.4f}]") # 检查输出 probs = F.softmax(logits, dim=-1) print(f" 概率范围: [{probs.min():.4f}, {probs.max():.4f}]") return model except Exception as e: print(f"❌ 最小模型测试失败: {e}") import traceback traceback.print_exc() return None def main(): """主调试函数""" import argparse parser = argparse.ArgumentParser(description="输入法模型推理调试") parser.add_argument("--checkpoint", type=str, help="模型checkpoint路径") args = parser.parse_args() print("=" * 60) print("输入法模型推理调试工具") print("=" * 60) # 调试checkpoint if args.checkpoint: state_dict = debug_model_checkpoint(args.checkpoint) else: print("\n⚠️ 未提供checkpoint路径,跳过checkpoint调试") state_dict = None # 调试输入准备 debug_input_preparation() # 调试查询引擎 debug_query_engine() # 创建最小测试模型 model = create_minimal_test_model() print("\n" + "=" * 60) print("调试完成") print("=" * 60) if __name__ == "__main__": main()