SUimeModelTraner/debug_inference.py

238 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
输入法模型推理调试脚本
用于诊断为什么模型预测结果异常
"""
from pathlib import Path
import torch
import torch.nn.functional as F
def debug_model_checkpoint(checkpoint_path: str):
"""调试checkpoint文件"""
print(f"\n🔍 调试checkpoint: {checkpoint_path}")
if not Path(checkpoint_path).exists():
print(f"❌ Checkpoint文件不存在: {checkpoint_path}")
return None
# 加载checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Checkpoint类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f"Checkpoint键名: {list(checkpoint.keys())}")
# 检查是否有模型状态字典
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
print(f"使用'model_state_dict'键,包含{len(state_dict)}个权重")
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
print(f"使用'state_dict'键,包含{len(state_dict)}个权重")
else:
# 可能是直接的状态字典
state_dict = checkpoint
print(f"使用直接状态字典,包含{len(state_dict)}个权重")
# 打印前10个权重键名和形状
print("\n前10个权重键名和形状:")
for i, (key, value) in enumerate(list(state_dict.items())[:10]):
print(f" {i + 1}. {key}: {value.shape}")
# 特别检查分类头
classifier_keys = [k for k in state_dict.keys() if "classifier" in k]
if classifier_keys:
print(f"\n分类头相关权重:")
for key in classifier_keys:
weight = state_dict[key]
print(
f" {key}: shape={weight.shape}, range=[{weight.min():.4f}, {weight.max():.4f}]"
)
else:
print(f"\n⚠️ 未找到分类头权重")
return state_dict
else:
print(f"❌ Checkpoint不是字典类型: {type(checkpoint)}")
return None
def debug_input_preparation():
"""调试输入数据准备"""
print("\n🔍 调试输入数据准备")
# 测试text_to_pinyin_ids函数
from src.model.dataset import text_to_pinyin_ids
test_pinyins = ["tian", "shang", "ha", "ni", "hao"]
for pinyin in test_pinyins:
ids = text_to_pinyin_ids(pinyin)
print(f"拼音 '{pinyin}' -> ID列表: {ids}")
# 测试tokenizer
from modelscope import AutoTokenizer
try:
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"\n✅ Tokenizer加载成功词汇表大小: {tokenizer.vocab_size}")
# 测试tokenizer
test_texts = ["今天天气", "我们去公园", "张三|李四|今天天气"]
for text in test_texts:
encoded = tokenizer(
text, max_length=128, truncation=True, return_tensors="pt"
)
print(f"文本 '{text}' -> input_ids形状: {encoded['input_ids'].shape}")
except Exception as e:
print(f"❌ Tokenizer加载失败: {e}")
def debug_query_engine():
"""调试查询引擎"""
print("\n🔍 调试查询引擎")
try:
from src.model.query import QueryEngine
query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
query_engine.load(stats_path)
print(f"✅ 查询引擎加载成功")
# 测试字符查询
test_chars = ["", "", "", "", ""]
for char in test_chars:
results = query_engine.query_by_char(char, limit=3)
if results:
print(f"字符 '{char}': {[(r[0], r[1], r[2]) for r in results[:3]]}")
else:
print(f"字符 '{char}': 无结果")
# 测试拼音查询
test_pinyins = ["tian", "shang", "ren", "hao", "de"]
for pinyin in test_pinyins:
results = query_engine.query_by_pinyin(pinyin, limit=3)
if results:
print(
f"拼音 '{pinyin}': {[(r[0], r[1], r[2]) for r in results[:3]]}"
)
else:
print(f"拼音 '{pinyin}': 无结果")
else:
print(f"❌ 统计文件不存在: {stats_path}")
except Exception as e:
print(f"❌ 查询引擎调试失败: {e}")
import traceback
traceback.print_exc()
def create_minimal_test_model():
"""创建最小测试模型,检查基本功能"""
print("\n🔍 创建最小测试模型")
try:
from src.model.model import InputMethodEngine
# 创建小模型
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=30,
dim=64, # 小维度测试
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=64,
compile=False,
)
print(f"✅ 最小模型创建成功")
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f" 分类头形状: {model.classifier.weight.shape}")
# 测试前向传播
batch_size = 2
input_ids = torch.randint(0, 100, (batch_size, 64))
token_type_ids = torch.zeros((batch_size, 64), dtype=torch.long)
attention_mask = torch.ones((batch_size, 64), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, 24))
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
with torch.no_grad():
logits = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
print(f" 前向传播成功logits形状: {logits.shape}")
print(f" logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
# 检查输出
probs = F.softmax(logits, dim=-1)
print(f" 概率范围: [{probs.min():.4f}, {probs.max():.4f}]")
return model
except Exception as e:
print(f"❌ 最小模型测试失败: {e}")
import traceback
traceback.print_exc()
return None
def main():
"""主调试函数"""
import argparse
parser = argparse.ArgumentParser(description="输入法模型推理调试")
parser.add_argument("--checkpoint", type=str, help="模型checkpoint路径")
args = parser.parse_args()
print("=" * 60)
print("输入法模型推理调试工具")
print("=" * 60)
# 调试checkpoint
if args.checkpoint:
state_dict = debug_model_checkpoint(args.checkpoint)
else:
print("\n⚠️ 未提供checkpoint路径跳过checkpoint调试")
state_dict = None
# 调试输入准备
debug_input_preparation()
# 调试查询引擎
debug_query_engine()
# 创建最小测试模型
model = create_minimal_test_model()
print("\n" + "=" * 60)
print("调试完成")
print("=" * 60)
if __name__ == "__main__":
main()