238 lines
7.4 KiB
Python
238 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
输入法模型推理调试脚本
|
||
用于诊断为什么模型预测结果异常
|
||
"""
|
||
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
|
||
def debug_model_checkpoint(checkpoint_path: str):
|
||
"""调试checkpoint文件"""
|
||
print(f"\n🔍 调试checkpoint: {checkpoint_path}")
|
||
|
||
if not Path(checkpoint_path).exists():
|
||
print(f"❌ Checkpoint文件不存在: {checkpoint_path}")
|
||
return None
|
||
|
||
# 加载checkpoint
|
||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||
print(f"Checkpoint类型: {type(checkpoint)}")
|
||
|
||
if isinstance(checkpoint, dict):
|
||
print(f"Checkpoint键名: {list(checkpoint.keys())}")
|
||
|
||
# 检查是否有模型状态字典
|
||
if "model_state_dict" in checkpoint:
|
||
state_dict = checkpoint["model_state_dict"]
|
||
print(f"使用'model_state_dict'键,包含{len(state_dict)}个权重")
|
||
elif "state_dict" in checkpoint:
|
||
state_dict = checkpoint["state_dict"]
|
||
print(f"使用'state_dict'键,包含{len(state_dict)}个权重")
|
||
else:
|
||
# 可能是直接的状态字典
|
||
state_dict = checkpoint
|
||
print(f"使用直接状态字典,包含{len(state_dict)}个权重")
|
||
|
||
# 打印前10个权重键名和形状
|
||
print("\n前10个权重键名和形状:")
|
||
for i, (key, value) in enumerate(list(state_dict.items())[:10]):
|
||
print(f" {i + 1}. {key}: {value.shape}")
|
||
|
||
# 特别检查分类头
|
||
classifier_keys = [k for k in state_dict.keys() if "classifier" in k]
|
||
if classifier_keys:
|
||
print(f"\n分类头相关权重:")
|
||
for key in classifier_keys:
|
||
weight = state_dict[key]
|
||
print(
|
||
f" {key}: shape={weight.shape}, range=[{weight.min():.4f}, {weight.max():.4f}]"
|
||
)
|
||
else:
|
||
print(f"\n⚠️ 未找到分类头权重")
|
||
|
||
return state_dict
|
||
else:
|
||
print(f"❌ Checkpoint不是字典类型: {type(checkpoint)}")
|
||
return None
|
||
|
||
|
||
def debug_input_preparation():
|
||
"""调试输入数据准备"""
|
||
print("\n🔍 调试输入数据准备")
|
||
|
||
# 测试text_to_pinyin_ids函数
|
||
from src.model.dataset import text_to_pinyin_ids
|
||
|
||
test_pinyins = ["tian", "shang", "ha", "ni", "hao"]
|
||
for pinyin in test_pinyins:
|
||
ids = text_to_pinyin_ids(pinyin)
|
||
print(f"拼音 '{pinyin}' -> ID列表: {ids}")
|
||
|
||
# 测试tokenizer
|
||
from modelscope import AutoTokenizer
|
||
|
||
try:
|
||
tokenizer_path = (
|
||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||
)
|
||
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||
print(f"\n✅ Tokenizer加载成功,词汇表大小: {tokenizer.vocab_size}")
|
||
|
||
# 测试tokenizer
|
||
test_texts = ["今天天气", "我们去公园", "张三|李四|今天天气"]
|
||
for text in test_texts:
|
||
encoded = tokenizer(
|
||
text, max_length=128, truncation=True, return_tensors="pt"
|
||
)
|
||
print(f"文本 '{text}' -> input_ids形状: {encoded['input_ids'].shape}")
|
||
except Exception as e:
|
||
print(f"❌ Tokenizer加载失败: {e}")
|
||
|
||
|
||
def debug_query_engine():
|
||
"""调试查询引擎"""
|
||
print("\n🔍 调试查询引擎")
|
||
|
||
try:
|
||
from src.model.query import QueryEngine
|
||
|
||
query_engine = QueryEngine()
|
||
stats_path = (
|
||
Path(__file__).parent
|
||
/ "src"
|
||
/ "model"
|
||
/ "assets"
|
||
/ "pinyin_char_statistics.json"
|
||
)
|
||
|
||
if stats_path.exists():
|
||
query_engine.load(stats_path)
|
||
print(f"✅ 查询引擎加载成功")
|
||
|
||
# 测试字符查询
|
||
test_chars = ["天", "上", "人", "好", "的"]
|
||
for char in test_chars:
|
||
results = query_engine.query_by_char(char, limit=3)
|
||
if results:
|
||
print(f"字符 '{char}': {[(r[0], r[1], r[2]) for r in results[:3]]}")
|
||
else:
|
||
print(f"字符 '{char}': 无结果")
|
||
|
||
# 测试拼音查询
|
||
test_pinyins = ["tian", "shang", "ren", "hao", "de"]
|
||
for pinyin in test_pinyins:
|
||
results = query_engine.query_by_pinyin(pinyin, limit=3)
|
||
if results:
|
||
print(
|
||
f"拼音 '{pinyin}': {[(r[0], r[1], r[2]) for r in results[:3]]}"
|
||
)
|
||
else:
|
||
print(f"拼音 '{pinyin}': 无结果")
|
||
else:
|
||
print(f"❌ 统计文件不存在: {stats_path}")
|
||
except Exception as e:
|
||
print(f"❌ 查询引擎调试失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
|
||
def create_minimal_test_model():
|
||
"""创建最小测试模型,检查基本功能"""
|
||
print("\n🔍 创建最小测试模型")
|
||
|
||
try:
|
||
from src.model.model import InputMethodEngine
|
||
|
||
# 创建小模型
|
||
model = InputMethodEngine(
|
||
vocab_size=100,
|
||
pinyin_vocab_size=30,
|
||
dim=64, # 小维度测试
|
||
num_slots=8,
|
||
n_layers=2,
|
||
n_heads=2,
|
||
num_experts=4,
|
||
max_seq_len=64,
|
||
compile=False,
|
||
)
|
||
|
||
print(f"✅ 最小模型创建成功")
|
||
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||
print(f" 分类头形状: {model.classifier.weight.shape}")
|
||
|
||
# 测试前向传播
|
||
batch_size = 2
|
||
input_ids = torch.randint(0, 100, (batch_size, 64))
|
||
token_type_ids = torch.zeros((batch_size, 64), dtype=torch.long)
|
||
attention_mask = torch.ones((batch_size, 64), dtype=torch.long)
|
||
pinyin_ids = torch.randint(0, 30, (batch_size, 24))
|
||
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
|
||
|
||
with torch.no_grad():
|
||
logits = model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
print(f" 前向传播成功,logits形状: {logits.shape}")
|
||
print(f" logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
|
||
|
||
# 检查输出
|
||
probs = F.softmax(logits, dim=-1)
|
||
print(f" 概率范围: [{probs.min():.4f}, {probs.max():.4f}]")
|
||
|
||
return model
|
||
|
||
except Exception as e:
|
||
print(f"❌ 最小模型测试失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def main():
|
||
"""主调试函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="输入法模型推理调试")
|
||
parser.add_argument("--checkpoint", type=str, help="模型checkpoint路径")
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 60)
|
||
print("输入法模型推理调试工具")
|
||
print("=" * 60)
|
||
|
||
# 调试checkpoint
|
||
if args.checkpoint:
|
||
state_dict = debug_model_checkpoint(args.checkpoint)
|
||
else:
|
||
print("\n⚠️ 未提供checkpoint路径,跳过checkpoint调试")
|
||
state_dict = None
|
||
|
||
# 调试输入准备
|
||
debug_input_preparation()
|
||
|
||
# 调试查询引擎
|
||
debug_query_engine()
|
||
|
||
# 创建最小测试模型
|
||
model = create_minimal_test_model()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("调试完成")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|