#!/usr/bin/env python3 """ 测试GPU训练的模型在CPU上推理 解决设备转换和权重加载问题 """ import sys from pathlib import Path import torch def test_device_conversion(checkpoint_path): """测试设备转换""" print("=" * 60) print("GPU->CPU设备转换测试") print("=" * 60) # 方法1:直接加载到CPU print("\n🔍 方法1: 直接加载到CPU") try: checkpoint = torch.load(checkpoint_path, map_location="cpu") print("✅ 直接加载到CPU成功") if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] # 检查是否有CUDA tensor cuda_tensors = 0 for key, tensor in state_dict.items(): if tensor.is_cuda: cuda_tensors += 1 print(f" 模型状态字典包含CUDA tensor: {cuda_tensors}/{len(state_dict)}") # 将权重移动到CPU cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} print(" 已将所有权重移动到CPU") return cpu_state_dict, checkpoint except Exception as e: print(f"❌ 方法1失败: {e}") return None, None def test_model_creation(state_dict, config=None): """测试模型创建和权重加载""" print("\n" + "=" * 60) print("模型创建和权重加载测试") print("=" * 60) try: from src.model.model import InputMethodEngine # 从checkpoint获取配置或使用默认值 if config and "config" in config: model_config = config["config"] print("📋 使用checkpoint中的配置:") for key, value in model_config.items(): print(f" {key}: {value}") model = InputMethodEngine( vocab_size=model_config.get("vocab_size", 10019), pinyin_vocab_size=model_config.get("pinyin_vocab_size", 30), dim=model_config.get("dim", 512), num_slots=model_config.get("num_slots", 8), n_layers=model_config.get("n_layers", 4), n_heads=model_config.get("n_heads", 4), num_experts=model_config.get("num_experts", 20), max_seq_len=model_config.get("max_seq_len", 128), compile=False, ) else: print("📋 使用默认配置") model = InputMethodEngine(compile=False) print("✅ 模型创建成功") print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}") print(f" 词汇表大小: {model.vocab_size}") # 尝试加载权重 print("\n🔄 加载权重...") try: model.load_state_dict(state_dict) print("✅ 权重加载成功 (strict=True)") except RuntimeError as e: print(f"⚠️ 严格模式加载失败: {e}") print("🔄 尝试非严格模式加载...") model.load_state_dict(state_dict, strict=False) print("✅ 权重加载成功 (strict=False)") # 检查分类头 print("\n📊 分类头检查:") classifier_weight = model.classifier.weight.data classifier_bias = model.classifier.bias.data print(f" 权重形状: {classifier_weight.shape}") print( f" 权重范围: [{classifier_weight.min():.6f}, {classifier_weight.max():.6f}]" ) print( f" 权重均值: {classifier_weight.mean():.6f} ± {classifier_weight.std():.6f}" ) print(f" 偏置形状: {classifier_bias.shape}") print(f" 偏置范围: [{classifier_bias.min():.6f}, {classifier_bias.max():.6f}]") return model except Exception as e: print(f"❌ 模型创建失败: {e}") import traceback traceback.print_exc() return None def test_forward_pass(model): """测试前向传播""" print("\n" + "=" * 60) print("前向传播测试") print("=" * 60) try: model.eval() # 创建简单的测试输入 batch_size = 2 seq_len = 64 # 使用随机但合理的输入 input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long) token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long) attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) pinyin_ids = torch.randint(1, 30, (batch_size, 24), dtype=torch.long) # 避免0 history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long) print("📋 测试输入:") print(f" input_ids: {input_ids.shape}") print(f" token_type_ids: {token_type_ids.shape}") print(f" attention_mask: {attention_mask.shape}") print(f" pinyin_ids: {pinyin_ids.shape}") print(f" history_slot_ids: {history_slot_ids.shape}") # 执行前向传播 with torch.no_grad(): logits = model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) print("\n✅ 前向传播成功") print(f" Logits形状: {logits.shape}") print(f" Logits范围: [{logits.min():.6f}, {logits.max():.6f}]") print(f" Logits均值: {logits.mean():.6f} ± {logits.std():.6f}") # 检查输出是否合理 if logits.abs().max() < 1e-6: print("⚠️ 警告: Logits值非常小,可能权重未正确加载") # 计算softmax概率 probs = torch.nn.functional.softmax(logits, dim=-1) print("\n📊 概率分布:") print(f" 概率总和: {probs.sum(dim=-1).mean().item():.6f} (应为1.0)") print(f" 最大概率: {probs.max().item():.6f}") print(f" 最小概率: {probs.min().item():.6f}") print( f" 平均概率: {probs.mean().item():.6f} (预期: ~{1.0 / logits.size(-1):.6f})" ) # 检查top-5预测 top_probs, top_indices = torch.topk(probs, k=5, dim=-1) print("\n🏆 Batch 0的Top-5预测:") for i in range(5): print( f" {i + 1}. ID {top_indices[0, i].item()}: {top_probs[0, i].item():.6f}" ) return True except Exception as e: print(f"❌ 前向传播失败: {e}") import traceback traceback.print_exc() return False def test_id_mapping(): """测试ID映射""" print("\n" + "=" * 60) print("ID映射测试") print("=" * 60) try: from src.model.query import QueryEngine query_engine = QueryEngine() stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): query_engine.load(stats_path) print("✅ 查询引擎加载成功") # 测试一些常见字的ID test_chars = ["的", "是", "一", "天", "上", "人", "好"] print("📋 常见字ID映射:") for char in test_chars: results = query_engine.query_by_char(char, limit=1) if results: char_id, pinyin, count = results[0] print( f" '{char}' -> ID: {char_id}, 拼音: {pinyin}, 频次: {count:,}" ) else: print(f" '{char}' -> 未找到") # 检查ID范围 all_ids = list(query_engine._id_to_info.keys()) if all_ids: print("\n📊 ID范围统计:") print(f" 最小ID: {min(all_ids)}") print(f" 最大ID: {max(all_ids)}") print(f" ID总数: {len(all_ids)}") # 检查ID是否连续 sorted_ids = sorted(all_ids) gaps = [] for i in range(1, len(sorted_ids)): if sorted_ids[i] - sorted_ids[i - 1] > 1: gaps.append((sorted_ids[i - 1], sorted_ids[i])) if gaps: print(f" ⚠️ 发现ID间隙: {len(gaps)}处") for i, (prev, curr) in enumerate(gaps[:3]): print(f" 间隙{i + 1}: {prev} -> {curr} (差: {curr - prev})") else: print(" ✅ ID基本连续") else: print(f"❌ 统计文件不存在: {stats_path}") except Exception as e: print(f"❌ ID映射测试失败: {e}") def main(): if len(sys.argv) < 2: print("使用方法: python test_cpu_inference.py ") print("示例: python test_cpu_inference.py ~/下载/best_model.pt") return checkpoint_path = sys.argv[1] # 测试设备转换 state_dict, full_checkpoint = test_device_conversion(checkpoint_path) if state_dict is None: print("\n❌ 设备转换失败,无法继续测试") return # 测试模型创建 model = test_model_creation(state_dict, full_checkpoint) if model is None: print("\n❌ 模型创建失败,无法继续测试") return # 测试前向传播 success = test_forward_pass(model) if success: print("\n✅ CPU推理测试通过!") else: print("\n❌ CPU推理测试失败") # 测试ID映射 test_id_mapping() print("\n" + "=" * 60) print("测试完成") print("=" * 60) if __name__ == "__main__": main()