SUimeModelTraner/test_cpu_inference.py

290 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试GPU训练的模型在CPU上推理
解决设备转换和权重加载问题
"""
import sys
from pathlib import Path
import torch
def test_device_conversion(checkpoint_path):
"""测试设备转换"""
print("=" * 60)
print("GPU->CPU设备转换测试")
print("=" * 60)
# 方法1直接加载到CPU
print("\n🔍 方法1: 直接加载到CPU")
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print("✅ 直接加载到CPU成功")
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
# 检查是否有CUDA tensor
cuda_tensors = 0
for key, tensor in state_dict.items():
if tensor.is_cuda:
cuda_tensors += 1
print(f" 模型状态字典包含CUDA tensor: {cuda_tensors}/{len(state_dict)}")
# 将权重移动到CPU
cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
print(" 已将所有权重移动到CPU")
return cpu_state_dict, checkpoint
except Exception as e:
print(f"❌ 方法1失败: {e}")
return None, None
def test_model_creation(state_dict, config=None):
"""测试模型创建和权重加载"""
print("\n" + "=" * 60)
print("模型创建和权重加载测试")
print("=" * 60)
try:
from src.model.model import InputMethodEngine
# 从checkpoint获取配置或使用默认值
if config and "config" in config:
model_config = config["config"]
print("📋 使用checkpoint中的配置:")
for key, value in model_config.items():
print(f" {key}: {value}")
model = InputMethodEngine(
vocab_size=model_config.get("vocab_size", 10019),
pinyin_vocab_size=model_config.get("pinyin_vocab_size", 30),
dim=model_config.get("dim", 512),
num_slots=model_config.get("num_slots", 8),
n_layers=model_config.get("n_layers", 4),
n_heads=model_config.get("n_heads", 4),
num_experts=model_config.get("num_experts", 20),
max_seq_len=model_config.get("max_seq_len", 128),
compile=False,
)
else:
print("📋 使用默认配置")
model = InputMethodEngine(compile=False)
print("✅ 模型创建成功")
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f" 词汇表大小: {model.vocab_size}")
# 尝试加载权重
print("\n🔄 加载权重...")
try:
model.load_state_dict(state_dict)
print("✅ 权重加载成功 (strict=True)")
except RuntimeError as e:
print(f"⚠️ 严格模式加载失败: {e}")
print("🔄 尝试非严格模式加载...")
model.load_state_dict(state_dict, strict=False)
print("✅ 权重加载成功 (strict=False)")
# 检查分类头
print("\n📊 分类头检查:")
classifier_weight = model.classifier.weight.data
classifier_bias = model.classifier.bias.data
print(f" 权重形状: {classifier_weight.shape}")
print(
f" 权重范围: [{classifier_weight.min():.6f}, {classifier_weight.max():.6f}]"
)
print(
f" 权重均值: {classifier_weight.mean():.6f} ± {classifier_weight.std():.6f}"
)
print(f" 偏置形状: {classifier_bias.shape}")
print(f" 偏置范围: [{classifier_bias.min():.6f}, {classifier_bias.max():.6f}]")
return model
except Exception as e:
print(f"❌ 模型创建失败: {e}")
import traceback
traceback.print_exc()
return None
def test_forward_pass(model):
"""测试前向传播"""
print("\n" + "=" * 60)
print("前向传播测试")
print("=" * 60)
try:
model.eval()
# 创建简单的测试输入
batch_size = 2
seq_len = 64
# 使用随机但合理的输入
input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)
token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(1, 30, (batch_size, 24), dtype=torch.long) # 避免0
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
print("📋 测试输入:")
print(f" input_ids: {input_ids.shape}")
print(f" token_type_ids: {token_type_ids.shape}")
print(f" attention_mask: {attention_mask.shape}")
print(f" pinyin_ids: {pinyin_ids.shape}")
print(f" history_slot_ids: {history_slot_ids.shape}")
# 执行前向传播
with torch.no_grad():
logits = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
print("\n✅ 前向传播成功")
print(f" Logits形状: {logits.shape}")
print(f" Logits范围: [{logits.min():.6f}, {logits.max():.6f}]")
print(f" Logits均值: {logits.mean():.6f} ± {logits.std():.6f}")
# 检查输出是否合理
if logits.abs().max() < 1e-6:
print("⚠️ 警告: Logits值非常小可能权重未正确加载")
# 计算softmax概率
probs = torch.nn.functional.softmax(logits, dim=-1)
print("\n📊 概率分布:")
print(f" 概率总和: {probs.sum(dim=-1).mean().item():.6f} (应为1.0)")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(
f" 平均概率: {probs.mean().item():.6f} (预期: ~{1.0 / logits.size(-1):.6f})"
)
# 检查top-5预测
top_probs, top_indices = torch.topk(probs, k=5, dim=-1)
print("\n🏆 Batch 0的Top-5预测:")
for i in range(5):
print(
f" {i + 1}. ID {top_indices[0, i].item()}: {top_probs[0, i].item():.6f}"
)
return True
except Exception as e:
print(f"❌ 前向传播失败: {e}")
import traceback
traceback.print_exc()
return False
def test_id_mapping():
"""测试ID映射"""
print("\n" + "=" * 60)
print("ID映射测试")
print("=" * 60)
try:
from src.model.query import QueryEngine
query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
query_engine.load(stats_path)
print("✅ 查询引擎加载成功")
# 测试一些常见字的ID
test_chars = ["", "", "", "", "", "", ""]
print("📋 常见字ID映射:")
for char in test_chars:
results = query_engine.query_by_char(char, limit=1)
if results:
char_id, pinyin, count = results[0]
print(
f" '{char}' -> ID: {char_id}, 拼音: {pinyin}, 频次: {count:,}"
)
else:
print(f" '{char}' -> 未找到")
# 检查ID范围
all_ids = list(query_engine._id_to_info.keys())
if all_ids:
print("\n📊 ID范围统计:")
print(f" 最小ID: {min(all_ids)}")
print(f" 最大ID: {max(all_ids)}")
print(f" ID总数: {len(all_ids)}")
# 检查ID是否连续
sorted_ids = sorted(all_ids)
gaps = []
for i in range(1, len(sorted_ids)):
if sorted_ids[i] - sorted_ids[i - 1] > 1:
gaps.append((sorted_ids[i - 1], sorted_ids[i]))
if gaps:
print(f" ⚠️ 发现ID间隙: {len(gaps)}")
for i, (prev, curr) in enumerate(gaps[:3]):
print(f" 间隙{i + 1}: {prev} -> {curr} (差: {curr - prev})")
else:
print(" ✅ ID基本连续")
else:
print(f"❌ 统计文件不存在: {stats_path}")
except Exception as e:
print(f"❌ ID映射测试失败: {e}")
def main():
if len(sys.argv) < 2:
print("使用方法: python test_cpu_inference.py <checkpoint_path>")
print("示例: python test_cpu_inference.py ~/下载/best_model.pt")
return
checkpoint_path = sys.argv[1]
# 测试设备转换
state_dict, full_checkpoint = test_device_conversion(checkpoint_path)
if state_dict is None:
print("\n❌ 设备转换失败,无法继续测试")
return
# 测试模型创建
model = test_model_creation(state_dict, full_checkpoint)
if model is None:
print("\n❌ 模型创建失败,无法继续测试")
return
# 测试前向传播
success = test_forward_pass(model)
if success:
print("\n✅ CPU推理测试通过!")
else:
print("\n❌ CPU推理测试失败")
# 测试ID映射
test_id_mapping()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)
if __name__ == "__main__":
main()