feat: 添加模型权重检查与推理调试工具脚本

This commit is contained in:
songsenand 2026-04-06 12:29:22 +08:00
parent a203e67aff
commit 2f0166c8ce
9 changed files with 1704 additions and 161 deletions

147
check_weights.py Executable file
View File

@ -0,0 +1,147 @@
#!/usr/bin/env python3
"""
快速检查模型权重加载情况的脚本
"""
from pathlib import Path
import numpy as np
import torch
def analyze_checkpoint(checkpoint_path):
"""分析checkpoint文件"""
print(f"🔍 分析checkpoint: {checkpoint_path}")
if not Path(checkpoint_path).exists():
print(f"❌ 文件不存在")
return
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"✅ 加载成功")
print(f" 类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f" 键名: {list(checkpoint.keys())}")
# 找到模型状态字典
state_dict = None
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
print(f" 🔍 使用'model_state_dict'")
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
print(f" 🔍 使用'state_dict'")
else:
# 可能是直接的状态字典
state_dict = checkpoint
print(f" 🔍 使用直接状态字典")
if state_dict:
print(f" 总权重数: {len(state_dict)}")
# 分析分类头权重
classifier_keys = []
for key in state_dict.keys():
if "classifier" in key:
classifier_keys.append(key)
if classifier_keys:
print(f" 📊 分类头相关权重:")
for key in classifier_keys:
weight = state_dict[key]
print(f" {key}: shape={weight.shape}")
print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]")
print(f" 均值: {weight.mean():.6f}")
print(f" 标准差: {weight.std():.6f}")
# 检查权重是否接近随机初始化
if weight.std() < 0.01:
print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练")
# 检查模型架构键名
print(f"\n 🔑 模型架构键名示例前20个:")
for i, key in enumerate(list(state_dict.keys())[:20]):
weight = state_dict[key]
print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}")
# 检查是否有预期的组件
expected_components = [
"context_encoder",
"slot_memory",
"cross_attn",
"moe",
"classifier",
]
found_components = []
for comp in expected_components:
found = any(comp in key for key in state_dict.keys())
if found:
found_components.append(comp)
print(f"\n 📋 找到的模型组件: {found_components}")
missing = set(expected_components) - set(found_components)
if missing:
print(f" ❌ 缺失的组件: {missing}")
return state_dict
else:
print(f"❌ checkpoint不是字典类型")
except Exception as e:
print(f"❌ 加载失败: {e}")
import traceback
traceback.print_exc()
def check_weight_distribution(state_dict):
"""检查权重分布"""
print(f"\n📊 权重分布统计:")
weight_stats = []
for key, weight in state_dict.items():
if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置
stats = {
"key": key,
"shape": weight.shape,
"min": weight.min().item(),
"max": weight.max().item(),
"mean": weight.mean().item(),
"std": weight.std().item(),
"abs_mean": weight.abs().mean().item(),
}
weight_stats.append(stats)
# 打印前10个权重
for i, stats in enumerate(weight_stats[:10]):
print(f" {i + 1:2d}. {stats['key']:40}")
print(f" 形状: {stats['shape']}")
print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]")
print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}")
# 检查是否接近随机初始化
if stats["std"] < 0.01:
print(f" ⚠️ 警告: 标准差很小,可能未训练")
return weight_stats
def main():
import sys
if len(sys.argv) < 2:
print("使用方法: python check_weights.py <checkpoint_path>")
print("示例: python check_weights.py ./output/checkpoints/best_model.pt")
return
checkpoint_path = sys.argv[1]
state_dict = analyze_checkpoint(checkpoint_path)
if state_dict:
check_weight_distribution(state_dict)
if __name__ == "__main__":
main()

237
debug_inference.py Normal file
View File

