feat: 添加模型权重检查与推理调试工具脚本
This commit is contained in:
parent
a203e67aff
commit
2f0166c8ce
|
|
@ -0,0 +1,147 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
快速检查模型权重加载情况的脚本
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def analyze_checkpoint(checkpoint_path):
|
||||
"""分析checkpoint文件"""
|
||||
print(f"🔍 分析checkpoint: {checkpoint_path}")
|
||||
|
||||
if not Path(checkpoint_path).exists():
|
||||
print(f"❌ 文件不存在")
|
||||
return
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
print(f"✅ 加载成功")
|
||||
print(f" 类型: {type(checkpoint)}")
|
||||
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f" 键名: {list(checkpoint.keys())}")
|
||||
|
||||
# 找到模型状态字典
|
||||
state_dict = None
|
||||
if "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
print(f" 🔍 使用'model_state_dict'键")
|
||||
elif "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
print(f" 🔍 使用'state_dict'键")
|
||||
else:
|
||||
# 可能是直接的状态字典
|
||||
state_dict = checkpoint
|
||||
print(f" 🔍 使用直接状态字典")
|
||||
|
||||
if state_dict:
|
||||
print(f" 总权重数: {len(state_dict)}")
|
||||
|
||||
# 分析分类头权重
|
||||
classifier_keys = []
|
||||
for key in state_dict.keys():
|
||||
if "classifier" in key:
|
||||
classifier_keys.append(key)
|
||||
|
||||
if classifier_keys:
|
||||
print(f" 📊 分类头相关权重:")
|
||||
for key in classifier_keys:
|
||||
weight = state_dict[key]
|
||||
print(f" {key}: shape={weight.shape}")
|
||||
print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]")
|
||||
print(f" 均值: {weight.mean():.6f}")
|
||||
print(f" 标准差: {weight.std():.6f}")
|
||||
|
||||
# 检查权重是否接近随机初始化
|
||||
if weight.std() < 0.01:
|
||||
print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练")
|
||||
|
||||
# 检查模型架构键名
|
||||
print(f"\n 🔑 模型架构键名示例(前20个):")
|
||||
for i, key in enumerate(list(state_dict.keys())[:20]):
|
||||
weight = state_dict[key]
|
||||
print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}")
|
||||
|
||||
# 检查是否有预期的组件
|
||||
expected_components = [
|
||||
"context_encoder",
|
||||
"slot_memory",
|
||||
"cross_attn",
|
||||
"moe",
|
||||
"classifier",
|
||||
]
|
||||
found_components = []
|
||||
for comp in expected_components:
|
||||
found = any(comp in key for key in state_dict.keys())
|
||||
if found:
|
||||
found_components.append(comp)
|
||||
|
||||
print(f"\n 📋 找到的模型组件: {found_components}")
|
||||
missing = set(expected_components) - set(found_components)
|
||||
if missing:
|
||||
print(f" ❌ 缺失的组件: {missing}")
|
||||
|
||||
return state_dict
|
||||
else:
|
||||
print(f"❌ checkpoint不是字典类型")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 加载失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def check_weight_distribution(state_dict):
|
||||
"""检查权重分布"""
|
||||
print(f"\n📊 权重分布统计:")
|
||||
|
||||
weight_stats = []
|
||||
for key, weight in state_dict.items():
|
||||
if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置
|
||||
stats = {
|
||||
"key": key,
|
||||
"shape": weight.shape,
|
||||
"min": weight.min().item(),
|
||||
"max": weight.max().item(),
|
||||
"mean": weight.mean().item(),
|
||||
"std": weight.std().item(),
|
||||
"abs_mean": weight.abs().mean().item(),
|
||||
}
|
||||
weight_stats.append(stats)
|
||||
|
||||
# 打印前10个权重
|
||||
for i, stats in enumerate(weight_stats[:10]):
|
||||
print(f" {i + 1:2d}. {stats['key']:40}")
|
||||
print(f" 形状: {stats['shape']}")
|
||||
print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]")
|
||||
print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}")
|
||||
|
||||
# 检查是否接近随机初始化
|
||||
if stats["std"] < 0.01:
|
||||
print(f" ⚠️ 警告: 标准差很小,可能未训练")
|
||||
|
||||
return weight_stats
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("使用方法: python check_weights.py <checkpoint_path>")
|
||||
print("示例: python check_weights.py ./output/checkpoints/best_model.pt")
|
||||
return
|
||||
|
||||
checkpoint_path = sys.argv[1]
|
||||
state_dict = analyze_checkpoint(checkpoint_path)
|
||||
|
||||
if state_dict:
|
||||
check_weight_distribution(state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
输入法模型推理调试脚本
|
||||
用于诊断为什么模型预测结果异常
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def debug_model_checkpoint(checkpoint_path: str):
|
||||
"""调试checkpoint文件"""
|
||||
print(f"\n🔍 调试checkpoint: {checkpoint_path}")
|
||||
|
||||
if not Path(checkpoint_path).exists():
|
||||
print(f"❌ Checkpoint文件不存在: {checkpoint_path}")
|
||||
return None
|
||||
|
||||
# 加载checkpoint
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
print(f"Checkpoint类型: {type(checkpoint)}")
|
||||
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"Checkpoint键名: {list(checkpoint.keys())}")
|
||||
|
||||
# 检查是否有模型状态字典
|
||||
if "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
print(f"使用'model_state_dict'键,包含{len(state_dict)}个权重")
|
||||
elif "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
print(f"使用'state_dict'键,包含{len(state_dict)}个权重")
|
||||
else:
|
||||
# 可能是直接的状态字典
|
||||
state_dict = checkpoint
|
||||
print(f"使用直接状态字典,包含{len(state_dict)}个权重")
|
||||
|
||||
# 打印前10个权重键名和形状
|
||||
print("\n前10个权重键名和形状:")
|
||||
for i, (key, value) in enumerate(list(state_dict.items())[:10]):
|
||||
print(f" {i + 1}. {key}: {value.shape}")
|
||||
|
||||
# 特别检查分类头
|
||||
classifier_keys = [k for k in state_dict.keys() if "classifier" in k]
|
||||
if classifier_keys:
|
||||
print(f"\n分类头相关权重:")
|
||||
for key in classifier_keys:
|
||||
weight = state_dict[key]
|
||||
print(
|
||||
f" {key}: shape={weight.shape}, range=[{weight.min():.4f}, {weight.max():.4f}]"
|
||||
)
|
||||
else:
|
||||
print(f"\n⚠️ 未找到分类头权重")
|
||||
|
||||
return state_dict
|
||||
else:
|
||||
print(f"❌ Checkpoint不是字典类型: {type(checkpoint)}")
|
||||
return None
|
||||
|
||||
|
||||
def debug_input_preparation():
|
||||
"""调试输入数据准备"""
|
||||
print("\n🔍 调试输入数据准备")
|
||||
|
||||
# 测试text_to_pinyin_ids函数
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
|
||||
test_pinyins = ["tian", "shang", "ha", "ni", "hao"]
|
||||
for pinyin in test_pinyins:
|
||||
ids = text_to_pinyin_ids(pinyin)
|
||||
print(f"拼音 '{pinyin}' -> ID列表: {ids}")
|
||||
|
||||
# 测试tokenizer
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer_path = (
|
||||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||
print(f"\n✅ Tokenizer加载成功,词汇表大小: {tokenizer.vocab_size}")
|
||||
|
||||
# 测试tokenizer
|
||||
test_texts = ["今天天气", "我们去公园", "张三|李四|今天天气"]
|
||||
for text in test_texts:
|
||||
encoded = tokenizer(
|
||||
text, max_length=128, truncation=True, return_tensors="pt"
|
||||
)
|
||||
print(f"文本 '{text}' -> input_ids形状: {encoded['input_ids'].shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ Tokenizer加载失败: {e}")
|
||||
|
||||
|
||||
def debug_query_engine():
|
||||
"""调试查询引擎"""
|
||||
print("\n🔍 调试查询引擎")
|
||||
|
||||
try:
|
||||
from src.model.query import QueryEngine
|
||||
|
||||
query_engine = QueryEngine()
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
|
||||
if stats_path.exists():
|
||||
query_engine.load(stats_path)
|
||||
print(f"✅ 查询引擎加载成功")
|
||||
|
||||
# 测试字符查询
|
||||
test_chars = ["天", "上", "人", "好", "的"]
|
||||
for char in test_chars:
|
||||
results = query_engine.query_by_char(char, limit=3)
|
||||
if results:
|
||||
print(f"字符 '{char}': {[(r[0], r[1], r[2]) for r in results[:3]]}")
|
||||
else:
|
||||
print(f"字符 '{char}': 无结果")
|
||||
|
||||
# 测试拼音查询
|
||||
test_pinyins = ["tian", "shang", "ren", "hao", "de"]
|
||||
for pinyin in test_pinyins:
|
||||
results = query_engine.query_by_pinyin(pinyin, limit=3)
|
||||
if results:
|
||||
print(
|
||||
f"拼音 '{pinyin}': {[(r[0], r[1], r[2]) for r in results[:3]]}"
|
||||
)
|
||||
else:
|
||||
print(f"拼音 '{pinyin}': 无结果")
|
||||
else:
|
||||
print(f"❌ 统计文件不存在: {stats_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 查询引擎调试失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def create_minimal_test_model():
|
||||
"""创建最小测试模型,检查基本功能"""
|
||||
print("\n🔍 创建最小测试模型")
|
||||
|
||||
try:
|
||||
from src.model.model import InputMethodEngine
|
||||
|
||||
# 创建小模型
|
||||
model = InputMethodEngine(
|
||||
vocab_size=100,
|
||||
pinyin_vocab_size=30,
|
||||
dim=64, # 小维度测试
|
||||
num_slots=8,
|
||||
n_layers=2,
|
||||
n_heads=2,
|
||||
num_experts=4,
|
||||
max_seq_len=64,
|
||||
compile=False,
|
||||
)
|
||||
|
||||
print(f"✅ 最小模型创建成功")
|
||||
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f" 分类头形状: {model.classifier.weight.shape}")
|
||||
|
||||
# 测试前向传播
|
||||
batch_size = 2
|
||||
input_ids = torch.randint(0, 100, (batch_size, 64))
|
||||
token_type_ids = torch.zeros((batch_size, 64), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, 64), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, 24))
|
||||
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
pinyin_ids=pinyin_ids,
|
||||
history_slot_ids=history_slot_ids,
|
||||
)
|
||||
|
||||
print(f" 前向传播成功,logits形状: {logits.shape}")
|
||||
print(f" logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
|
||||
|
||||
# 检查输出
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
print(f" 概率范围: [{probs.min():.4f}, {probs.max():.4f}]")
|
||||
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 最小模型测试失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""主调试函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="输入法模型推理调试")
|
||||
parser.add_argument("--checkpoint", type=str, help="模型checkpoint路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("输入法模型推理调试工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 调试checkpoint
|
||||
if args.checkpoint:
|
||||
state_dict = debug_model_checkpoint(args.checkpoint)
|
||||
else:
|
||||
print("\n⚠️ 未提供checkpoint路径,跳过checkpoint调试")
|
||||
state_dict = None
|
||||
|
||||
# 调试输入准备
|
||||
debug_input_preparation()
|
||||
|
||||
# 调试查询引擎
|
||||
debug_query_engine()
|
||||
|
||||
# 创建最小测试模型
|
||||
model = create_minimal_test_model()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("调试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,602 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
输入法模型推理脚本
|
||||
|
||||
使用方法:
|
||||
python inference.py --checkpoint ./output/checkpoints/best_model.pt
|
||||
|
||||
交互模式: 分步询问输入
|
||||
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
|
||||
2. 光标前文本: 光标前的连续文本
|
||||
3. 光标后文本: 光标后的连续文本
|
||||
4. 拼音: 当前输入的拼音
|
||||
5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上")
|
||||
|
||||
示例场景:
|
||||
输入"shanghai"已确认"上",继续输入"tian"
|
||||
上下文提示: 张三,李四
|
||||
光标前文本: 今天天气很好
|
||||
光标后文本: 我们去公园玩
|
||||
拼音: tian
|
||||
槽位历史: 上
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
from src.model.model import InputMethodEngine
|
||||
from src.model.query import QueryEngine
|
||||
|
||||
|
||||
class InputMethodInference:
|
||||
"""输入法模型推理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
self.device = torch.device(device)
|
||||
self.checkpoint_path = checkpoint_path
|
||||
|
||||
# 加载组件
|
||||
print(f"正在加载模型从: {checkpoint_path}")
|
||||
self.load_model()
|
||||
|
||||
# 加载tokenizer
|
||||
print("正在加载tokenizer...")
|
||||
self.load_tokenizer()
|
||||
|
||||
# 加载查询引擎
|
||||
print("正在加载查询引擎...")
|
||||
self.load_query_engine()
|
||||
|
||||
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
||||
|
||||
def load_model(self):
|
||||
"""加载训练好的模型"""
|
||||
# 创建模型实例(不编译)
|
||||
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
# 加载checkpoint
|
||||
# 加载训练好的权重(强制先加载到CPU,再移动到目标设备)
|
||||
# 这样确保GPU训练的权重能正确转换到CPU
|
||||
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
||||
if "model_state_dict" in checkpoint:
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
self.model.eval()
|
||||
self.model.to(self.device)
|
||||
print(
|
||||
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
|
||||
)
|
||||
|
||||
def load_tokenizer(self):
|
||||
"""加载tokenizer"""
|
||||
try:
|
||||
# 从assets/tokenizer加载tokenizer
|
||||
tokenizer_path = (
|
||||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法加载tokenizer: {e}")
|
||||
print("使用默认的bert-base-chinese tokenizer")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||||
|
||||
def load_query_engine(self):
|
||||
"""加载查询引擎用于字符-ID转换"""
|
||||
try:
|
||||
self.query_engine = QueryEngine()
|
||||
# 加载默认的统计文件
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
if stats_path.exists():
|
||||
self.query_engine.load(stats_path)
|
||||
print(
|
||||
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
|
||||
)
|
||||
else:
|
||||
print(f"⚠️ 统计文件不存在: {stats_path}")
|
||||
self.query_engine = None
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法加载查询引擎: {e}")
|
||||
self.query_engine = None
|
||||
|
||||
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
|
||||
"""将汉字转换为ID,如果提供拼音则更精确"""
|
||||
# 处理结束符
|
||||
if char == "//":
|
||||
return 0 # 假设0是结束符ID
|
||||
|
||||
if self.query_engine is None:
|
||||
# 简单回退:使用unicode编码
|
||||
return ord(char) if len(char) == 1 else 0
|
||||
|
||||
try:
|
||||
if pinyin is not None:
|
||||
# 使用精确的字符-拼音对
|
||||
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
|
||||
if info:
|
||||
return info.id
|
||||
|
||||
# 回退:获取字符的第一个拼音变体
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results:
|
||||
return results[0][0] # 返回ID
|
||||
|
||||
return 0
|
||||
except:
|
||||
return 0
|
||||
|
||||
def id_to_char(self, id: int) -> str:
|
||||
"""将ID转换为汉字"""
|
||||
# 处理结束符ID (假设0是结束符)
|
||||
if id == 0:
|
||||
return "//"
|
||||
|
||||
if self.query_engine is None:
|
||||
return chr(id) if id < 0x110000 else "<?>"
|
||||
|
||||
try:
|
||||
info = self.query_engine.query_by_id(id)
|
||||
return info.char if info else f"[ID:{id}]"
|
||||
except:
|
||||
return f"[ID:{id}]"
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
context_prompts: List[str],
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
slot_chars: List[str],
|
||||
max_seq_len: int = 128,
|
||||
):
|
||||
"""
|
||||
准备模型输入
|
||||
|
||||
Args:
|
||||
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
|
||||
text_before: 光标前文本
|
||||
text_after: 光标后文本
|
||||
pinyin: 当前输入的拼音
|
||||
slot_chars: 槽位内的汉字列表(用户已确认的输入历史)
|
||||
max_seq_len: 最大序列长度
|
||||
|
||||
Returns:
|
||||
模型输入字典
|
||||
"""
|
||||
|
||||
# 1. 构建tokenizer输入
|
||||
# 根据dataset.py,格式为: "part4|part1" 和 part3
|
||||
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
||||
# part1: text_before
|
||||
# part3: text_after
|
||||
|
||||
# 处理上下文提示
|
||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||
|
||||
# 构建输入文本
|
||||
if context_text:
|
||||
input_text = f"{context_text}|{text_before}"
|
||||
else:
|
||||
input_text = text_before
|
||||
|
||||
# 2. Tokenize
|
||||
encoded = self.tokenizer(
|
||||
input_text,
|
||||
text_after,
|
||||
max_length=max_seq_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
# 3. 处理拼音输入
|
||||
pinyin_ids = text_to_pinyin_ids(pinyin)
|
||||
if len(pinyin_ids) < 24:
|
||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long)
|
||||
|
||||
# 4. 处理历史槽位(用户已确认的输入历史)
|
||||
history_slot_ids = []
|
||||
for char in slot_chars:
|
||||
# 为每个槽位汉字查找ID(用户已确认的输入历史)
|
||||
char_id = self.char_to_id(char)
|
||||
history_slot_ids.append(char_id)
|
||||
|
||||
# 填充到8个槽位
|
||||
if len(history_slot_ids) < 8:
|
||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||
else:
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
history_tensor = torch.tensor([history_slot_ids], dtype=torch.long)
|
||||
|
||||
# 5. 移动到设备
|
||||
inputs = {
|
||||
"input_ids": encoded["input_ids"].to(self.device),
|
||||
"token_type_ids": encoded["token_type_ids"].to(self.device),
|
||||
"attention_mask": encoded["attention_mask"].to(self.device),
|
||||
"pinyin_ids": pinyin_tensor.to(self.device),
|
||||
"history_slot_ids": history_tensor.to(self.device),
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def predict(
|
||||
self,
|
||||
context_prompts: List[str],
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
slot_chars: List[str],
|
||||
top_k: int = 20,
|
||||
) -> Tuple[List[Tuple[str, float, int]], float]:
|
||||
"""
|
||||
执行推理
|
||||
|
||||
Args:
|
||||
context_prompts: 上下文提示(专有词汇、姓名等,用|分隔)
|
||||
text_before: 光标前文本
|
||||
text_after: 光标后文本
|
||||
pinyin: 当前输入的拼音
|
||||
slot_chars: 槽位内的汉字列表(用户已确认的输入历史,最大8个)
|
||||
top_k: 返回top-k个预测结果
|
||||
|
||||
Returns:
|
||||
(predictions, inference_time_ms)
|
||||
predictions: List[Tuple[char, score, id]]
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# 准备输入
|
||||
inputs = self.prepare_inputs(
|
||||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||||
)
|
||||
|
||||
# 调试信息:打印输入形状
|
||||
print("\n🔍 调试信息 - 输入检查:")
|
||||
for key, tensor in inputs.items():
|
||||
print(
|
||||
f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}"
|
||||
)
|
||||
if key in ["history_slot_ids", "pinyin_ids"]:
|
||||
print(f" 值: {tensor.cpu().numpy().tolist()}")
|
||||
|
||||
# 特别检查拼音和槽位输入
|
||||
print("\n🔍 拼音输入详细分析:")
|
||||
pinyin_tensor = inputs["pinyin_ids"]
|
||||
pinyin_values = pinyin_tensor.cpu().numpy()[0]
|
||||
# 将拼音ID转换回字符
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
|
||||
# 逆向转换可能需要一个反向映射,这里先简单显示
|
||||
print(f" 拼音ID列表: {pinyin_values}")
|
||||
print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}")
|
||||
|
||||
print("\n🔍 槽位历史详细分析:")
|
||||
slot_tensor = inputs["history_slot_ids"]
|
||||
slot_values = slot_tensor.cpu().numpy()[0]
|
||||
print(f" 槽位ID列表: {slot_values}")
|
||||
print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}")
|
||||
# 将ID转换为汉字
|
||||
slot_chars_converted = [self.id_to_char(id) for id in slot_values]
|
||||
print(f" 槽位汉字: {slot_chars_converted}")
|
||||
|
||||
# 调试:检查模型权重
|
||||
print(f"\n🔍 调试信息 - 模型检查:")
|
||||
print(f" 模型设备: {self.device}")
|
||||
print(f" 模型是否训练模式: {self.model.training}")
|
||||
|
||||
# 检查模型词汇表大小
|
||||
vocab_size = self.model.vocab_size
|
||||
print(f" 模型词汇表大小: {vocab_size}")
|
||||
|
||||
# 检查前5个ID对应的字符
|
||||
print(f" 前5个ID对应的字符:")
|
||||
for i in range(1, 6):
|
||||
char = self.id_to_char(i)
|
||||
print(f" ID {i}: '{char}'")
|
||||
|
||||
# 检查分类头权重
|
||||
# 调试:检查分类头权重
|
||||
classifier_weight = self.model.classifier.weight.data
|
||||
classifier_bias = self.model.classifier.bias.data
|
||||
print(f" 分类头权重形状: {classifier_weight.shape}")
|
||||
print(
|
||||
f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]"
|
||||
)
|
||||
print(f" 分类头权重均值: {classifier_weight.mean():.4f}")
|
||||
print(f" 分类头权重标准差: {classifier_weight.std():.4f}")
|
||||
print(f" 分类头偏置形状: {classifier_bias.shape}")
|
||||
print(
|
||||
f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]"
|
||||
)
|
||||
print(f" 分类头偏置均值: {classifier_bias.mean():.4f}")
|
||||
print(f" 分类头偏置标准差: {classifier_bias.std():.4f}")
|
||||
|
||||
# 推理(CPU推理时禁用混合精度)
|
||||
with torch.no_grad():
|
||||
if self.device.type == "cuda":
|
||||
with torch.autocast(device_type="cuda"):
|
||||
logits = self.model(**inputs)
|
||||
else:
|
||||
# CPU推理时不使用autocast
|
||||
logits = self.model(**inputs)
|
||||
|
||||
# 调试:检查logits
|
||||
print(f"\n🔍 调试信息 - 输出检查:")
|
||||
print(f" Logits形状: {logits.shape}")
|
||||
print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]")
|
||||
print(f" Logits均值: {logits.mean():.4f}")
|
||||
|
||||
# 检查logits中最大值和对应的ID
|
||||
max_val, max_idx = torch.max(logits, dim=-1)
|
||||
print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}")
|
||||
|
||||
# 获取top-k预测
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
|
||||
|
||||
# 调试:检查概率分布
|
||||
print(f" 概率总和: {probs.sum().item():.4f}")
|
||||
top_probs_array = top_probs.cpu().numpy().flatten()
|
||||
top_indices_array = top_indices.cpu().numpy().flatten()
|
||||
print(f" Top-{top_k}概率: {top_probs_array}")
|
||||
print(f" Top-{top_k} ID: {top_indices_array}")
|
||||
|
||||
# 检查概率是否均匀分布
|
||||
print(f" 概率分布分析:")
|
||||
print(f" 平均概率: {probs.mean().item():.6f}")
|
||||
print(f" 最大概率: {probs.max().item():.6f}")
|
||||
print(f" 最小概率: {probs.min().item():.6f}")
|
||||
print(f" 标准差: {probs.std().item():.6f}")
|
||||
|
||||
# 检查top-20概率是否都很小
|
||||
if top_probs_array[0] < 0.01:
|
||||
print(
|
||||
f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练"
|
||||
)
|
||||
print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配")
|
||||
|
||||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# 转换为可读结果
|
||||
predictions = []
|
||||
for i in range(top_k):
|
||||
idx = int(top_indices[0, i].item())
|
||||
prob = top_probs[0, i].item()
|
||||
char = self.id_to_char(idx)
|
||||
predictions.append((char, prob, idx))
|
||||
|
||||
return predictions, inference_time_ms
|
||||
|
||||
def interactive_mode(self):
|
||||
"""交互式推理模式 - 分步询问输入"""
|
||||
print("\n" + "=" * 60)
|
||||
print("输入法模型推理 - 交互模式")
|
||||
print("=" * 60)
|
||||
print("说明:")
|
||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||
print(" - 光标前文本: 光标前的连续文本")
|
||||
print(" - 光标后文本: 光标后的连续文本")
|
||||
print(" - 拼音: 当前输入的拼音")
|
||||
print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认'上'")
|
||||
print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出")
|
||||
print("-" * 60)
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("\n" + "=" * 60)
|
||||
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
||||
print("格式: 用逗号分隔多个词汇,可为空")
|
||||
print("示例: 张三,李四,北京大学")
|
||||
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
|
||||
|
||||
if context_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
# 解析上下文提示
|
||||
context_prompts = []
|
||||
if context_input:
|
||||
context_prompts = [
|
||||
item.strip()
|
||||
for item in context_input.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
|
||||
print(
|
||||
f"✅ 已记录上下文提示: {context_prompts if context_prompts else '无'}"
|
||||
)
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("第2步: 光标前文本")
|
||||
print("说明: 光标前的连续文本内容")
|
||||
print("示例: 今天天气很好")
|
||||
text_before = input("请输入光标前文本: ").strip()
|
||||
|
||||
if text_before.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
print(f"✅ 已记录光标前文本: '{text_before}'")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("第3步: 光标后文本")
|
||||
print("说明: 光标后的连续文本内容")
|
||||
print("示例: 我们去公园玩")
|
||||
text_after = input("请输入光标后文本: ").strip()
|
||||
|
||||
if text_after.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
print(f"✅ 已记录光标后文本: '{text_after}'")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("第4步: 拼音输入")
|
||||
print("说明: 当前正在输入的拼音")
|
||||
print("示例: tian, shang, hao")
|
||||
pinyin = input("请输入拼音: ").strip()
|
||||
|
||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
print(f"✅ 已记录拼音: '{pinyin}'")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
print("第5步: 槽位历史(已确认的输入)")
|
||||
print("说明: 用户已确认的输入历史,用逗号分隔")
|
||||
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
||||
print(" 今天,天气 (表示已确认两个词)")
|
||||
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
|
||||
|
||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
# 解析槽位历史
|
||||
slot_chars = []
|
||||
if slot_input:
|
||||
slot_chars = [
|
||||
char.strip() for char in slot_input.split(",") if char.strip()
|
||||
]
|
||||
|
||||
print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else '无'}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📝 输入汇总:")
|
||||
print(f" 上下文提示: {context_prompts if context_prompts else '无'}")
|
||||
print(f" 光标前文本: '{text_before}'")
|
||||
print(f" 光标后文本: '{text_after}'")
|
||||
print(f" 拼音: '{pinyin}'")
|
||||
print(f" 槽位历史: {slot_chars if slot_chars else '无'}")
|
||||
|
||||
# 执行推理
|
||||
print("\n🔮 推理中...")
|
||||
predictions, inference_time = self.predict(
|
||||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||||
)
|
||||
|
||||
# 显示结果
|
||||
print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)")
|
||||
print("\n🏆 Top-20 预测结果:")
|
||||
print("-" * 50)
|
||||
for i, (char, prob, idx) in enumerate(predictions):
|
||||
if char == "//":
|
||||
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
|
||||
else:
|
||||
print(
|
||||
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
|
||||
)
|
||||
|
||||
# 显示原始拼音对应的可能汉字
|
||||
if pinyin and self.query_engine:
|
||||
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
|
||||
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
|
||||
if pinyin_results:
|
||||
for j, (pid, char, count) in enumerate(pinyin_results):
|
||||
print(f" {char} (频次: {count:,})")
|
||||
else:
|
||||
print(" (无匹配结果)")
|
||||
|
||||
# 询问是否继续
|
||||
print("\n" + "-" * 40)
|
||||
continue_input = input("是否继续推理?(y/n): ").strip().lower()
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n退出交互模式")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ 推理出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# 询问是否继续
|
||||
continue_input = input("\n是否继续?(y/n): ").strip().lower()
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="输入法模型推理")
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, required=True, help="模型checkpoint路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="推理设备 (默认: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive", action="store_true", default=True, help="交互模式 (默认: True)"
|
||||
)
|
||||
parser.add_argument("--test", action="store_true", help="运行测试推理")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 选择设备
|
||||
if args.device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
|
||||
# 初始化推理器
|
||||
inference = InputMethodInference(args.checkpoint, device)
|
||||
|
||||
# 测试推理
|
||||
if args.test:
|
||||
print("\n🧪 运行测试推理...")
|
||||
print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'")
|
||||
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
|
||||
test_predictions, test_time = inference.predict(
|
||||
context_prompts=["张三", "李四"],
|
||||
text_before="今天天气",
|
||||
text_after="很好",
|
||||
pinyin="tian",
|
||||
slot_chars=["上"], # 用户已确认输入"上"
|
||||
)
|
||||
print(f"测试推理耗时: {test_time:.2f}ms")
|
||||
print(f"Top-5 结果:")
|
||||
for i, (char, prob, idx) in enumerate(test_predictions[:5]):
|
||||
if char == "//":
|
||||
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
|
||||
else:
|
||||
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
|
||||
|
||||
# 交互模式
|
||||
if args.interactive:
|
||||
inference.interactive_mode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -311,35 +311,54 @@ class MoELayer(nn.Module):
|
|||
out: [batch, seq_len, dim]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
num_tokens = B * L
|
||||
|
||||
# 展平输入以便处理
|
||||
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||
# 1. Compute Gating Scores
|
||||
gates = self.gate(x) # [B, L, num_experts]
|
||||
|
||||
# 1. 计算门控分数
|
||||
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||
# 2. Select Top-K Experts
|
||||
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
|
||||
|
||||
# 2. 选择 Top-K 专家
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||
# Normalize weights for selected experts
|
||||
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
|
||||
|
||||
# 归一化权重
|
||||
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||
# 3. Dispatch and Compute
|
||||
# Initialize output
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||
expert_outputs = torch.stack(
|
||||
[expert(x_flat) for expert in self.experts], dim=1
|
||||
) # [B*L, num_experts, D]
|
||||
# Reshape for easier processing: flatten batch and sequence dimensions
|
||||
x_flat = x.view(-1, D) # [B*L, D]
|
||||
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
|
||||
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
|
||||
|
||||
# 4. 使用 gather 选择对应专家的输出
|
||||
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||
selected_outputs = torch.gather(
|
||||
expert_outputs, 1, indices_expanded
|
||||
) # [B*L, K, D]
|
||||
# 5. 加权求和
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||
# For each of the top-k positions
|
||||
for k in range(self.top_k):
|
||||
# Get expert indices and weights for this position
|
||||
expert_indices = topk_indices_flat[:, k] # [B*L]
|
||||
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
|
||||
|
||||
# 恢复原始形状
|
||||
return out_flat.view(B, L, D)
|
||||
# Process each expert separately
|
||||
for e_idx in range(self.num_experts):
|
||||
# Mask for tokens assigned to this expert at position k
|
||||
mask = expert_indices == e_idx # [B*L]
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
# Extract tokens for this expert
|
||||
x_selected = x_flat[mask] # [N_selected, D]
|
||||
if x_selected.numel() == 0:
|
||||
continue
|
||||
|
||||
# Pass through expert
|
||||
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
|
||||
|
||||
# Apply expert weights and add to output
|
||||
weighted_out = expert_out * expert_weights[mask]
|
||||
|
||||
# Scatter back to flat output
|
||||
out_flat = out.view(-1, D)
|
||||
out_flat[mask] += weighted_out
|
||||
|
||||
# Reshape back to original shape
|
||||
out = out.view(B, L, D)
|
||||
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
import http.server
|
||||
import json
|
||||
import os
|
||||
import socketserver
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import webbrowser
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import typer
|
||||
|
||||
|
|
@ -335,6 +339,215 @@ def check_status(
|
|||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
def create_http_handler(status_file_path: str, enable_cors: bool = True):
|
||||
"""创建HTTP请求处理器"""
|
||||
|
||||
class TrainingStatusHTTPHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
# 只允许访问状态文件
|
||||
parsed_path = urlparse(self.path)
|
||||
if parsed_path.path not in ["/", "/training_status.json"]:
|
||||
self.send_error(404, "File not found")
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查文件是否存在(包括临时文件)
|
||||
main_file = status_file_path
|
||||
temp_file = f"{status_file_path}.tmp"
|
||||
|
||||
# 尝试读取数据,最多重试3次
|
||||
max_retries = 3
|
||||
retry_delay = 0.1 # 100ms
|
||||
content = None
|
||||
json_valid = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# 首先检查主文件是否存在
|
||||
if not os.path.exists(main_file):
|
||||
# 检查临时文件是否存在(可能正在写入)
|
||||
if os.path.exists(temp_file):
|
||||
# 如果只有临时文件存在,尝试读取临时文件
|
||||
file_to_read = temp_file
|
||||
else:
|
||||
# 两个文件都不存在
|
||||
break
|
||||
else:
|
||||
file_to_read = main_file
|
||||
|
||||
# 读取文件内容
|
||||
with open(file_to_read, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 验证JSON格式
|
||||
if content:
|
||||
json.loads(content) # 验证JSON格式
|
||||
json_valid = True
|
||||
break # JSON有效,跳出重试循环
|
||||
else:
|
||||
# 空内容,等待重试
|
||||
time.sleep(retry_delay)
|
||||
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
# JSON解析错误或IO错误,等待后重试
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
# 最后一次尝试也失败
|
||||
raise e
|
||||
|
||||
if not content or not json_valid:
|
||||
self.send_error(404, "Status file not found or invalid JSON")
|
||||
return
|
||||
|
||||
# 设置响应头
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.send_header("Content-Length", str(len(content)))
|
||||
|
||||
# 添加缓存控制头,避免浏览器缓存
|
||||
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
self.send_header("Pragma", "no-cache")
|
||||
self.send_header("Expires", "0")
|
||||
|
||||
# 添加CORS头
|
||||
if enable_cors:
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
self.end_headers()
|
||||
|
||||
# 发送内容
|
||||
self.wfile.write(content.encode("utf-8"))
|
||||
|
||||
except Exception as e:
|
||||
self.send_error(500, f"Internal server error: {str(e)}")
|
||||
|
||||
def do_OPTIONS(self):
|
||||
"""处理OPTIONS请求(用于CORS预检)"""
|
||||
self.send_response(200)
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""重写日志方法,减少日志输出"""
|
||||
# 可以选择性地记录日志
|
||||
# typer.echo(f"HTTP Server: {format % args}")
|
||||
pass
|
||||
|
||||
return TrainingStatusHTTPHandler
|
||||
|
||||
|
||||
def start_http_server(
|
||||
status_file: str,
|
||||
port: int,
|
||||
host: str,
|
||||
enable_cors: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
启动HTTP服务器
|
||||
|
||||
Args:
|
||||
status_file: 状态文件路径
|
||||
port: 端口号
|
||||
host: 主机地址
|
||||
enable_cors: 是否启用CORS
|
||||
|
||||
Returns:
|
||||
停止服务器的函数
|
||||
"""
|
||||
# 获取绝对路径
|
||||
status_file_path = os.path.abspath(status_file)
|
||||
|
||||
# 创建自定义处理器
|
||||
handler = create_http_handler(status_file_path, enable_cors)
|
||||
|
||||
# 创建服务器
|
||||
server = socketserver.TCPServer((host, port), handler)
|
||||
|
||||
# 在后台启动服务器
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
typer.echo(f"🌐 HTTP服务器已启动")
|
||||
typer.echo(f" 📁 状态文件: {status_file_path}")
|
||||
typer.echo(f" 🔗 访问地址: http://{host}:{port}/training_status.json")
|
||||
typer.echo(f" 🌍 CORS支持: {'已启用' if enable_cors else '已禁用'}")
|
||||
typer.echo("\n按 Ctrl+C 停止服务器\n")
|
||||
|
||||
# 返回停止函数
|
||||
def stop_server():
|
||||
typer.echo("\n🛑 正在停止HTTP服务器...")
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
typer.echo("✅ HTTP服务器已停止")
|
||||
|
||||
return stop_server
|
||||
|
||||
|
||||
@app.command(name="serve")
|
||||
def serve_status_file(
|
||||
status_file: str = typer.Option(
|
||||
"./output/training_status.json",
|
||||
"--status-file",
|
||||
"-s",
|
||||
help="训练状态JSON文件路径",
|
||||
),
|
||||
port: int = typer.Option(8080, "--port", "-p", help="HTTP服务端口号"),
|
||||
host: str = typer.Option("0.0.0.0", "--host", help="HTTP服务主机地址"),
|
||||
cors: bool = typer.Option(True, "--cors", help="启用CORS支持"),
|
||||
):
|
||||
"""
|
||||
启动HTTP服务,提供训练状态JSON文件访问
|
||||
|
||||
启动后可通过 http://<host>:<port>/training_status.json 访问数据
|
||||
"""
|
||||
# 检查状态文件是否存在
|
||||
if not os.path.exists(status_file):
|
||||
typer.echo(f"⚠️ 警告: 状态文件不存在: {status_file}")
|
||||
typer.echo("开始训练后,训练脚本会自动创建此文件。")
|
||||
typer.echo("您可以先启动HTTP服务,然后开始训练。")
|
||||
|
||||
# 创建目录(如果不存在)
|
||||
os.makedirs(os.path.dirname(status_file), exist_ok=True)
|
||||
|
||||
# 创建空的JSON文件
|
||||
with open(status_file, "w", encoding="utf-8") as f:
|
||||
json.dump([], f)
|
||||
typer.echo(f"✅ 已创建空状态文件: {status_file}")
|
||||
|
||||
try:
|
||||
# 启动HTTP服务器
|
||||
stop_server = start_http_server(
|
||||
status_file=status_file,
|
||||
port=port,
|
||||
host=host,
|
||||
enable_cors=cors,
|
||||
)
|
||||
|
||||
# 等待用户中断
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
stop_server()
|
||||
|
||||
except OSError as e:
|
||||
if "Address already in use" in str(e):
|
||||
typer.echo(f"❌ 错误: 端口 {port} 已被占用")
|
||||
typer.echo("请使用其他端口: --port <端口号>")
|
||||
else:
|
||||
typer.echo(f"❌ 启动HTTP服务器时出错: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
typer.echo(f"❌ 启动HTTP服务器时出错: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
app()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
|
@ -490,9 +491,13 @@ class Trainer:
|
|||
[self.training_status_data] if self.training_status_data else []
|
||||
)
|
||||
|
||||
# 写入文件
|
||||
with open(self.status_file, "w", encoding="utf-8") as f:
|
||||
# 使用原子写入避免读取不完整JSON
|
||||
# 先写入临时文件,然后原子重命名
|
||||
temp_file = Path(f"{self.status_file}.tmp")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.training_status_data, f, indent=2, ensure_ascii=False)
|
||||
# 原子重命名(Unix系统是原子操作)
|
||||
temp_file.rename(self.status_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write training status: {e}")
|
||||
|
|
|
|||
|
|
@ -6,10 +6,15 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
|
||||
# import plotly.express as px # 暂时未使用
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
# from plotly.subplots import make_subplots # 暂时未使用
|
||||
|
||||
# 用于HTTP URL支持
|
||||
# requests模块将在load_from_url函数中动态导入
|
||||
|
||||
# 添加项目路径到系统路径,以便导入模型
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
|
@ -95,7 +100,46 @@ st.markdown(
|
|||
|
||||
|
||||
def load_training_data(file_path):
|
||||
"""加载训练状态数据"""
|
||||
"""从本地文件或HTTP URL加载训练状态数据"""
|
||||
# 检查是否是HTTP/HTTPS URL
|
||||
if file_path.startswith(("http://", "https://")):
|
||||
return load_from_url(file_path)
|
||||
else:
|
||||
return load_from_local_file(file_path)
|
||||
|
||||
|
||||
def load_from_url(url):
|
||||
"""从HTTP URL加载数据"""
|
||||
# 动态导入requests模块
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
st.error(
|
||||
"requests库未安装,无法从HTTP URL加载数据。请安装:pip install requests"
|
||||
)
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# 设置合理的超时
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return convert_to_dataframe(data)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
st.warning(f"从URL加载数据失败: {e}")
|
||||
return pd.DataFrame()
|
||||
except json.JSONDecodeError:
|
||||
st.warning("远程返回的数据不是有效的JSON格式")
|
||||
return pd.DataFrame()
|
||||
except Exception as e:
|
||||
st.warning(f"从URL加载数据时发生错误: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
def load_from_local_file(file_path):
|
||||
"""从本地文件加载数据"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
st.warning(f"文件不存在: {file_path}")
|
||||
|
|
@ -104,6 +148,16 @@ def load_training_data(file_path):
|
|||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return convert_to_dataframe(data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return pd.DataFrame()
|
||||
except Exception:
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
def convert_to_dataframe(data):
|
||||
"""将数据转换为DataFrame,包含数据验证和清理"""
|
||||
# 检测是否是配置文件(检查是否有典型的配置键)
|
||||
config_keys = [
|
||||
"train_data_path",
|
||||
|
|
@ -113,6 +167,7 @@ def load_training_data(file_path):
|
|||
"batch_size",
|
||||
"num_epochs",
|
||||
]
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 检查是否是配置文件
|
||||
if any(key in data for key in config_keys):
|
||||
|
|
@ -169,11 +224,6 @@ def load_training_data(file_path):
|
|||
|
||||
return df
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return pd.DataFrame()
|
||||
except Exception:
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
def create_metric_card(label, value, delta=None, help_text=None):
|
||||
"""创建指标卡片"""
|
||||
|
|
@ -381,9 +431,9 @@ def main():
|
|||
"TRAINING_STATUS_FILE", "./output/training_status.json"
|
||||
)
|
||||
status_file = st.text_input(
|
||||
"状态文件路径",
|
||||
"状态文件路径或URL",
|
||||
value=default_status_file,
|
||||
help="训练过程中生成的JSON状态文件路径",
|
||||
help="可以是:\n1. 本地文件路径(如 ./output/training_status.json)\n2. HTTP/HTTPS URL(如 http://服务器IP:端口/training_status.json)",
|
||||
)
|
||||
|
||||
# 刷新间隔
|
||||
|
|
|
|||
103
test.py
103
test.py
|
|
@ -5,20 +5,11 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
from model.model import InputMethodEngine
|
||||
from model.trainer import collate_fn, worker_init_fn
|
||||
|
||||
# Try to import DataLoader2 from torchdata, fallback to standard DataLoader
|
||||
try:
|
||||
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
|
||||
|
||||
DATA_LOADER2_AVAILABLE = True
|
||||
print("✅ Using DataLoader2 from torchdata")
|
||||
except ImportError:
|
||||
DATA_LOADER2_AVAILABLE = False
|
||||
print("⚠️ torchdata not installed, falling back to standard DataLoader")
|
||||
|
||||
max_iter_length = 128 * 128
|
||||
batch_size = 1024
|
||||
max_iter_length = 5
|
||||
batch_size = 1
|
||||
|
||||
if sys.platform == "win32":
|
||||
dataset_path = "data"
|
||||
|
|
@ -29,40 +20,10 @@ dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
|||
|
||||
|
||||
def create_dataloader():
|
||||
"""
|
||||
Create dataloader with DataLoader2 if available, otherwise fallback to DataLoader.
|
||||
This function tries to handle streaming datasets better with DataLoader2.
|
||||
"""
|
||||
if DATA_LOADER2_AVAILABLE:
|
||||
try:
|
||||
# DataLoader2 configuration for streaming datasets
|
||||
# Use MultiProcessingReadingService with careful worker settings
|
||||
reading_service = MultiProcessingReadingService(
|
||||
num_workers=2, # Start with 2 workers for streaming dataset
|
||||
prefetch_factor=2, # Reduced prefetch for better memory management
|
||||
persistent_workers=True,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
dataloader = DataLoader2(
|
||||
dataset,
|
||||
reading_service=reading_service,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn,
|
||||
shuffle=False, # Dataset handles shuffling internally
|
||||
)
|
||||
print(f"✅ Created DataLoader2 with {2} workers")
|
||||
return dataloader
|
||||
except Exception as e:
|
||||
print(f"⚠️ DataLoader2 creation failed: {e}, falling back to DataLoader")
|
||||
|
||||
# Fallback to standard DataLoader
|
||||
print("📊 Using standard DataLoader")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2, # Limited to 2 for streaming dataset compatibility
|
||||
num_workers=1, # Limited to 2 for streaming dataset compatibility
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
|
|
@ -72,32 +33,52 @@ def create_dataloader():
|
|||
return dataloader
|
||||
|
||||
|
||||
samples = []
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = create_dataloader()
|
||||
|
||||
# Test the dataloader
|
||||
print(f"🔍 Testing dataloader with batch_size={batch_size}")
|
||||
print(f" Dataset max_iter_length: {max_iter_length}")
|
||||
print(f" Expected batches: {max_iter_length / batch_size:.0f}")
|
||||
|
||||
try:
|
||||
# Convert to list to test loading (as in original code)
|
||||
dataloader_list = list([i for i in dataloader])
|
||||
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
|
||||
|
||||
# Process batches
|
||||
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
|
||||
zero_labels = (line["labels"] == 0).sum()
|
||||
print(f"Batch {i}: labels==0 count = {zero_labels.item()}")
|
||||
# Early exit for testing
|
||||
if i >= 5: # Limit to 5 batches for quick testing
|
||||
print("⚠️ Limited to 5 batches for testing")
|
||||
break
|
||||
samples.append(line)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during dataloader iteration: {e}")
|
||||
import traceback
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
traceback.print_exc()
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
sample = samples[0]
|
||||
input_ids = sample["input_ids"]
|
||||
token_type_ids = sample["token_type_ids"]
|
||||
attention_mask = sample["attention_mask"]
|
||||
pinyin_ids = sample["pinyin_ids"]
|
||||
history_slot_ids = sample["history_slot_ids"]
|
||||
for k, v in sample.items():
|
||||
if isinstance(v, str):
|
||||
print(f"{k}: {v}")
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
|
||||
print(sort_res[0:5])
|
||||
|
||||
print("🏁 Test completed")
|
||||
# 在test.py的res计算后添加:
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 计算softmax概率
|
||||
probs = F.softmax(res, dim=-1)
|
||||
|
||||
print(f"\n📊 概率分布分析:")
|
||||
print(f" 形状: {probs.shape}")
|
||||
print(f" 总概率和: {probs.sum().item():.6f}")
|
||||
print(f" 最大概率: {probs.max().item():.6f}")
|
||||
print(f" 最小概率: {probs.min().item():.6f}")
|
||||
print(f" 平均概率: {probs.mean().item():.6f}")
|
||||
|
||||
# 获取top-20概率
|
||||
top_probs, top_indices = torch.topk(probs, k=20)
|
||||
print(f"\n🏆 Top-20预测:")
|
||||
for i in range(20):
|
||||
idx = top_indices[0, i].item()
|
||||
prob = top_probs[0, i].item()
|
||||
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,289 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试GPU训练的模型在CPU上推理
|
||||
解决设备转换和权重加载问题
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def test_device_conversion(checkpoint_path):
|
||||
"""测试设备转换"""
|
||||
print("=" * 60)
|
||||
print("GPU->CPU设备转换测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 方法1:直接加载到CPU
|
||||
print("\n🔍 方法1: 直接加载到CPU")
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
print("✅ 直接加载到CPU成功")
|
||||
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
# 检查是否有CUDA tensor
|
||||
cuda_tensors = 0
|
||||
for key, tensor in state_dict.items():
|
||||
if tensor.is_cuda:
|
||||
cuda_tensors += 1
|
||||
|
||||
print(f" 模型状态字典包含CUDA tensor: {cuda_tensors}/{len(state_dict)}")
|
||||
|
||||
# 将权重移动到CPU
|
||||
cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
||||
print(" 已将所有权重移动到CPU")
|
||||
|
||||
return cpu_state_dict, checkpoint
|
||||
except Exception as e:
|
||||
print(f"❌ 方法1失败: {e}")
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def test_model_creation(state_dict, config=None):
|
||||
"""测试模型创建和权重加载"""
|
||||
print("\n" + "=" * 60)
|
||||
print("模型创建和权重加载测试")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from src.model.model import InputMethodEngine
|
||||
|
||||
# 从checkpoint获取配置或使用默认值
|
||||
if config and "config" in config:
|
||||
model_config = config["config"]
|
||||
print("📋 使用checkpoint中的配置:")
|
||||
for key, value in model_config.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
model = InputMethodEngine(
|
||||
vocab_size=model_config.get("vocab_size", 10019),
|
||||
pinyin_vocab_size=model_config.get("pinyin_vocab_size", 30),
|
||||
dim=model_config.get("dim", 512),
|
||||
num_slots=model_config.get("num_slots", 8),
|
||||
n_layers=model_config.get("n_layers", 4),
|
||||
n_heads=model_config.get("n_heads", 4),
|
||||
num_experts=model_config.get("num_experts", 20),
|
||||
max_seq_len=model_config.get("max_seq_len", 128),
|
||||
compile=False,
|
||||
)
|
||||
else:
|
||||
print("📋 使用默认配置")
|
||||
model = InputMethodEngine(compile=False)
|
||||
|
||||
print("✅ 模型创建成功")
|
||||
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f" 词汇表大小: {model.vocab_size}")
|
||||
|
||||
# 尝试加载权重
|
||||
print("\n🔄 加载权重...")
|
||||
try:
|
||||
model.load_state_dict(state_dict)
|
||||
print("✅ 权重加载成功 (strict=True)")
|
||||
except RuntimeError as e:
|
||||
print(f"⚠️ 严格模式加载失败: {e}")
|
||||
print("🔄 尝试非严格模式加载...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("✅ 权重加载成功 (strict=False)")
|
||||
|
||||
# 检查分类头
|
||||
print("\n📊 分类头检查:")
|
||||
classifier_weight = model.classifier.weight.data
|
||||
classifier_bias = model.classifier.bias.data
|
||||
print(f" 权重形状: {classifier_weight.shape}")
|
||||
print(
|
||||
f" 权重范围: [{classifier_weight.min():.6f}, {classifier_weight.max():.6f}]"
|
||||
)
|
||||
print(
|
||||
f" 权重均值: {classifier_weight.mean():.6f} ± {classifier_weight.std():.6f}"
|
||||
)
|
||||
print(f" 偏置形状: {classifier_bias.shape}")
|
||||
print(f" 偏置范围: [{classifier_bias.min():.6f}, {classifier_bias.max():.6f}]")
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"❌ 模型创建失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def test_forward_pass(model):
|
||||
"""测试前向传播"""
|
||||
print("\n" + "=" * 60)
|
||||
print("前向传播测试")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
model.eval()
|
||||
|
||||
# 创建简单的测试输入
|
||||
batch_size = 2
|
||||
seq_len = 64
|
||||
|
||||
# 使用随机但合理的输入
|
||||
input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)
|
||||
token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(1, 30, (batch_size, 24), dtype=torch.long) # 避免0
|
||||
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
|
||||
|
||||
print("📋 测试输入:")
|
||||
print(f" input_ids: {input_ids.shape}")
|
||||
print(f" token_type_ids: {token_type_ids.shape}")
|
||||
print(f" attention_mask: {attention_mask.shape}")
|
||||
print(f" pinyin_ids: {pinyin_ids.shape}")
|
||||
print(f" history_slot_ids: {history_slot_ids.shape}")
|
||||
|
||||
# 执行前向传播
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
pinyin_ids=pinyin_ids,
|
||||
history_slot_ids=history_slot_ids,
|
||||
)
|
||||
|
||||
print("\n✅ 前向传播成功")
|
||||
print(f" Logits形状: {logits.shape}")
|
||||
print(f" Logits范围: [{logits.min():.6f}, {logits.max():.6f}]")
|
||||
print(f" Logits均值: {logits.mean():.6f} ± {logits.std():.6f}")
|
||||
|
||||
# 检查输出是否合理
|
||||
if logits.abs().max() < 1e-6:
|
||||
print("⚠️ 警告: Logits值非常小,可能权重未正确加载")
|
||||
|
||||
# 计算softmax概率
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
print("\n📊 概率分布:")
|
||||
print(f" 概率总和: {probs.sum(dim=-1).mean().item():.6f} (应为1.0)")
|
||||
print(f" 最大概率: {probs.max().item():.6f}")
|
||||
print(f" 最小概率: {probs.min().item():.6f}")
|
||||
print(
|
||||
f" 平均概率: {probs.mean().item():.6f} (预期: ~{1.0 / logits.size(-1):.6f})"
|
||||
)
|
||||
|
||||
# 检查top-5预测
|
||||
top_probs, top_indices = torch.topk(probs, k=5, dim=-1)
|
||||
print("\n🏆 Batch 0的Top-5预测:")
|
||||
for i in range(5):
|
||||
print(
|
||||
f" {i + 1}. ID {top_indices[0, i].item()}: {top_probs[0, i].item():.6f}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ 前向传播失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_id_mapping():
|
||||
"""测试ID映射"""
|
||||
print("\n" + "=" * 60)
|
||||
print("ID映射测试")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from src.model.query import QueryEngine
|
||||
|
||||
query_engine = QueryEngine()
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
|
||||
if stats_path.exists():
|
||||
query_engine.load(stats_path)
|
||||
print("✅ 查询引擎加载成功")
|
||||
|
||||
# 测试一些常见字的ID
|
||||
test_chars = ["的", "是", "一", "天", "上", "人", "好"]
|
||||
print("📋 常见字ID映射:")
|
||||
for char in test_chars:
|
||||
results = query_engine.query_by_char(char, limit=1)
|
||||
if results:
|
||||
char_id, pinyin, count = results[0]
|
||||
print(
|
||||
f" '{char}' -> ID: {char_id}, 拼音: {pinyin}, 频次: {count:,}"
|
||||
)
|
||||
else:
|
||||
print(f" '{char}' -> 未找到")
|
||||
|
||||
# 检查ID范围
|
||||
all_ids = list(query_engine._id_to_info.keys())
|
||||
if all_ids:
|
||||
print("\n📊 ID范围统计:")
|
||||
print(f" 最小ID: {min(all_ids)}")
|
||||
print(f" 最大ID: {max(all_ids)}")
|
||||
print(f" ID总数: {len(all_ids)}")
|
||||
|
||||
# 检查ID是否连续
|
||||
sorted_ids = sorted(all_ids)
|
||||
gaps = []
|
||||
for i in range(1, len(sorted_ids)):
|
||||
if sorted_ids[i] - sorted_ids[i - 1] > 1:
|
||||
gaps.append((sorted_ids[i - 1], sorted_ids[i]))
|
||||
|
||||
if gaps:
|
||||
print(f" ⚠️ 发现ID间隙: {len(gaps)}处")
|
||||
for i, (prev, curr) in enumerate(gaps[:3]):
|
||||
print(f" 间隙{i + 1}: {prev} -> {curr} (差: {curr - prev})")
|
||||
else:
|
||||
print(" ✅ ID基本连续")
|
||||
else:
|
||||
print(f"❌ 统计文件不存在: {stats_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ ID映射测试失败: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("使用方法: python test_cpu_inference.py <checkpoint_path>")
|
||||
print("示例: python test_cpu_inference.py ~/下载/best_model.pt")
|
||||
return
|
||||
|
||||
checkpoint_path = sys.argv[1]
|
||||
|
||||
# 测试设备转换
|
||||
state_dict, full_checkpoint = test_device_conversion(checkpoint_path)
|
||||
|
||||
if state_dict is None:
|
||||
print("\n❌ 设备转换失败,无法继续测试")
|
||||
return
|
||||
|
||||
# 测试模型创建
|
||||
model = test_model_creation(state_dict, full_checkpoint)
|
||||
|
||||
if model is None:
|
||||
print("\n❌ 模型创建失败,无法继续测试")
|
||||
return
|
||||
|
||||
# 测试前向传播
|
||||
success = test_forward_pass(model)
|
||||
|
||||
if success:
|
||||
print("\n✅ CPU推理测试通过!")
|
||||
else:
|
||||
print("\n❌ CPU推理测试失败")
|
||||
|
||||
# 测试ID映射
|
||||
test_id_mapping()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue