290 lines
9.5 KiB
Python
290 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试GPU训练的模型在CPU上推理
|
||
解决设备转换和权重加载问题
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
|
||
|
||
def test_device_conversion(checkpoint_path):
|
||
"""测试设备转换"""
|
||
print("=" * 60)
|
||
print("GPU->CPU设备转换测试")
|
||
print("=" * 60)
|
||
|
||
# 方法1:直接加载到CPU
|
||
print("\n🔍 方法1: 直接加载到CPU")
|
||
try:
|
||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||
print("✅ 直接加载到CPU成功")
|
||
|
||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||
state_dict = checkpoint["model_state_dict"]
|
||
# 检查是否有CUDA tensor
|
||
cuda_tensors = 0
|
||
for key, tensor in state_dict.items():
|
||
if tensor.is_cuda:
|
||
cuda_tensors += 1
|
||
|
||
print(f" 模型状态字典包含CUDA tensor: {cuda_tensors}/{len(state_dict)}")
|
||
|
||
# 将权重移动到CPU
|
||
cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
||
print(" 已将所有权重移动到CPU")
|
||
|
||
return cpu_state_dict, checkpoint
|
||
except Exception as e:
|
||
print(f"❌ 方法1失败: {e}")
|
||
|
||
return None, None
|
||
|
||
|
||
def test_model_creation(state_dict, config=None):
|
||
"""测试模型创建和权重加载"""
|
||
print("\n" + "=" * 60)
|
||
print("模型创建和权重加载测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
from src.model.model import InputMethodEngine
|
||
|
||
# 从checkpoint获取配置或使用默认值
|
||
if config and "config" in config:
|
||
model_config = config["config"]
|
||
print("📋 使用checkpoint中的配置:")
|
||
for key, value in model_config.items():
|
||
print(f" {key}: {value}")
|
||
|
||
model = InputMethodEngine(
|
||
vocab_size=model_config.get("vocab_size", 10019),
|
||
pinyin_vocab_size=model_config.get("pinyin_vocab_size", 30),
|
||
dim=model_config.get("dim", 512),
|
||
num_slots=model_config.get("num_slots", 8),
|
||
n_layers=model_config.get("n_layers", 4),
|
||
n_heads=model_config.get("n_heads", 4),
|
||
num_experts=model_config.get("num_experts", 20),
|
||
max_seq_len=model_config.get("max_seq_len", 128),
|
||
compile=False,
|
||
)
|
||
else:
|
||
print("📋 使用默认配置")
|
||
model = InputMethodEngine(compile=False)
|
||
|
||
print("✅ 模型创建成功")
|
||
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||
print(f" 词汇表大小: {model.vocab_size}")
|
||
|
||
# 尝试加载权重
|
||
print("\n🔄 加载权重...")
|
||
try:
|
||
model.load_state_dict(state_dict)
|
||
print("✅ 权重加载成功 (strict=True)")
|
||
except RuntimeError as e:
|
||
print(f"⚠️ 严格模式加载失败: {e}")
|
||
print("🔄 尝试非严格模式加载...")
|
||
model.load_state_dict(state_dict, strict=False)
|
||
print("✅ 权重加载成功 (strict=False)")
|
||
|
||
# 检查分类头
|
||
print("\n📊 分类头检查:")
|
||
classifier_weight = model.classifier.weight.data
|
||
classifier_bias = model.classifier.bias.data
|
||
print(f" 权重形状: {classifier_weight.shape}")
|
||
print(
|
||
f" 权重范围: [{classifier_weight.min():.6f}, {classifier_weight.max():.6f}]"
|
||
)
|
||
print(
|
||
f" 权重均值: {classifier_weight.mean():.6f} ± {classifier_weight.std():.6f}"
|
||
)
|
||
print(f" 偏置形状: {classifier_bias.shape}")
|
||
print(f" 偏置范围: [{classifier_bias.min():.6f}, {classifier_bias.max():.6f}]")
|
||
|
||
return model
|
||
except Exception as e:
|
||
print(f"❌ 模型创建失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def test_forward_pass(model):
|
||
"""测试前向传播"""
|
||
print("\n" + "=" * 60)
|
||
print("前向传播测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
model.eval()
|
||
|
||
# 创建简单的测试输入
|
||
batch_size = 2
|
||
seq_len = 64
|
||
|
||
# 使用随机但合理的输入
|
||
input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)
|
||
token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long)
|
||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||
pinyin_ids = torch.randint(1, 30, (batch_size, 24), dtype=torch.long) # 避免0
|
||
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
|
||
|
||
print("📋 测试输入:")
|
||
print(f" input_ids: {input_ids.shape}")
|
||
print(f" token_type_ids: {token_type_ids.shape}")
|
||
print(f" attention_mask: {attention_mask.shape}")
|
||
print(f" pinyin_ids: {pinyin_ids.shape}")
|
||
print(f" history_slot_ids: {history_slot_ids.shape}")
|
||
|
||
# 执行前向传播
|
||
with torch.no_grad():
|
||
logits = model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
print("\n✅ 前向传播成功")
|
||
print(f" Logits形状: {logits.shape}")
|
||
print(f" Logits范围: [{logits.min():.6f}, {logits.max():.6f}]")
|
||
print(f" Logits均值: {logits.mean():.6f} ± {logits.std():.6f}")
|
||
|
||
# 检查输出是否合理
|
||
if logits.abs().max() < 1e-6:
|
||
print("⚠️ 警告: Logits值非常小,可能权重未正确加载")
|
||
|
||
# 计算softmax概率
|
||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||
print("\n📊 概率分布:")
|
||
print(f" 概率总和: {probs.sum(dim=-1).mean().item():.6f} (应为1.0)")
|
||
print(f" 最大概率: {probs.max().item():.6f}")
|
||
print(f" 最小概率: {probs.min().item():.6f}")
|
||
print(
|
||
f" 平均概率: {probs.mean().item():.6f} (预期: ~{1.0 / logits.size(-1):.6f})"
|
||
)
|
||
|
||
# 检查top-5预测
|
||
top_probs, top_indices = torch.topk(probs, k=5, dim=-1)
|
||
print("\n🏆 Batch 0的Top-5预测:")
|
||
for i in range(5):
|
||
print(
|
||
f" {i + 1}. ID {top_indices[0, i].item()}: {top_probs[0, i].item():.6f}"
|
||
)
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"❌ 前向传播失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def test_id_mapping():
|
||
"""测试ID映射"""
|
||
print("\n" + "=" * 60)
|
||
print("ID映射测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
from src.model.query import QueryEngine
|
||
|
||
query_engine = QueryEngine()
|
||
stats_path = (
|
||
Path(__file__).parent
|
||
/ "src"
|
||
/ "model"
|
||
/ "assets"
|
||
/ "pinyin_char_statistics.json"
|
||
)
|
||
|
||
if stats_path.exists():
|
||
query_engine.load(stats_path)
|
||
print("✅ 查询引擎加载成功")
|
||
|
||
# 测试一些常见字的ID
|
||
test_chars = ["的", "是", "一", "天", "上", "人", "好"]
|
||
print("📋 常见字ID映射:")
|
||
for char in test_chars:
|
||
results = query_engine.query_by_char(char, limit=1)
|
||
if results:
|
||
char_id, pinyin, count = results[0]
|
||
print(
|
||
f" '{char}' -> ID: {char_id}, 拼音: {pinyin}, 频次: {count:,}"
|
||
)
|
||
else:
|
||
print(f" '{char}' -> 未找到")
|
||
|
||
# 检查ID范围
|
||
all_ids = list(query_engine._id_to_info.keys())
|
||
if all_ids:
|
||
print("\n📊 ID范围统计:")
|
||
print(f" 最小ID: {min(all_ids)}")
|
||
print(f" 最大ID: {max(all_ids)}")
|
||
print(f" ID总数: {len(all_ids)}")
|
||
|
||
# 检查ID是否连续
|
||
sorted_ids = sorted(all_ids)
|
||
gaps = []
|
||
for i in range(1, len(sorted_ids)):
|
||
if sorted_ids[i] - sorted_ids[i - 1] > 1:
|
||
gaps.append((sorted_ids[i - 1], sorted_ids[i]))
|
||
|
||
if gaps:
|
||
print(f" ⚠️ 发现ID间隙: {len(gaps)}处")
|
||
for i, (prev, curr) in enumerate(gaps[:3]):
|
||
print(f" 间隙{i + 1}: {prev} -> {curr} (差: {curr - prev})")
|
||
else:
|
||
print(" ✅ ID基本连续")
|
||
else:
|
||
print(f"❌ 统计文件不存在: {stats_path}")
|
||
except Exception as e:
|
||
print(f"❌ ID映射测试失败: {e}")
|
||
|
||
|
||
def main():
|
||
if len(sys.argv) < 2:
|
||
print("使用方法: python test_cpu_inference.py <checkpoint_path>")
|
||
print("示例: python test_cpu_inference.py ~/下载/best_model.pt")
|
||
return
|
||
|
||
checkpoint_path = sys.argv[1]
|
||
|
||
# 测试设备转换
|
||
state_dict, full_checkpoint = test_device_conversion(checkpoint_path)
|
||
|
||
if state_dict is None:
|
||
print("\n❌ 设备转换失败,无法继续测试")
|
||
return
|
||
|
||
# 测试模型创建
|
||
model = test_model_creation(state_dict, full_checkpoint)
|
||
|
||
if model is None:
|
||
print("\n❌ 模型创建失败,无法继续测试")
|
||
return
|
||
|
||
# 测试前向传播
|
||
success = test_forward_pass(model)
|
||
|
||
if success:
|
||
print("\n✅ CPU推理测试通过!")
|
||
else:
|
||
print("\n❌ CPU推理测试失败")
|
||
|
||
# 测试ID映射
|
||
test_id_mapping()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|