@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
输入法模型推理调试脚本
用于诊断为什么模型预测结果异常
"""
from pathlib import Path
import torch
import torch.nn.functional as F
def debug_model_checkpoint(checkpoint_path: str):
"""调试checkpoint文件"""
print(f"\n🔍 调试checkpoint: {checkpoint_path}")
if not Path(checkpoint_path).exists():
print(f"❌ Checkpoint文件不存在: {checkpoint_path}")
return None
# 加载checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Checkpoint类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f"Checkpoint键名: {list(checkpoint.keys())}")
# 检查是否有模型状态字典
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
print(f"使用'model_state_dict'键,包含{len(state_dict)}个权重")
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
print(f"使用'state_dict'键,包含{len(state_dict)}个权重")
else:
# 可能是直接的状态字典
state_dict = checkpoint
print(f"使用直接状态字典,包含{len(state_dict)}个权重")
# 打印前10个权重键名和形状
print("\n前10个权重键名和形状:")
for i, (key, value) in enumerate(list(state_dict.items())[:10]):
print(f" {i + 1}. {key}: {value.shape}")
# 特别检查分类头
classifier_keys = [k for k in state_dict.keys() if "classifier" in k]
if classifier_keys:
print(f"\n分类头相关权重:")
for key in classifier_keys:
weight = state_dict[key]
print(
f" {key}: shape={weight.shape}, range=[{weight.min():.4f}, {weight.max():.4f}]"
)
else:
print(f"\n⚠️ 未找到分类头权重")
return state_dict
else:
print(f"❌ Checkpoint不是字典类型: {type(checkpoint)}")
return None
def debug_input_preparation():
"""调试输入数据准备"""
print("\n🔍 调试输入数据准备")
# 测试text_to_pinyin_ids函数
from src.model.dataset import text_to_pinyin_ids
test_pinyins = ["tian", "shang", "ha", "ni", "hao"]
for pinyin in test_pinyins:
ids = text_to_pinyin_ids(pinyin)
print(f"拼音 '{pinyin}' -> ID列表: {ids}")
# 测试tokenizer
from modelscope import AutoTokenizer
try:
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"\n✅ Tokenizer加载成功词汇表大小: {tokenizer.vocab_size}")
# 测试tokenizer
test_texts = ["今天天气", "我们去公园", "张三|李四|今天天气"]
for text in test_texts:
encoded = tokenizer(
text, max_length=128, truncation=True, return_tensors="pt"
)
print(f"文本 '{text}' -> input_ids形状: {encoded['input_ids'].shape}")
except Exception as e:
print(f"❌ Tokenizer加载失败: {e}")
def debug_query_engine():
"""调试查询引擎"""
print("\n🔍 调试查询引擎")
try:
from src.model.query import QueryEngine
query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
query_engine.load(stats_path)
print(f"✅ 查询引擎加载成功")
# 测试字符查询
test_chars = ["", "", "", "", ""]
for char in test_chars:
results = query_engine.query_by_char(char, limit=3)
if results:
print(f"字符 '{char}': {[(r[0], r[1], r[2]) for r in results[:3]]}")
else:
print(f"字符 '{char}': 无结果")
# 测试拼音查询
test_pinyins = ["tian", "shang", "ren", "hao", "de"]
for pinyin in test_pinyins:
results = query_engine.query_by_pinyin(pinyin, limit=3)
if results:
print(
f"拼音 '{pinyin}': {[(r[0], r[1], r[2]) for r in results[:3]]}"
)
else:
print(f"拼音 '{pinyin}': 无结果")
else:
print(f"❌ 统计文件不存在: {stats_path}")
except Exception as e:
print(f"❌ 查询引擎调试失败: {e}")
import traceback
traceback.print_exc()
def create_minimal_test_model():
"""创建最小测试模型,检查基本功能"""
print("\n🔍 创建最小测试模型")
try:
from src.model.model import InputMethodEngine
# 创建小模型
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=30,
dim=64, # 小维度测试
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=64,
compile=False,
)
print(f"✅ 最小模型创建成功")
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f" 分类头形状: {model.classifier.weight.shape}")
# 测试前向传播
batch_size = 2
input_ids = torch.randint(0, 100, (batch_size, 64))
token_type_ids = torch.zeros((batch_size, 64), dtype=torch.long)
attention_mask = torch.ones((batch_size, 64), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, 24))
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
with torch.no_grad():
logits = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
print(f" 前向传播成功logits形状: {logits.shape}")
print(f" logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
# 检查输出
probs = F.softmax(logits, dim=-1)
print(f" 概率范围: [{probs.min():.4f}, {probs.max():.4f}]")
return model
except Exception as e:
print(f"❌ 最小模型测试失败: {e}")
import traceback
traceback.print_exc()
return None
def main():
"""主调试函数"""
import argparse
parser = argparse.ArgumentParser(description="输入法模型推理调试")
parser.add_argument("--checkpoint", type=str, help="模型checkpoint路径")
args = parser.parse_args()
print("=" * 60)
print("输入法模型推理调试工具")
print("=" * 60)
# 调试checkpoint
if args.checkpoint:
state_dict = debug_model_checkpoint(args.checkpoint)
else:
print("\n⚠️ 未提供checkpoint路径跳过checkpoint调试")
state_dict = None
# 调试输入准备
debug_input_preparation()
# 调试查询引擎
debug_query_engine()
# 创建最小测试模型
model = create_minimal_test_model()
print("\n" + "=" * 60)
print("调试完成")
print("=" * 60)
if __name__ == "__main__":
main()

602
inference.py Normal file
View File

@ -0,0 +1,602 @@
#!/usr/bin/env python3
"""
输入法模型推理脚本
使用方法:
python inference.py --checkpoint ./output/checkpoints/best_model.pt
交互模式: 分步询问输入
1. 上下文提示: 模型不掌握的专有词汇姓名等可为空
2. 光标前文本: 光标前的连续文本
3. 光标后文本: 光标后的连续文本
4. 拼音: 当前输入的拼音
5. 槽位历史: 用户已确认的输入历史如输入shanghai已确认""
示例场景:
输入"shanghai"已确认""继续输入"tian"
上下文提示: 张三,李四
光标前文本: 今天天气很好
光标后文本: 我们去公园玩
拼音: tian
槽位历史:
"""
import argparse
import time
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from modelscope import AutoTokenizer
from src.model.dataset import text_to_pinyin_ids
from src.model.model import InputMethodEngine
from src.model.query import QueryEngine
class InputMethodInference:
"""输入法模型推理器"""
def __init__(
self,
checkpoint_path: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
self.device = torch.device(device)
self.checkpoint_path = checkpoint_path
# 加载组件
print(f"正在加载模型从: {checkpoint_path}")
self.load_model()
# 加载tokenizer
print("正在加载tokenizer...")
self.load_tokenizer()
# 加载查询引擎
print("正在加载查询引擎...")
self.load_query_engine()
print(f"✅ 推理器初始化完成 (设备: {self.device})")
def load_model(self):
"""加载训练好的模型"""
# 创建模型实例(不编译)
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
# 加载checkpoint
# 加载训练好的权重强制先加载到CPU再移动到目标设备
# 这样确保GPU训练的权重能正确转换到CPU
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
if "model_state_dict" in checkpoint:
self.model.load_state_dict(checkpoint["model_state_dict"])
else:
self.model.load_state_dict(checkpoint)
self.model.eval()
self.model.to(self.device)
print(
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
)
def load_tokenizer(self):
"""加载tokenizer"""
try:
# 从assets/tokenizer加载tokenizer
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"✅ Tokenizer加载完成词汇表大小: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"⚠️ 无法加载tokenizer: {e}")
print("使用默认的bert-base-chinese tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def load_query_engine(self):
"""加载查询引擎用于字符-ID转换"""
try:
self.query_engine = QueryEngine()
# 加载默认的统计文件
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
self.query_engine.load(stats_path)
print(
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
)
else:
print(f"⚠️ 统计文件不存在: {stats_path}")
self.query_engine = None
except Exception as e:
print(f"⚠️ 无法加载查询引擎: {e}")
self.query_engine = None
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
"""将汉字转换为ID如果提供拼音则更精确"""
# 处理结束符
if char == "//":
return 0 # 假设0是结束符ID
if self.query_engine is None:
# 简单回退使用unicode编码
return ord(char) if len(char) == 1 else 0
try:
if pinyin is not None:
# 使用精确的字符-拼音对
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
if info:
return info.id
# 回退:获取字符的第一个拼音变体
results = self.query_engine.query_by_char(char, limit=1)
if results:
return results[0][0] # 返回ID
return 0
except:
return 0
def id_to_char(self, id: int) -> str:
"""将ID转换为汉字"""
# 处理结束符ID (假设0是结束符)
if id == 0:
return "//"
if self.query_engine is None:
return chr(id) if id < 0x110000 else "<?>"
try:
info = self.query_engine.query_by_id(id)
return info.char if info else f"[ID:{id}]"
except:
return f"[ID:{id}]"
def prepare_inputs(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
max_seq_len: int = 128,
):
"""
准备模型输入
Args:
context_prompts: 上下文提示专有词汇姓名等|分隔
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表用户已确认的输入历史
max_seq_len: 最大序列长度
Returns:
模型输入字典
"""
# 1. 构建tokenizer输入
# 根据dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before
# part3: text_after
# 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize
encoded = self.tokenizer(
input_text,
text_after,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
# 3. 处理拼音输入
pinyin_ids = text_to_pinyin_ids(pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long)
# 4. 处理历史槽位(用户已确认的输入历史)
history_slot_ids = []
for char in slot_chars:
# 为每个槽位汉字查找ID用户已确认的输入历史
char_id = self.char_to_id(char)
history_slot_ids.append(char_id)
# 填充到8个槽位
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
history_tensor = torch.tensor([history_slot_ids], dtype=torch.long)
# 5. 移动到设备
inputs = {
"input_ids": encoded["input_ids"].to(self.device),
"token_type_ids": encoded["token_type_ids"].to(self.device),
"attention_mask": encoded["attention_mask"].to(self.device),
"pinyin_ids": pinyin_tensor.to(self.device),
"history_slot_ids": history_tensor.to(self.device),
}
return inputs
def predict(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
top_k: int = 20,
) -> Tuple[List[Tuple[str, float, int]], float]:
"""
执行推理
Args:
context_prompts: 上下文提示专有词汇姓名等|分隔
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表用户已确认的输入历史最大8个
top_k: 返回top-k个预测结果
Returns:
(predictions, inference_time_ms)
predictions: List[Tuple[char, score, id]]
"""
start_time = time.perf_counter()
# 准备输入
inputs = self.prepare_inputs(
context_prompts, text_before, text_after, pinyin, slot_chars
)
# 调试信息:打印输入形状
print("\n🔍 调试信息 - 输入检查:")
for key, tensor in inputs.items():
print(
f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}"
)
if key in ["history_slot_ids", "pinyin_ids"]:
print(f" 值: {tensor.cpu().numpy().tolist()}")
# 特别检查拼音和槽位输入
print("\n🔍 拼音输入详细分析:")
pinyin_tensor = inputs["pinyin_ids"]
pinyin_values = pinyin_tensor.cpu().numpy()[0]
# 将拼音ID转换回字符
from src.model.dataset import text_to_pinyin_ids
# 逆向转换可能需要一个反向映射,这里先简单显示
print(f" 拼音ID列表: {pinyin_values}")
print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}")
print("\n🔍 槽位历史详细分析:")
slot_tensor = inputs["history_slot_ids"]
slot_values = slot_tensor.cpu().numpy()[0]
print(f" 槽位ID列表: {slot_values}")
print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}")
# 将ID转换为汉字
slot_chars_converted = [self.id_to_char(id) for id in slot_values]
print(f" 槽位汉字: {slot_chars_converted}")
# 调试:检查模型权重
print(f"\n🔍 调试信息 - 模型检查:")
print(f" 模型设备: {self.device}")
print(f" 模型是否训练模式: {self.model.training}")
# 检查模型词汇表大小
vocab_size = self.model.vocab_size
print(f" 模型词汇表大小: {vocab_size}")
# 检查前5个ID对应的字符
print(f" 前5个ID对应的字符:")
for i in range(1, 6):
char = self.id_to_char(i)
print(f" ID {i}: '{char}'")
# 检查分类头权重
# 调试:检查分类头权重
classifier_weight = self.model.classifier.weight.data
classifier_bias = self.model.classifier.bias.data
print(f" 分类头权重形状: {classifier_weight.shape}")
print(
f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]"
)
print(f" 分类头权重均值: {classifier_weight.mean():.4f}")
print(f" 分类头权重标准差: {classifier_weight.std():.4f}")
print(f" 分类头偏置形状: {classifier_bias.shape}")
print(
f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]"
)
print(f" 分类头偏置均值: {classifier_bias.mean():.4f}")
print(f" 分类头偏置标准差: {classifier_bias.std():.4f}")
# 推理CPU推理时禁用混合精度
with torch.no_grad():
if self.device.type == "cuda":
with torch.autocast(device_type="cuda"):
logits = self.model(**inputs)
else:
# CPU推理时不使用autocast
logits = self.model(**inputs)
# 调试检查logits
print(f"\n🔍 调试信息 - 输出检查:")
print(f" Logits形状: {logits.shape}")
print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
print(f" Logits均值: {logits.mean():.4f}")
# 检查logits中最大值和对应的ID
max_val, max_idx = torch.max(logits, dim=-1)
print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}")
# 获取top-k预测
probs = F.softmax(logits, dim=-1)
top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
# 调试:检查概率分布
print(f" 概率总和: {probs.sum().item():.4f}")
top_probs_array = top_probs.cpu().numpy().flatten()
top_indices_array = top_indices.cpu().numpy().flatten()
print(f" Top-{top_k}概率: {top_probs_array}")
print(f" Top-{top_k} ID: {top_indices_array}")
# 检查概率是否均匀分布
print(f" 概率分布分析:")
print(f" 平均概率: {probs.mean().item():.6f}")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(f" 标准差: {probs.std().item():.6f}")
# 检查top-20概率是否都很小
if top_probs_array[0] < 0.01:
print(
f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练"
)
print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配")
inference_time_ms = (time.perf_counter() - start_time) * 1000
# 转换为可读结果
predictions = []
for i in range(top_k):
idx = int(top_indices[0, i].item())
prob = top_probs[0, i].item()
char = self.id_to_char(idx)
predictions.append((char, prob, idx))
return predictions, inference_time_ms
def interactive_mode(self):
"""交互式推理模式 - 分步询问输入"""
print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式")
print("=" * 60)
print("说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
print(" - 拼音: 当前输入的拼音")
print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认''")
print("提示: 输入 'quit''exit''q' 可随时退出")
print("-" * 60)
while True:
try:
print("\n" + "=" * 60)
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学")
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
# 解析上下文提示
context_prompts = []
if context_input:
context_prompts = [
item.strip()
for item in context_input.split(",")
if item.strip()
]
print(
f"✅ 已记录上下文提示: {context_prompts if context_prompts else ''}"
)
print("\n" + "-" * 40)
print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好")
text_before = input("请输入光标前文本: ").strip()
if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录光标前文本: '{text_before}'")
print("\n" + "-" * 40)
print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩")
text_after = input("请输入光标后文本: ").strip()
if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录光标后文本: '{text_after}'")
print("\n" + "-" * 40)
print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao")
pinyin = input("请输入拼音: ").strip()
if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
print(f"✅ 已记录拼音: '{pinyin}'")
print("\n" + "-" * 40)
print("第5步: 槽位历史(已确认的输入)")
print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)")
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
break
# 解析槽位历史
slot_chars = []
if slot_input:
slot_chars = [
char.strip() for char in slot_input.split(",") if char.strip()
]
print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else ''}")
print("\n" + "=" * 60)
print("📝 输入汇总:")
print(f" 上下文提示: {context_prompts if context_prompts else ''}")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars if slot_chars else ''}")
# 执行推理
print("\n🔮 推理中...")
predictions, inference_time = self.predict(
context_prompts, text_before, text_after, pinyin, slot_chars
)
# 显示结果
print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)")
print("\n🏆 Top-20 预测结果:")
print("-" * 50)
for i, (char, prob, idx) in enumerate(predictions):
if char == "//":
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
else:
print(
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
)
# 显示原始拼音对应的可能汉字
if pinyin and self.query_engine:
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
if pinyin_results:
for j, (pid, char, count) in enumerate(pinyin_results):
print(f" {char} (频次: {count:,})")
else:
print(" (无匹配结果)")
# 询问是否继续
print("\n" + "-" * 40)
continue_input = input("是否继续推理?(y/n): ").strip().lower()
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
except KeyboardInterrupt:
print("\n\n退出交互模式")
break
except Exception as e:
print(f"\n❌ 推理出错: {e}")
import traceback
traceback.print_exc()
# 询问是否继续
continue_input = input("\n是否继续?(y/n): ").strip().lower()
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
def main():
parser = argparse.ArgumentParser(description="输入法模型推理")
parser.add_argument(
"--checkpoint", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cpu", "cuda"],
help="推理设备 (默认: auto)",
)
parser.add_argument(
"--interactive", action="store_true", default=True, help="交互模式 (默认: True)"
)
parser.add_argument("--test", action="store_true", help="运行测试推理")
args = parser.parse_args()
# 选择设备
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# 初始化推理器
inference = InputMethodInference(args.checkpoint, device)
# 测试推理
if args.test:
print("\n🧪 运行测试推理...")
print("测试场景: 输入'shanghai',已确认第一个字'',继续输入'tian'")
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
test_predictions, test_time = inference.predict(
context_prompts=["张三", "李四"],
text_before="今天天气",
text_after="很好",
pinyin="tian",
slot_chars=[""], # 用户已确认输入"上"
)
print(f"测试推理耗时: {test_time:.2f}ms")
print(f"Top-5 结果:")
for i, (char, prob, idx) in enumerate(test_predictions[:5]):
if char == "//":
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
else:
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
# 交互模式
if args.interactive:
inference.interactive_mode()
if __name__ == "__main__":
main()

View File

@ -311,35 +311,54 @@ class MoELayer(nn.Module):
out: [batch, seq_len, dim]
"""
B, L, D = x.shape
num_tokens = B * L
# 展平输入以便处理
x_flat = x.view(num_tokens, D) # [B*L, D]
# 1. Compute Gating Scores
gates = self.gate(x) # [B, L, num_experts]
# 1. 计算门控分数
gates = self.gate(x_flat) # [B*L, num_experts]
# 2. Select Top-K Experts
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
# 2. 选择 Top-K 专家
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
# Normalize weights for selected experts
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
# 归一化权重
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
# 3. Dispatch and Compute
# Initialize output
out = torch.zeros_like(x)
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
) # [B*L, num_experts, D]
# Reshape for easier processing: flatten batch and sequence dimensions
x_flat = x.view(-1, D) # [B*L, D]
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
# 4. 使用 gather 选择对应专家的输出
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
selected_outputs = torch.gather(
expert_outputs, 1, indices_expanded
) # [B*L, K, D]
# 5. 加权求和
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
# For each of the top-k positions
for k in range(self.top_k):
# Get expert indices and weights for this position
expert_indices = topk_indices_flat[:, k] # [B*L]
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
# 恢复原始形状
return out_flat.view(B, L, D)
# Process each expert separately
for e_idx in range(self.num_experts):
# Mask for tokens assigned to this expert at position k
mask = expert_indices == e_idx # [B*L]
if not mask.any():
continue
# Extract tokens for this expert
x_selected = x_flat[mask] # [N_selected, D]
if x_selected.numel() == 0:
continue
# Pass through expert
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
# Apply expert weights and add to output
weighted_out = expert_out * expert_weights[mask]
# Scatter back to flat output
out_flat = out.view(-1, D)
out_flat[mask] += weighted_out
# Reshape back to original shape
out = out.view(B, L, D)
return out

View File

@ -1,12 +1,16 @@
import http.server
import json
import os
import socketserver
import subprocess
import sys
import threading
import time
import webbrowser
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
from typing import Callable, Optional, Union
from urllib.parse import urlparse
import typer
@ -335,6 +339,215 @@ def check_status(
raise typer.Exit(code=1)
def create_http_handler(status_file_path: str, enable_cors: bool = True):
"""创建HTTP请求处理器"""
class TrainingStatusHTTPHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
# 只允许访问状态文件
parsed_path = urlparse(self.path)
if parsed_path.path not in ["/", "/training_status.json"]:
self.send_error(404, "File not found")
return
try:
# 检查文件是否存在(包括临时文件)
main_file = status_file_path
temp_file = f"{status_file_path}.tmp"
# 尝试读取数据最多重试3次
max_retries = 3
retry_delay = 0.1 # 100ms
content = None
json_valid = False
for attempt in range(max_retries):
try:
# 首先检查主文件是否存在
if not os.path.exists(main_file):
# 检查临时文件是否存在(可能正在写入)
if os.path.exists(temp_file):
# 如果只有临时文件存在,尝试读取临时文件
file_to_read = temp_file
else:
# 两个文件都不存在
break
else:
file_to_read = main_file
# 读取文件内容
with open(file_to_read, "r", encoding="utf-8") as f:
content = f.read()
# 验证JSON格式
if content:
json.loads(content) # 验证JSON格式
json_valid = True
break # JSON有效跳出重试循环
else:
# 空内容,等待重试
time.sleep(retry_delay)
except (json.JSONDecodeError, IOError) as e:
# JSON解析错误或IO错误等待后重试
if attempt < max_retries - 1:
time.sleep(retry_delay)
else:
# 最后一次尝试也失败
raise e
if not content or not json_valid:
self.send_error(404, "Status file not found or invalid JSON")
return
# 设置响应头
self.send_response(200)
self.send_header("Content-type", "application/json")
self.send_header("Content-Length", str(len(content)))
# 添加缓存控制头,避免浏览器缓存
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
self.send_header("Pragma", "no-cache")
self.send_header("Expires", "0")
# 添加CORS头
if enable_cors:
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type")
self.end_headers()
# 发送内容
self.wfile.write(content.encode("utf-8"))
except Exception as e:
self.send_error(500, f"Internal server error: {str(e)}")
def do_OPTIONS(self):
"""处理OPTIONS请求用于CORS预检"""
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type")
self.end_headers()
def log_message(self, format, *args):
"""重写日志方法,减少日志输出"""
# 可以选择性地记录日志
# typer.echo(f"HTTP Server: {format % args}")
pass
return TrainingStatusHTTPHandler
def start_http_server(
status_file: str,
port: int,
host: str,
enable_cors: bool = True,
) -> Callable:
"""
启动HTTP服务器
Args:
status_file: 状态文件路径
port: 端口号
host: 主机地址
enable_cors: 是否启用CORS
Returns:
停止服务器的函数
"""
# 获取绝对路径
status_file_path = os.path.abspath(status_file)
# 创建自定义处理器
handler = create_http_handler(status_file_path, enable_cors)
# 创建服务器
server = socketserver.TCPServer((host, port), handler)
# 在后台启动服务器
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
typer.echo(f"🌐 HTTP服务器已启动")
typer.echo(f" 📁 状态文件: {status_file_path}")
typer.echo(f" 🔗 访问地址: http://{host}:{port}/training_status.json")
typer.echo(f" 🌍 CORS支持: {'已启用' if enable_cors else '已禁用'}")
typer.echo("\n按 Ctrl+C 停止服务器\n")
# 返回停止函数
def stop_server():
typer.echo("\n🛑 正在停止HTTP服务器...")
server.shutdown()
server.server_close()
typer.echo("✅ HTTP服务器已停止")
return stop_server
@app.command(name="serve")
def serve_status_file(
status_file: str = typer.Option(
"./output/training_status.json",
"--status-file",
"-s",
help="训练状态JSON文件路径",
),
port: int = typer.Option(8080, "--port", "-p", help="HTTP服务端口号"),
host: str = typer.Option("0.0.0.0", "--host", help="HTTP服务主机地址"),
cors: bool = typer.Option(True, "--cors", help="启用CORS支持"),
):
"""
启动HTTP服务提供训练状态JSON文件访问
启动后可通过 http://<host>:<port>/training_status.json 访问数据
"""
# 检查状态文件是否存在
if not os.path.exists(status_file):
typer.echo(f"⚠️ 警告: 状态文件不存在: {status_file}")
typer.echo("开始训练后,训练脚本会自动创建此文件。")
typer.echo("您可以先启动HTTP服务然后开始训练。")
# 创建目录(如果不存在)
os.makedirs(os.path.dirname(status_file), exist_ok=True)
# 创建空的JSON文件
with open(status_file, "w", encoding="utf-8") as f:
json.dump([], f)
typer.echo(f"✅ 已创建空状态文件: {status_file}")
try:
# 启动HTTP服务器
stop_server = start_http_server(
status_file=status_file,
port=port,
host=host,
enable_cors=cors,
)
# 等待用户中断
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
stop_server()
except OSError as e:
if "Address already in use" in str(e):
typer.echo(f"❌ 错误: 端口 {port} 已被占用")
typer.echo("请使用其他端口: --port <端口号>")
else:
typer.echo(f"❌ 启动HTTP服务器时出错: {e}")
raise typer.Exit(code=1)
except Exception as e:
typer.echo(f"❌ 启动HTTP服务器时出错: {e}")
raise typer.Exit(code=1)
def main():
"""主函数"""
app()

View File

@ -2,6 +2,7 @@ import json
import math
import os
import random
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -490,9 +491,13 @@ class Trainer:
[self.training_status_data] if self.training_status_data else []
)
# 写入文件
with open(self.status_file, "w", encoding="utf-8") as f:
# 使用原子写入避免读取不完整JSON
# 先写入临时文件,然后原子重命名
temp_file = Path(f"{self.status_file}.tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.training_status_data, f, indent=2, ensure_ascii=False)
# 原子重命名Unix系统是原子操作
temp_file.rename(self.status_file)
except Exception as e:
logger.error(f"Failed to write training status: {e}")

View File

@ -6,10 +6,15 @@ from datetime import datetime
from pathlib import Path
import pandas as pd
import plotly.express as px
# import plotly.express as px # 暂时未使用
import plotly.graph_objects as go
import streamlit as st
from plotly.subplots import make_subplots
# from plotly.subplots import make_subplots # 暂时未使用
# 用于HTTP URL支持
# requests模块将在load_from_url函数中动态导入
# 添加项目路径到系统路径,以便导入模型
sys.path.append(str(Path(__file__).parent.parent))
@ -95,7 +100,46 @@ st.markdown(
def load_training_data(file_path):
"""加载训练状态数据"""
"""从本地文件或HTTP URL加载训练状态数据"""
# 检查是否是HTTP/HTTPS URL
if file_path.startswith(("http://", "https://")):
return load_from_url(file_path)
else:
return load_from_local_file(file_path)
def load_from_url(url):
"""从HTTP URL加载数据"""
# 动态导入requests模块
try:
import requests
except ImportError:
st.error(
"requests库未安装无法从HTTP URL加载数据。请安装pip install requests"
)
return pd.DataFrame()
try:
# 设置合理的超时
response = requests.get(url, timeout=10)
response.raise_for_status()
data = response.json()
return convert_to_dataframe(data)
except requests.exceptions.RequestException as e:
st.warning(f"从URL加载数据失败: {e}")
return pd.DataFrame()
except json.JSONDecodeError:
st.warning("远程返回的数据不是有效的JSON格式")
return pd.DataFrame()
except Exception as e:
st.warning(f"从URL加载数据时发生错误: {e}")
return pd.DataFrame()
def load_from_local_file(file_path):
"""从本地文件加载数据"""
try:
if not os.path.exists(file_path):
st.warning(f"文件不存在: {file_path}")
@ -104,6 +148,16 @@ def load_training_data(file_path):
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
return convert_to_dataframe(data)
except json.JSONDecodeError:
return pd.DataFrame()
except Exception:
return pd.DataFrame()
def convert_to_dataframe(data):
"""将数据转换为DataFrame包含数据验证和清理"""
# 检测是否是配置文件(检查是否有典型的配置键)
config_keys = [
"train_data_path",
@ -113,6 +167,7 @@ def load_training_data(file_path):
"batch_size",
"num_epochs",
]
if isinstance(data, dict):
# 检查是否是配置文件
if any(key in data for key in config_keys):
@ -169,11 +224,6 @@ def load_training_data(file_path):
return df
except json.JSONDecodeError:
return pd.DataFrame()
except Exception:
return pd.DataFrame()
def create_metric_card(label, value, delta=None, help_text=None):
"""创建指标卡片"""
@ -381,9 +431,9 @@ def main():
"TRAINING_STATUS_FILE", "./output/training_status.json"
)
status_file = st.text_input(
"状态文件路径",
"状态文件路径或URL",
value=default_status_file,
help="训练过程中生成的JSON状态文件路径",
help="可以是:\n1. 本地文件路径(如 ./output/training_status.json\n2. HTTP/HTTPS URL如 http://服务器IP:端口/training_status.json",
)
# 刷新间隔

109
test.py
View File

@ -5,20 +5,11 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.model import InputMethodEngine
from model.trainer import collate_fn, worker_init_fn
# Try to import DataLoader2 from torchdata, fallback to standard DataLoader
try:
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
DATA_LOADER2_AVAILABLE = True
print("✅ Using DataLoader2 from torchdata")
except ImportError:
DATA_LOADER2_AVAILABLE = False
print("⚠️ torchdata not installed, falling back to standard DataLoader")
max_iter_length = 128 * 128
batch_size = 1024
max_iter_length = 5
batch_size = 1
if sys.platform == "win32":
dataset_path = "data"
@ -29,40 +20,10 @@ dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
def create_dataloader():
"""
Create dataloader with DataLoader2 if available, otherwise fallback to DataLoader.
This function tries to handle streaming datasets better with DataLoader2.
"""
if DATA_LOADER2_AVAILABLE:
try:
# DataLoader2 configuration for streaming datasets
# Use MultiProcessingReadingService with careful worker settings
reading_service = MultiProcessingReadingService(
num_workers=2, # Start with 2 workers for streaming dataset
prefetch_factor=2, # Reduced prefetch for better memory management
persistent_workers=True,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
)
dataloader = DataLoader2(
dataset,
reading_service=reading_service,
batch_size=batch_size,
collate_fn=collate_fn,
shuffle=False, # Dataset handles shuffling internally
)
print(f"✅ Created DataLoader2 with {2} workers")
return dataloader
except Exception as e:
print(f"⚠️ DataLoader2 creation failed: {e}, falling back to DataLoader")
# Fallback to standard DataLoader
print("📊 Using standard DataLoader")
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=2, # Limited to 2 for streaming dataset compatibility
num_workers=1, # Limited to 2 for streaming dataset compatibility
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
@ -72,32 +33,52 @@ def create_dataloader():
return dataloader
samples = []
# Create the dataloader
dataloader = create_dataloader()
# Convert to list to test loading (as in original code)
dataloader_list = list([i for i in dataloader])
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
# Test the dataloader
print(f"🔍 Testing dataloader with batch_size={batch_size}")
print(f" Dataset max_iter_length: {max_iter_length}")
print(f" Expected batches: {max_iter_length / batch_size:.0f}")
# Process batches
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
samples.append(line)
try:
# Convert to list to test loading (as in original code)
dataloader_list = list([i for i in dataloader])
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
# Process batches
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
zero_labels = (line["labels"] == 0).sum()
print(f"Batch {i}: labels==0 count = {zero_labels.item()}")
# Early exit for testing
if i >= 5: # Limit to 5 batches for quick testing
print("⚠️ Limited to 5 batches for testing")
break
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
sample = samples[0]
input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"]
attention_mask = sample["attention_mask"]
pinyin_ids = sample["pinyin_ids"]
history_slot_ids = sample["history_slot_ids"]
for k, v in sample.items():
if isinstance(v, str):
print(f"{k}: {v}")
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
print(sort_res[0:5])
except Exception as e:
print(f"❌ Error during dataloader iteration: {e}")
import traceback
# 在test.py的res计算后添加
import torch.nn.functional as F
traceback.print_exc()
# 计算softmax概率
probs = F.softmax(res, dim=-1)
print("🏁 Test completed")
print(f"\n📊 概率分布分析:")
print(f" 形状: {probs.shape}")
print(f" 总概率和: {probs.sum().item():.6f}")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(f" 平均概率: {probs.mean().item():.6f}")
# 获取top-20概率
top_probs, top_indices = torch.topk(probs, k=20)
print(f"\n🏆 Top-20预测:")
for i in range(20):
idx = top_indices[0, i].item()
prob = top_probs[0, i].item()
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")

289
test_cpu_inference.py Normal file
View File

@ -0,0 +1,289 @@
#!/usr/bin/env python3
"""
测试GPU训练的模型在CPU上推理
解决设备转换和权重加载问题
"""
import sys
from pathlib import Path
import torch
def test_device_conversion(checkpoint_path):
"""测试设备转换"""
print("=" * 60)
print("GPU->CPU设备转换测试")
print("=" * 60)
# 方法1直接加载到CPU
print("\n🔍 方法1: 直接加载到CPU")
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print("✅ 直接加载到CPU成功")
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
# 检查是否有CUDA tensor
cuda_tensors = 0
for key, tensor in state_dict.items():
if tensor.is_cuda:
cuda_tensors += 1
print(f" 模型状态字典包含CUDA tensor: {cuda_tensors}/{len(state_dict)}")
# 将权重移动到CPU
cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
print(" 已将所有权重移动到CPU")
return cpu_state_dict, checkpoint
except Exception as e:
print(f"❌ 方法1失败: {e}")
return None, None
def test_model_creation(state_dict, config=None):
"""测试模型创建和权重加载"""
print("\n" + "=" * 60)
print("模型创建和权重加载测试")
print("=" * 60)
try:
from src.model.model import InputMethodEngine
# 从checkpoint获取配置或使用默认值
if config and "config" in config:
model_config = config["config"]
print("📋 使用checkpoint中的配置:")
for key, value in model_config.items():
print(f" {key}: {value}")
model = InputMethodEngine(
vocab_size=model_config.get("vocab_size", 10019),
pinyin_vocab_size=model_config.get("pinyin_vocab_size", 30),
dim=model_config.get("dim", 512),
num_slots=model_config.get("num_slots", 8),
n_layers=model_config.get("n_layers", 4),
n_heads=model_config.get("n_heads", 4),
num_experts=model_config.get("num_experts", 20),
max_seq_len=model_config.get("max_seq_len", 128),
compile=False,
)
else:
print("📋 使用默认配置")
model = InputMethodEngine(compile=False)
print("✅ 模型创建成功")
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f" 词汇表大小: {model.vocab_size}")
# 尝试加载权重
print("\n🔄 加载权重...")
try:
model.load_state_dict(state_dict)
print("✅ 权重加载成功 (strict=True)")
except RuntimeError as e:
print(f"⚠️ 严格模式加载失败: {e}")
print("🔄 尝试非严格模式加载...")
model.load_state_dict(state_dict, strict=False)
print("✅ 权重加载成功 (strict=False)")
# 检查分类头
print("\n📊 分类头检查:")
classifier_weight = model.classifier.weight.data
classifier_bias = model.classifier.bias.data
print(f" 权重形状: {classifier_weight.shape}")
print(
f" 权重范围: [{classifier_weight.min():.6f}, {classifier_weight.max():.6f}]"
)
print(
f" 权重均值: {classifier_weight.mean():.6f} ± {classifier_weight.std():.6f}"
)
print(f" 偏置形状: {classifier_bias.shape}")
print(f" 偏置范围: [{classifier_bias.min():.6f}, {classifier_bias.max():.6f}]")
return model
except Exception as e:
print(f"❌ 模型创建失败: {e}")
import traceback
traceback.print_exc()
return None
def test_forward_pass(model):
"""测试前向传播"""
print("\n" + "=" * 60)
print("前向传播测试")
print("=" * 60)
try:
model.eval()
# 创建简单的测试输入
batch_size = 2
seq_len = 64
# 使用随机但合理的输入
input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)
token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(1, 30, (batch_size, 24), dtype=torch.long) # 避免0
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
print("📋 测试输入:")
print(f" input_ids: {input_ids.shape}")
print(f" token_type_ids: {token_type_ids.shape}")
print(f" attention_mask: {attention_mask.shape}")
print(f" pinyin_ids: {pinyin_ids.shape}")
print(f" history_slot_ids: {history_slot_ids.shape}")
# 执行前向传播
with torch.no_grad():
logits = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
print("\n✅ 前向传播成功")
print(f" Logits形状: {logits.shape}")
print(f" Logits范围: [{logits.min():.6f}, {logits.max():.6f}]")
print(f" Logits均值: {logits.mean():.6f} ± {logits.std():.6f}")
# 检查输出是否合理
if logits.abs().max() < 1e-6:
print("⚠️ 警告: Logits值非常小可能权重未正确加载")
# 计算softmax概率
probs = torch.nn.functional.softmax(logits, dim=-1)
print("\n📊 概率分布:")
print(f" 概率总和: {probs.sum(dim=-1).mean().item():.6f} (应为1.0)")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(
f" 平均概率: {probs.mean().item():.6f} (预期: ~{1.0 / logits.size(-1):.6f})"
)
# 检查top-5预测
top_probs, top_indices = torch.topk(probs, k=5, dim=-1)
print("\n🏆 Batch 0的Top-5预测:")
for i in range(5):
print(
f" {i + 1}. ID {top_indices[0, i].item()}: {top_probs[0, i].item():.6f}"
)
return True
except Exception as e:
print(f"❌ 前向传播失败: {e}")
import traceback
traceback.print_exc()
return False
def test_id_mapping():
"""测试ID映射"""
print("\n" + "=" * 60)
print("ID映射测试")
print("=" * 60)
try:
from src.model.query import QueryEngine
query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
query_engine.load(stats_path)
print("✅ 查询引擎加载成功")
# 测试一些常见字的ID
test_chars = ["", "", "", "", "", "", ""]
print("📋 常见字ID映射:")
for char in test_chars:
results = query_engine.query_by_char(char, limit=1)
if results:
char_id, pinyin, count = results[0]
print(
f" '{char}' -> ID: {char_id}, 拼音: {pinyin}, 频次: {count:,}"
)
else:
print(f" '{char}' -> 未找到")
# 检查ID范围
all_ids = list(query_engine._id_to_info.keys())
if all_ids:
print("\n📊 ID范围统计:")
print(f" 最小ID: {min(all_ids)}")
print(f" 最大ID: {max(all_ids)}")
print(f" ID总数: {len(all_ids)}")
# 检查ID是否连续
sorted_ids = sorted(all_ids)
gaps = []
for i in range(1, len(sorted_ids)):
if sorted_ids[i] - sorted_ids[i - 1] > 1:
gaps.append((sorted_ids[i - 1], sorted_ids[i]))
if gaps:
print(f" ⚠️ 发现ID间隙: {len(gaps)}")
for i, (prev, curr) in enumerate(gaps[:3]):
print(f" 间隙{i + 1}: {prev} -> {curr} (差: {curr - prev})")
else:
print(" ✅ ID基本连续")
else:
print(f"❌ 统计文件不存在: {stats_path}")
except Exception as e:
print(f"❌ ID映射测试失败: {e}")
def main():
if len(sys.argv) < 2:
print("使用方法: python test_cpu_inference.py <checkpoint_path>")
print("示例: python test_cpu_inference.py ~/下载/best_model.pt")
return
checkpoint_path = sys.argv[1]
# 测试设备转换
state_dict, full_checkpoint = test_device_conversion(checkpoint_path)
if state_dict is None:
print("\n❌ 设备转换失败,无法继续测试")
return
# 测试模型创建
model = test_model_creation(state_dict, full_checkpoint)
if model is None:
print("\n❌ 模型创建失败,无法继续测试")
return
# 测试前向传播
success = test_forward_pass(model)
if success:
print("\n✅ CPU推理测试通过!")
else:
print("\n❌ CPU推理测试失败")
# 测试ID映射
test_id_mapping()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)
if __name__ == "__main__":
main()