diff --git a/check_weights.py b/check_weights.py new file mode 100755 index 0000000..f69d629 --- /dev/null +++ b/check_weights.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +快速检查模型权重加载情况的脚本 +""" + +from pathlib import Path + +import numpy as np +import torch + + +def analyze_checkpoint(checkpoint_path): + """分析checkpoint文件""" + print(f"🔍 分析checkpoint: {checkpoint_path}") + + if not Path(checkpoint_path).exists(): + print(f"❌ 文件不存在") + return + + try: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + print(f"✅ 加载成功") + print(f" 类型: {type(checkpoint)}") + + if isinstance(checkpoint, dict): + print(f" 键名: {list(checkpoint.keys())}") + + # 找到模型状态字典 + state_dict = None + if "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + print(f" 🔍 使用'model_state_dict'键") + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + print(f" 🔍 使用'state_dict'键") + else: + # 可能是直接的状态字典 + state_dict = checkpoint + print(f" 🔍 使用直接状态字典") + + if state_dict: + print(f" 总权重数: {len(state_dict)}") + + # 分析分类头权重 + classifier_keys = [] + for key in state_dict.keys(): + if "classifier" in key: + classifier_keys.append(key) + + if classifier_keys: + print(f" 📊 分类头相关权重:") + for key in classifier_keys: + weight = state_dict[key] + print(f" {key}: shape={weight.shape}") + print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]") + print(f" 均值: {weight.mean():.6f}") + print(f" 标准差: {weight.std():.6f}") + + # 检查权重是否接近随机初始化 + if weight.std() < 0.01: + print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练") + + # 检查模型架构键名 + print(f"\n 🔑 模型架构键名示例(前20个):") + for i, key in enumerate(list(state_dict.keys())[:20]): + weight = state_dict[key] + print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}") + + # 检查是否有预期的组件 + expected_components = [ + "context_encoder", + "slot_memory", + "cross_attn", + "moe", + "classifier", + ] + found_components = [] + for comp in expected_components: + found = any(comp in key for key in state_dict.keys()) + if found: + found_components.append(comp) + + print(f"\n 📋 找到的模型组件: {found_components}") + missing = set(expected_components) - set(found_components) + if missing: + print(f" ❌ 缺失的组件: {missing}") + + return state_dict + else: + print(f"❌ checkpoint不是字典类型") + + except Exception as e: + print(f"❌ 加载失败: {e}") + import traceback + + traceback.print_exc() + + +def check_weight_distribution(state_dict): + """检查权重分布""" + print(f"\n📊 权重分布统计:") + + weight_stats = [] + for key, weight in state_dict.items(): + if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置 + stats = { + "key": key, + "shape": weight.shape, + "min": weight.min().item(), + "max": weight.max().item(), + "mean": weight.mean().item(), + "std": weight.std().item(), + "abs_mean": weight.abs().mean().item(), + } + weight_stats.append(stats) + + # 打印前10个权重 + for i, stats in enumerate(weight_stats[:10]): + print(f" {i + 1:2d}. {stats['key']:40}") + print(f" 形状: {stats['shape']}") + print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]") + print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}") + + # 检查是否接近随机初始化 + if stats["std"] < 0.01: + print(f" ⚠️ 警告: 标准差很小,可能未训练") + + return weight_stats + + +def main(): + import sys + + if len(sys.argv) < 2: + print("使用方法: python check_weights.py ") + print("示例: python check_weights.py ./output/checkpoints/best_model.pt") + return + + checkpoint_path = sys.argv[1] + state_dict = analyze_checkpoint(checkpoint_path) + + if state_dict: + check_weight_distribution(state_dict) + + +if __name__ == "__main__": + main() diff --git a/debug_inference.py b/debug_inference.py new file mode 100644 index 0000000..3cc5b0e --- /dev/null +++ b/debug_inference.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +输入法模型推理调试脚本 +用于诊断为什么模型预测结果异常 +""" + +from pathlib import Path + +import torch +import torch.nn.functional as F + + +def debug_model_checkpoint(checkpoint_path: str): + """调试checkpoint文件""" + print(f"\n🔍 调试checkpoint: {checkpoint_path}") + + if not Path(checkpoint_path).exists(): + print(f"❌ Checkpoint文件不存在: {checkpoint_path}") + return None + + # 加载checkpoint + checkpoint = torch.load(checkpoint_path, map_location="cpu") + print(f"Checkpoint类型: {type(checkpoint)}") + + if isinstance(checkpoint, dict): + print(f"Checkpoint键名: {list(checkpoint.keys())}") + + # 检查是否有模型状态字典 + if "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + print(f"使用'model_state_dict'键,包含{len(state_dict)}个权重") + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + print(f"使用'state_dict'键,包含{len(state_dict)}个权重") + else: + # 可能是直接的状态字典 + state_dict = checkpoint + print(f"使用直接状态字典,包含{len(state_dict)}个权重") + + # 打印前10个权重键名和形状 + print("\n前10个权重键名和形状:") + for i, (key, value) in enumerate(list(state_dict.items())[:10]): + print(f" {i + 1}. {key}: {value.shape}") + + # 特别检查分类头 + classifier_keys = [k for k in state_dict.keys() if "classifier" in k] + if classifier_keys: + print(f"\n分类头相关权重:") + for key in classifier_keys: + weight = state_dict[key] + print( + f" {key}: shape={weight.shape}, range=[{weight.min():.4f}, {weight.max():.4f}]" + ) + else: + print(f"\n⚠️ 未找到分类头权重") + + return state_dict + else: + print(f"❌ Checkpoint不是字典类型: {type(checkpoint)}") + return None + + +def debug_input_preparation(): + """调试输入数据准备""" + print("\n🔍 调试输入数据准备") + + # 测试text_to_pinyin_ids函数 + from src.model.dataset import text_to_pinyin_ids + + test_pinyins = ["tian", "shang", "ha", "ni", "hao"] + for pinyin in test_pinyins: + ids = text_to_pinyin_ids(pinyin) + print(f"拼音 '{pinyin}' -> ID列表: {ids}") + + # 测试tokenizer + from modelscope import AutoTokenizer + + try: + tokenizer_path = ( + Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" + ) + tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) + print(f"\n✅ Tokenizer加载成功,词汇表大小: {tokenizer.vocab_size}") + + # 测试tokenizer + test_texts = ["今天天气", "我们去公园", "张三|李四|今天天气"] + for text in test_texts: + encoded = tokenizer( + text, max_length=128, truncation=True, return_tensors="pt" + ) + print(f"文本 '{text}' -> input_ids形状: {encoded['input_ids'].shape}") + except Exception as e: + print(f"❌ Tokenizer加载失败: {e}") + + +def debug_query_engine(): + """调试查询引擎""" + print("\n🔍 调试查询引擎") + + try: + from src.model.query import QueryEngine + + query_engine = QueryEngine() + stats_path = ( + Path(__file__).parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + + if stats_path.exists(): + query_engine.load(stats_path) + print(f"✅ 查询引擎加载成功") + + # 测试字符查询 + test_chars = ["天", "上", "人", "好", "的"] + for char in test_chars: + results = query_engine.query_by_char(char, limit=3) + if results: + print(f"字符 '{char}': {[(r[0], r[1], r[2]) for r in results[:3]]}") + else: + print(f"字符 '{char}': 无结果") + + # 测试拼音查询 + test_pinyins = ["tian", "shang", "ren", "hao", "de"] + for pinyin in test_pinyins: + results = query_engine.query_by_pinyin(pinyin, limit=3) + if results: + print( + f"拼音 '{pinyin}': {[(r[0], r[1], r[2]) for r in results[:3]]}" + ) + else: + print(f"拼音 '{pinyin}': 无结果") + else: + print(f"❌ 统计文件不存在: {stats_path}") + except Exception as e: + print(f"❌ 查询引擎调试失败: {e}") + import traceback + + traceback.print_exc() + + +def create_minimal_test_model(): + """创建最小测试模型,检查基本功能""" + print("\n🔍 创建最小测试模型") + + try: + from src.model.model import InputMethodEngine + + # 创建小模型 + model = InputMethodEngine( + vocab_size=100, + pinyin_vocab_size=30, + dim=64, # 小维度测试 + num_slots=8, + n_layers=2, + n_heads=2, + num_experts=4, + max_seq_len=64, + compile=False, + ) + + print(f"✅ 最小模型创建成功") + print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}") + print(f" 分类头形状: {model.classifier.weight.shape}") + + # 测试前向传播 + batch_size = 2 + input_ids = torch.randint(0, 100, (batch_size, 64)) + token_type_ids = torch.zeros((batch_size, 64), dtype=torch.long) + attention_mask = torch.ones((batch_size, 64), dtype=torch.long) + pinyin_ids = torch.randint(0, 30, (batch_size, 24)) + history_slot_ids = torch.randint(0, 100, (batch_size, 8)) + + with torch.no_grad(): + logits = model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + + print(f" 前向传播成功,logits形状: {logits.shape}") + print(f" logits范围: [{logits.min():.4f}, {logits.max():.4f}]") + + # 检查输出 + probs = F.softmax(logits, dim=-1) + print(f" 概率范围: [{probs.min():.4f}, {probs.max():.4f}]") + + return model + + except Exception as e: + print(f"❌ 最小模型测试失败: {e}") + import traceback + + traceback.print_exc() + return None + + +def main(): + """主调试函数""" + import argparse + + parser = argparse.ArgumentParser(description="输入法模型推理调试") + parser.add_argument("--checkpoint", type=str, help="模型checkpoint路径") + + args = parser.parse_args() + + print("=" * 60) + print("输入法模型推理调试工具") + print("=" * 60) + + # 调试checkpoint + if args.checkpoint: + state_dict = debug_model_checkpoint(args.checkpoint) + else: + print("\n⚠️ 未提供checkpoint路径,跳过checkpoint调试") + state_dict = None + + # 调试输入准备 + debug_input_preparation() + + # 调试查询引擎 + debug_query_engine() + + # 创建最小测试模型 + model = create_minimal_test_model() + + print("\n" + "=" * 60) + print("调试完成") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..f97ebb2 --- /dev/null +++ b/inference.py @@ -0,0 +1,602 @@ +#!/usr/bin/env python3 +""" +输入法模型推理脚本 + +使用方法: + python inference.py --checkpoint ./output/checkpoints/best_model.pt + +交互模式: 分步询问输入 + 1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空) + 2. 光标前文本: 光标前的连续文本 + 3. 光标后文本: 光标后的连续文本 + 4. 拼音: 当前输入的拼音 + 5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上") + +示例场景: + 输入"shanghai"已确认"上",继续输入"tian" + 上下文提示: 张三,李四 + 光标前文本: 今天天气很好 + 光标后文本: 我们去公园玩 + 拼音: tian + 槽位历史: 上 +""" + +import argparse +import time +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from modelscope import AutoTokenizer + +from src.model.dataset import text_to_pinyin_ids +from src.model.model import InputMethodEngine +from src.model.query import QueryEngine + + +class InputMethodInference: + """输入法模型推理器""" + + def __init__( + self, + checkpoint_path: str, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + ): + self.device = torch.device(device) + self.checkpoint_path = checkpoint_path + + # 加载组件 + print(f"正在加载模型从: {checkpoint_path}") + self.load_model() + + # 加载tokenizer + print("正在加载tokenizer...") + self.load_tokenizer() + + # 加载查询引擎 + print("正在加载查询引擎...") + self.load_query_engine() + + print(f"✅ 推理器初始化完成 (设备: {self.device})") + + def load_model(self): + """加载训练好的模型""" + # 创建模型实例(不编译) + self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False) + + # 加载checkpoint + # 加载训练好的权重(强制先加载到CPU,再移动到目标设备) + # 这样确保GPU训练的权重能正确转换到CPU + checkpoint = torch.load(self.checkpoint_path, map_location="cpu") + if "model_state_dict" in checkpoint: + self.model.load_state_dict(checkpoint["model_state_dict"]) + else: + self.model.load_state_dict(checkpoint) + + self.model.eval() + self.model.to(self.device) + print( + f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}" + ) + + def load_tokenizer(self): + """加载tokenizer""" + try: + # 从assets/tokenizer加载tokenizer + tokenizer_path = ( + Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" + ) + self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) + print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") + except Exception as e: + print(f"⚠️ 无法加载tokenizer: {e}") + print("使用默认的bert-base-chinese tokenizer") + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") + + def load_query_engine(self): + """加载查询引擎用于字符-ID转换""" + try: + self.query_engine = QueryEngine() + # 加载默认的统计文件 + stats_path = ( + Path(__file__).parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + if stats_path.exists(): + self.query_engine.load(stats_path) + print( + f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}" + ) + else: + print(f"⚠️ 统计文件不存在: {stats_path}") + self.query_engine = None + except Exception as e: + print(f"⚠️ 无法加载查询引擎: {e}") + self.query_engine = None + + def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int: + """将汉字转换为ID,如果提供拼音则更精确""" + # 处理结束符 + if char == "//": + return 0 # 假设0是结束符ID + + if self.query_engine is None: + # 简单回退:使用unicode编码 + return ord(char) if len(char) == 1 else 0 + + try: + if pinyin is not None: + # 使用精确的字符-拼音对 + info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin) + if info: + return info.id + + # 回退:获取字符的第一个拼音变体 + results = self.query_engine.query_by_char(char, limit=1) + if results: + return results[0][0] # 返回ID + + return 0 + except: + return 0 + + def id_to_char(self, id: int) -> str: + """将ID转换为汉字""" + # 处理结束符ID (假设0是结束符) + if id == 0: + return "//" + + if self.query_engine is None: + return chr(id) if id < 0x110000 else "" + + try: + info = self.query_engine.query_by_id(id) + return info.char if info else f"[ID:{id}]" + except: + return f"[ID:{id}]" + + def prepare_inputs( + self, + context_prompts: List[str], + text_before: str, + text_after: str, + pinyin: str, + slot_chars: List[str], + max_seq_len: int = 128, + ): + """ + 准备模型输入 + + Args: + context_prompts: 上下文提示(专有词汇、姓名等,用|分隔) + text_before: 光标前文本 + text_after: 光标后文本 + pinyin: 当前输入的拼音 + slot_chars: 槽位内的汉字列表(用户已确认的输入历史) + max_seq_len: 最大序列长度 + + Returns: + 模型输入字典 + """ + + # 1. 构建tokenizer输入 + # 根据dataset.py,格式为: "part4|part1" 和 part3 + # part4: 上下文提示(专有词汇、姓名等,模型不掌握) + # part1: text_before + # part3: text_after + + # 处理上下文提示 + context_text = "|".join(context_prompts) if context_prompts else "" + + # 构建输入文本 + if context_text: + input_text = f"{context_text}|{text_before}" + else: + input_text = text_before + + # 2. Tokenize + encoded = self.tokenizer( + input_text, + text_after, + max_length=max_seq_len, + padding="max_length", + truncation=True, + return_tensors="pt", + return_token_type_ids=True, + ) + + # 3. 处理拼音输入 + pinyin_ids = text_to_pinyin_ids(pinyin) + if len(pinyin_ids) < 24: + pinyin_ids.extend([0] * (24 - len(pinyin_ids))) + else: + pinyin_ids = pinyin_ids[:24] + pinyin_tensor = torch.tensor([pinyin_ids], dtype=torch.long) + + # 4. 处理历史槽位(用户已确认的输入历史) + history_slot_ids = [] + for char in slot_chars: + # 为每个槽位汉字查找ID(用户已确认的输入历史) + char_id = self.char_to_id(char) + history_slot_ids.append(char_id) + + # 填充到8个槽位 + if len(history_slot_ids) < 8: + history_slot_ids.extend([0] * (8 - len(history_slot_ids))) + else: + history_slot_ids = history_slot_ids[:8] + + history_tensor = torch.tensor([history_slot_ids], dtype=torch.long) + + # 5. 移动到设备 + inputs = { + "input_ids": encoded["input_ids"].to(self.device), + "token_type_ids": encoded["token_type_ids"].to(self.device), + "attention_mask": encoded["attention_mask"].to(self.device), + "pinyin_ids": pinyin_tensor.to(self.device), + "history_slot_ids": history_tensor.to(self.device), + } + + return inputs + + def predict( + self, + context_prompts: List[str], + text_before: str, + text_after: str, + pinyin: str, + slot_chars: List[str], + top_k: int = 20, + ) -> Tuple[List[Tuple[str, float, int]], float]: + """ + 执行推理 + + Args: + context_prompts: 上下文提示(专有词汇、姓名等,用|分隔) + text_before: 光标前文本 + text_after: 光标后文本 + pinyin: 当前输入的拼音 + slot_chars: 槽位内的汉字列表(用户已确认的输入历史,最大8个) + top_k: 返回top-k个预测结果 + + Returns: + (predictions, inference_time_ms) + predictions: List[Tuple[char, score, id]] + """ + start_time = time.perf_counter() + + # 准备输入 + inputs = self.prepare_inputs( + context_prompts, text_before, text_after, pinyin, slot_chars + ) + + # 调试信息:打印输入形状 + print("\n🔍 调试信息 - 输入检查:") + for key, tensor in inputs.items(): + print( + f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}" + ) + if key in ["history_slot_ids", "pinyin_ids"]: + print(f" 值: {tensor.cpu().numpy().tolist()}") + + # 特别检查拼音和槽位输入 + print("\n🔍 拼音输入详细分析:") + pinyin_tensor = inputs["pinyin_ids"] + pinyin_values = pinyin_tensor.cpu().numpy()[0] + # 将拼音ID转换回字符 + from src.model.dataset import text_to_pinyin_ids + + # 逆向转换可能需要一个反向映射,这里先简单显示 + print(f" 拼音ID列表: {pinyin_values}") + print(f" 拼音非零ID: {[id for id in pinyin_values if id != 0]}") + + print("\n🔍 槽位历史详细分析:") + slot_tensor = inputs["history_slot_ids"] + slot_values = slot_tensor.cpu().numpy()[0] + print(f" 槽位ID列表: {slot_values}") + print(f" 槽位非零ID: {[id for id in slot_values if id != 0]}") + # 将ID转换为汉字 + slot_chars_converted = [self.id_to_char(id) for id in slot_values] + print(f" 槽位汉字: {slot_chars_converted}") + + # 调试:检查模型权重 + print(f"\n🔍 调试信息 - 模型检查:") + print(f" 模型设备: {self.device}") + print(f" 模型是否训练模式: {self.model.training}") + + # 检查模型词汇表大小 + vocab_size = self.model.vocab_size + print(f" 模型词汇表大小: {vocab_size}") + + # 检查前5个ID对应的字符 + print(f" 前5个ID对应的字符:") + for i in range(1, 6): + char = self.id_to_char(i) + print(f" ID {i}: '{char}'") + + # 检查分类头权重 + # 调试:检查分类头权重 + classifier_weight = self.model.classifier.weight.data + classifier_bias = self.model.classifier.bias.data + print(f" 分类头权重形状: {classifier_weight.shape}") + print( + f" 分类头权重范围: [{classifier_weight.min():.4f}, {classifier_weight.max():.4f}]" + ) + print(f" 分类头权重均值: {classifier_weight.mean():.4f}") + print(f" 分类头权重标准差: {classifier_weight.std():.4f}") + print(f" 分类头偏置形状: {classifier_bias.shape}") + print( + f" 分类头偏置范围: [{classifier_bias.min():.4f}, {classifier_bias.max():.4f}]" + ) + print(f" 分类头偏置均值: {classifier_bias.mean():.4f}") + print(f" 分类头偏置标准差: {classifier_bias.std():.4f}") + + # 推理(CPU推理时禁用混合精度) + with torch.no_grad(): + if self.device.type == "cuda": + with torch.autocast(device_type="cuda"): + logits = self.model(**inputs) + else: + # CPU推理时不使用autocast + logits = self.model(**inputs) + + # 调试:检查logits + print(f"\n🔍 调试信息 - 输出检查:") + print(f" Logits形状: {logits.shape}") + print(f" Logits范围: [{logits.min():.4f}, {logits.max():.4f}]") + print(f" Logits均值: {logits.mean():.4f}") + + # 检查logits中最大值和对应的ID + max_val, max_idx = torch.max(logits, dim=-1) + print(f" 最大logit值: {max_val.item():.4f}, 对应ID: {max_idx.item()}") + + # 获取top-k预测 + probs = F.softmax(logits, dim=-1) + top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1) + + # 调试:检查概率分布 + print(f" 概率总和: {probs.sum().item():.4f}") + top_probs_array = top_probs.cpu().numpy().flatten() + top_indices_array = top_indices.cpu().numpy().flatten() + print(f" Top-{top_k}概率: {top_probs_array}") + print(f" Top-{top_k} ID: {top_indices_array}") + + # 检查概率是否均匀分布 + print(f" 概率分布分析:") + print(f" 平均概率: {probs.mean().item():.6f}") + print(f" 最大概率: {probs.max().item():.6f}") + print(f" 最小概率: {probs.min().item():.6f}") + print(f" 标准差: {probs.std().item():.6f}") + + # 检查top-20概率是否都很小 + if top_probs_array[0] < 0.01: + print( + f" ⚠️ 警告: 最高概率 ({top_probs_array[0]:.6f}) 小于0.01,模型可能未正确训练" + ) + print(f" 💡 可能原因: 1) 权重未正确加载 2) 输入格式错误 3) 模型配置不匹配") + + inference_time_ms = (time.perf_counter() - start_time) * 1000 + + # 转换为可读结果 + predictions = [] + for i in range(top_k): + idx = int(top_indices[0, i].item()) + prob = top_probs[0, i].item() + char = self.id_to_char(idx) + predictions.append((char, prob, idx)) + + return predictions, inference_time_ms + + def interactive_mode(self): + """交互式推理模式 - 分步询问输入""" + print("\n" + "=" * 60) + print("输入法模型推理 - 交互模式") + print("=" * 60) + print("说明:") + print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") + print(" - 光标前文本: 光标前的连续文本") + print(" - 光标后文本: 光标后的连续文本") + print(" - 拼音: 当前输入的拼音") + print(" - 槽位历史: 用户已确认的输入历史,如输入'shanghai'已确认'上'") + print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出") + print("-" * 60) + + while True: + try: + print("\n" + "=" * 60) + print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)") + print("格式: 用逗号分隔多个词汇,可为空") + print("示例: 张三,李四,北京大学") + context_input = input("请输入上下文提示(直接回车跳过): ").strip() + + if context_input.lower() in ["quit", "exit", "q"]: + print("退出交互模式") + break + + # 解析上下文提示 + context_prompts = [] + if context_input: + context_prompts = [ + item.strip() + for item in context_input.split(",") + if item.strip() + ] + + print( + f"✅ 已记录上下文提示: {context_prompts if context_prompts else '无'}" + ) + + print("\n" + "-" * 40) + print("第2步: 光标前文本") + print("说明: 光标前的连续文本内容") + print("示例: 今天天气很好") + text_before = input("请输入光标前文本: ").strip() + + if text_before.lower() in ["quit", "exit", "q"]: + print("退出交互模式") + break + + print(f"✅ 已记录光标前文本: '{text_before}'") + + print("\n" + "-" * 40) + print("第3步: 光标后文本") + print("说明: 光标后的连续文本内容") + print("示例: 我们去公园玩") + text_after = input("请输入光标后文本: ").strip() + + if text_after.lower() in ["quit", "exit", "q"]: + print("退出交互模式") + break + + print(f"✅ 已记录光标后文本: '{text_after}'") + + print("\n" + "-" * 40) + print("第4步: 拼音输入") + print("说明: 当前正在输入的拼音") + print("示例: tian, shang, hao") + pinyin = input("请输入拼音: ").strip() + + if pinyin.lower() in ["quit", "exit", "q"]: + print("退出交互模式") + break + + print(f"✅ 已记录拼音: '{pinyin}'") + + print("\n" + "-" * 40) + print("第5步: 槽位历史(已确认的输入)") + print("说明: 用户已确认的输入历史,用逗号分隔") + print("示例: 上 (表示输入'shanghai'已确认'上')") + print(" 今天,天气 (表示已确认两个词)") + slot_input = input("请输入槽位历史(直接回车表示无): ").strip() + + if slot_input.lower() in ["quit", "exit", "q"]: + print("退出交互模式") + break + + # 解析槽位历史 + slot_chars = [] + if slot_input: + slot_chars = [ + char.strip() for char in slot_input.split(",") if char.strip() + ] + + print(f"✅ 已记录槽位历史: {slot_chars if slot_chars else '无'}") + + print("\n" + "=" * 60) + print("📝 输入汇总:") + print(f" 上下文提示: {context_prompts if context_prompts else '无'}") + print(f" 光标前文本: '{text_before}'") + print(f" 光标后文本: '{text_after}'") + print(f" 拼音: '{pinyin}'") + print(f" 槽位历史: {slot_chars if slot_chars else '无'}") + + # 执行推理 + print("\n🔮 推理中...") + predictions, inference_time = self.predict( + context_prompts, text_before, text_after, pinyin, slot_chars + ) + + # 显示结果 + print(f"\n✅ 推理完成 (耗时: {inference_time:.2f}ms)") + print("\n🏆 Top-20 预测结果:") + print("-" * 50) + for i, (char, prob, idx) in enumerate(predictions): + if char == "//": + print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}") + else: + print( + f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}" + ) + + # 显示原始拼音对应的可能汉字 + if pinyin and self.query_engine: + print(f"\n📖 拼音 '{pinyin}' 的常见汉字:") + pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10) + if pinyin_results: + for j, (pid, char, count) in enumerate(pinyin_results): + print(f" {char} (频次: {count:,})") + else: + print(" (无匹配结果)") + + # 询问是否继续 + print("\n" + "-" * 40) + continue_input = input("是否继续推理?(y/n): ").strip().lower() + if continue_input not in ["y", "yes", ""]: + print("退出交互模式") + break + + except KeyboardInterrupt: + print("\n\n退出交互模式") + break + except Exception as e: + print(f"\n❌ 推理出错: {e}") + import traceback + + traceback.print_exc() + + # 询问是否继续 + continue_input = input("\n是否继续?(y/n): ").strip().lower() + if continue_input not in ["y", "yes", ""]: + print("退出交互模式") + break + + +def main(): + parser = argparse.ArgumentParser(description="输入法模型推理") + parser.add_argument( + "--checkpoint", type=str, required=True, help="模型checkpoint路径" + ) + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cpu", "cuda"], + help="推理设备 (默认: auto)", + ) + parser.add_argument( + "--interactive", action="store_true", default=True, help="交互模式 (默认: True)" + ) + parser.add_argument("--test", action="store_true", help="运行测试推理") + + args = parser.parse_args() + + # 选择设备 + if args.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + + # 初始化推理器 + inference = InputMethodInference(args.checkpoint, device) + + # 测试推理 + if args.test: + print("\n🧪 运行测试推理...") + print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'") + print("上下文提示: 张三、李四(模型不掌握的专有名词)") + test_predictions, test_time = inference.predict( + context_prompts=["张三", "李四"], + text_before="今天天气", + text_after="很好", + pinyin="tian", + slot_chars=["上"], # 用户已确认输入"上" + ) + print(f"测试推理耗时: {test_time:.2f}ms") + print(f"Top-5 结果:") + for i, (char, prob, idx) in enumerate(test_predictions[:5]): + if char == "//": + print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}") + else: + print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}") + + # 交互模式 + if args.interactive: + inference.interactive_mode() + + +if __name__ == "__main__": + main() diff --git a/src/model/components.py b/src/model/components.py index b5ed7a5..dd87fc0 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -311,35 +311,54 @@ class MoELayer(nn.Module): out: [batch, seq_len, dim] """ B, L, D = x.shape - num_tokens = B * L - # 展平输入以便处理 - x_flat = x.view(num_tokens, D) # [B*L, D] + # 1. Compute Gating Scores + gates = self.gate(x) # [B, L, num_experts] - # 1. 计算门控分数 - gates = self.gate(x_flat) # [B*L, num_experts] + # 2. Select Top-K Experts + topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K] - # 2. 选择 Top-K 专家 - topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K] + # Normalize weights for selected experts + weights = F.softmax(topk_vals, dim=-1) # [B, L, K] - # 归一化权重 - topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K] + # 3. Dispatch and Compute + # Initialize output + out = torch.zeros_like(x) - # 3. 并行计算所有专家(消除 Python 循环中的动态控制流) - # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量 - expert_outputs = torch.stack( - [expert(x_flat) for expert in self.experts], dim=1 - ) # [B*L, num_experts, D] + # Reshape for easier processing: flatten batch and sequence dimensions + x_flat = x.view(-1, D) # [B*L, D] + weights_flat = weights.view(-1, self.top_k) # [B*L, K] + topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K] - # 4. 使用 gather 选择对应专家的输出 - # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D] - indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D] - selected_outputs = torch.gather( - expert_outputs, 1, indices_expanded - ) # [B*L, K, D] - # 5. 加权求和 - weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D] - out_flat = weighted_outputs.sum(dim=1) # [B*L, D] + # For each of the top-k positions + for k in range(self.top_k): + # Get expert indices and weights for this position + expert_indices = topk_indices_flat[:, k] # [B*L] + expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1] - # 恢复原始形状 - return out_flat.view(B, L, D) + # Process each expert separately + for e_idx in range(self.num_experts): + # Mask for tokens assigned to this expert at position k + mask = expert_indices == e_idx # [B*L] + if not mask.any(): + continue + + # Extract tokens for this expert + x_selected = x_flat[mask] # [N_selected, D] + if x_selected.numel() == 0: + continue + + # Pass through expert + expert_out = self.experts[e_idx](x_selected) # [N_selected, D] + + # Apply expert weights and add to output + weighted_out = expert_out * expert_weights[mask] + + # Scatter back to flat output + out_flat = out.view(-1, D) + out_flat[mask] += weighted_out + + # Reshape back to original shape + out = out.view(B, L, D) + + return out diff --git a/src/model/monitor.py b/src/model/monitor.py index 409cf6f..62c7dfb 100644 --- a/src/model/monitor.py +++ b/src/model/monitor.py @@ -1,12 +1,16 @@ +import http.server import json import os +import socketserver import subprocess import sys +import threading import time import webbrowser from datetime import datetime from pathlib import Path -from typing import Optional, Union +from typing import Callable, Optional, Union +from urllib.parse import urlparse import typer @@ -335,6 +339,215 @@ def check_status( raise typer.Exit(code=1) +def create_http_handler(status_file_path: str, enable_cors: bool = True): + """创建HTTP请求处理器""" + + class TrainingStatusHTTPHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + # 只允许访问状态文件 + parsed_path = urlparse(self.path) + if parsed_path.path not in ["/", "/training_status.json"]: + self.send_error(404, "File not found") + return + + try: + # 检查文件是否存在(包括临时文件) + main_file = status_file_path + temp_file = f"{status_file_path}.tmp" + + # 尝试读取数据,最多重试3次 + max_retries = 3 + retry_delay = 0.1 # 100ms + content = None + json_valid = False + + for attempt in range(max_retries): + try: + # 首先检查主文件是否存在 + if not os.path.exists(main_file): + # 检查临时文件是否存在(可能正在写入) + if os.path.exists(temp_file): + # 如果只有临时文件存在,尝试读取临时文件 + file_to_read = temp_file + else: + # 两个文件都不存在 + break + else: + file_to_read = main_file + + # 读取文件内容 + with open(file_to_read, "r", encoding="utf-8") as f: + content = f.read() + + # 验证JSON格式 + if content: + json.loads(content) # 验证JSON格式 + json_valid = True + break # JSON有效,跳出重试循环 + else: + # 空内容,等待重试 + time.sleep(retry_delay) + + except (json.JSONDecodeError, IOError) as e: + # JSON解析错误或IO错误,等待后重试 + if attempt < max_retries - 1: + time.sleep(retry_delay) + else: + # 最后一次尝试也失败 + raise e + + if not content or not json_valid: + self.send_error(404, "Status file not found or invalid JSON") + return + + # 设置响应头 + self.send_response(200) + self.send_header("Content-type", "application/json") + self.send_header("Content-Length", str(len(content))) + + # 添加缓存控制头,避免浏览器缓存 + self.send_header("Cache-Control", "no-cache, no-store, must-revalidate") + self.send_header("Pragma", "no-cache") + self.send_header("Expires", "0") + + # 添加CORS头 + if enable_cors: + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") + + self.end_headers() + + # 发送内容 + self.wfile.write(content.encode("utf-8")) + + except Exception as e: + self.send_error(500, f"Internal server error: {str(e)}") + + def do_OPTIONS(self): + """处理OPTIONS请求(用于CORS预检)""" + self.send_response(200) + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") + self.end_headers() + + def log_message(self, format, *args): + """重写日志方法,减少日志输出""" + # 可以选择性地记录日志 + # typer.echo(f"HTTP Server: {format % args}") + pass + + return TrainingStatusHTTPHandler + + +def start_http_server( + status_file: str, + port: int, + host: str, + enable_cors: bool = True, +) -> Callable: + """ + 启动HTTP服务器 + + Args: + status_file: 状态文件路径 + port: 端口号 + host: 主机地址 + enable_cors: 是否启用CORS + + Returns: + 停止服务器的函数 + """ + # 获取绝对路径 + status_file_path = os.path.abspath(status_file) + + # 创建自定义处理器 + handler = create_http_handler(status_file_path, enable_cors) + + # 创建服务器 + server = socketserver.TCPServer((host, port), handler) + + # 在后台启动服务器 + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + typer.echo(f"🌐 HTTP服务器已启动") + typer.echo(f" 📁 状态文件: {status_file_path}") + typer.echo(f" 🔗 访问地址: http://{host}:{port}/training_status.json") + typer.echo(f" 🌍 CORS支持: {'已启用' if enable_cors else '已禁用'}") + typer.echo("\n按 Ctrl+C 停止服务器\n") + + # 返回停止函数 + def stop_server(): + typer.echo("\n🛑 正在停止HTTP服务器...") + server.shutdown() + server.server_close() + typer.echo("✅ HTTP服务器已停止") + + return stop_server + + +@app.command(name="serve") +def serve_status_file( + status_file: str = typer.Option( + "./output/training_status.json", + "--status-file", + "-s", + help="训练状态JSON文件路径", + ), + port: int = typer.Option(8080, "--port", "-p", help="HTTP服务端口号"), + host: str = typer.Option("0.0.0.0", "--host", help="HTTP服务主机地址"), + cors: bool = typer.Option(True, "--cors", help="启用CORS支持"), +): + """ + 启动HTTP服务,提供训练状态JSON文件访问 + + 启动后可通过 http://:/training_status.json 访问数据 + """ + # 检查状态文件是否存在 + if not os.path.exists(status_file): + typer.echo(f"⚠️ 警告: 状态文件不存在: {status_file}") + typer.echo("开始训练后,训练脚本会自动创建此文件。") + typer.echo("您可以先启动HTTP服务,然后开始训练。") + + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(status_file), exist_ok=True) + + # 创建空的JSON文件 + with open(status_file, "w", encoding="utf-8") as f: + json.dump([], f) + typer.echo(f"✅ 已创建空状态文件: {status_file}") + + try: + # 启动HTTP服务器 + stop_server = start_http_server( + status_file=status_file, + port=port, + host=host, + enable_cors=cors, + ) + + # 等待用户中断 + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + stop_server() + + except OSError as e: + if "Address already in use" in str(e): + typer.echo(f"❌ 错误: 端口 {port} 已被占用") + typer.echo("请使用其他端口: --port <端口号>") + else: + typer.echo(f"❌ 启动HTTP服务器时出错: {e}") + raise typer.Exit(code=1) + except Exception as e: + typer.echo(f"❌ 启动HTTP服务器时出错: {e}") + raise typer.Exit(code=1) + + def main(): """主函数""" app() diff --git a/src/model/trainer.py b/src/model/trainer.py index c855c0f..cf679a5 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -2,6 +2,7 @@ import json import math import os import random +import tempfile from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -490,9 +491,13 @@ class Trainer: [self.training_status_data] if self.training_status_data else [] ) - # 写入文件 - with open(self.status_file, "w", encoding="utf-8") as f: + # 使用原子写入避免读取不完整JSON + # 先写入临时文件,然后原子重命名 + temp_file = Path(f"{self.status_file}.tmp") + with open(temp_file, "w", encoding="utf-8") as f: json.dump(self.training_status_data, f, indent=2, ensure_ascii=False) + # 原子重命名(Unix系统是原子操作) + temp_file.rename(self.status_file) except Exception as e: logger.error(f"Failed to write training status: {e}") diff --git a/src/model/training_monitor.py b/src/model/training_monitor.py index 737921d..e710b8d 100644 --- a/src/model/training_monitor.py +++ b/src/model/training_monitor.py @@ -6,10 +6,15 @@ from datetime import datetime from pathlib import Path import pandas as pd -import plotly.express as px + +# import plotly.express as px # 暂时未使用 import plotly.graph_objects as go import streamlit as st -from plotly.subplots import make_subplots + +# from plotly.subplots import make_subplots # 暂时未使用 + +# 用于HTTP URL支持 +# requests模块将在load_from_url函数中动态导入 # 添加项目路径到系统路径,以便导入模型 sys.path.append(str(Path(__file__).parent.parent)) @@ -95,7 +100,46 @@ st.markdown( def load_training_data(file_path): - """加载训练状态数据""" + """从本地文件或HTTP URL加载训练状态数据""" + # 检查是否是HTTP/HTTPS URL + if file_path.startswith(("http://", "https://")): + return load_from_url(file_path) + else: + return load_from_local_file(file_path) + + +def load_from_url(url): + """从HTTP URL加载数据""" + # 动态导入requests模块 + try: + import requests + except ImportError: + st.error( + "requests库未安装,无法从HTTP URL加载数据。请安装:pip install requests" + ) + return pd.DataFrame() + + try: + # 设置合理的超时 + response = requests.get(url, timeout=10) + response.raise_for_status() + + data = response.json() + return convert_to_dataframe(data) + + except requests.exceptions.RequestException as e: + st.warning(f"从URL加载数据失败: {e}") + return pd.DataFrame() + except json.JSONDecodeError: + st.warning("远程返回的数据不是有效的JSON格式") + return pd.DataFrame() + except Exception as e: + st.warning(f"从URL加载数据时发生错误: {e}") + return pd.DataFrame() + + +def load_from_local_file(file_path): + """从本地文件加载数据""" try: if not os.path.exists(file_path): st.warning(f"文件不存在: {file_path}") @@ -104,70 +148,7 @@ def load_training_data(file_path): with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) - # 检测是否是配置文件(检查是否有典型的配置键) - config_keys = [ - "train_data_path", - "eval_data_path", - "output_dir", - "vocab_size", - "batch_size", - "num_epochs", - ] - if isinstance(data, dict): - # 检查是否是配置文件 - if any(key in data for key in config_keys): - st.error("❌ 检测到配置文件,请检查文件路径是否正确") - return pd.DataFrame() - - # 如果是单个训练状态字典,包装成列表 - data = [data] - elif isinstance(data, list): - # 检查列表中的第一个元素是否是配置文件 - if data and isinstance(data[0], dict): - if any(key in data[0] for key in config_keys): - st.error("❌ 检测到配置文件,请检查文件路径是否正确") - return pd.DataFrame() - - # 已经是列表,检查是否为空 - if len(data) == 0: - return pd.DataFrame() - else: - return pd.DataFrame() - - # 确保data是列表(经过上面的处理,它应该是列表) - if not isinstance(data, list): - return pd.DataFrame() - - # 清理数据:确保所有元素都是字典 - cleaned_data = [] - for i, item in enumerate(data): - if isinstance(item, dict): - # 检查是否是训练状态数据(包含训练指标) - if "step" in item or "train/loss" in item or "timestamp" in item: - cleaned_data.append(item) - else: - continue - - if len(cleaned_data) == 0: - return pd.DataFrame() - - # 使用清理后的数据创建DataFrame - try: - df = pd.DataFrame(cleaned_data, index=range(len(cleaned_data))) - except Exception: - try: - df = pd.DataFrame.from_records(cleaned_data) - except Exception: - return pd.DataFrame() - - # 确保时间戳为datetime类型 - if "timestamp" in df.columns: - try: - df["timestamp"] = pd.to_datetime(df["timestamp"]) - except Exception: - pass - - return df + return convert_to_dataframe(data) except json.JSONDecodeError: return pd.DataFrame() @@ -175,6 +156,75 @@ def load_training_data(file_path): return pd.DataFrame() +def convert_to_dataframe(data): + """将数据转换为DataFrame,包含数据验证和清理""" + # 检测是否是配置文件(检查是否有典型的配置键) + config_keys = [ + "train_data_path", + "eval_data_path", + "output_dir", + "vocab_size", + "batch_size", + "num_epochs", + ] + + if isinstance(data, dict): + # 检查是否是配置文件 + if any(key in data for key in config_keys): + st.error("❌ 检测到配置文件,请检查文件路径是否正确") + return pd.DataFrame() + + # 如果是单个训练状态字典,包装成列表 + data = [data] + elif isinstance(data, list): + # 检查列表中的第一个元素是否是配置文件 + if data and isinstance(data[0], dict): + if any(key in data[0] for key in config_keys): + st.error("❌ 检测到配置文件,请检查文件路径是否正确") + return pd.DataFrame() + + # 已经是列表,检查是否为空 + if len(data) == 0: + return pd.DataFrame() + else: + return pd.DataFrame() + + # 确保data是列表(经过上面的处理,它应该是列表) + if not isinstance(data, list): + return pd.DataFrame() + + # 清理数据:确保所有元素都是字典 + cleaned_data = [] + for i, item in enumerate(data): + if isinstance(item, dict): + # 检查是否是训练状态数据(包含训练指标) + if "step" in item or "train/loss" in item or "timestamp" in item: + cleaned_data.append(item) + else: + continue + + if len(cleaned_data) == 0: + return pd.DataFrame() + + # 使用清理后的数据创建DataFrame + try: + df = pd.DataFrame(cleaned_data, index=range(len(cleaned_data))) + except Exception: + try: + df = pd.DataFrame.from_records(cleaned_data) + except Exception: + return pd.DataFrame() + + # 确保时间戳为datetime类型 + if "timestamp" in df.columns: + try: + df["timestamp"] = pd.to_datetime(df["timestamp"]) + except Exception: + pass + + return df + + def create_metric_card(label, value, delta=None, help_text=None): """创建指标卡片""" col1, col2 = st.columns([3, 1]) @@ -381,9 +431,9 @@ def main(): "TRAINING_STATUS_FILE", "./output/training_status.json" ) status_file = st.text_input( - "状态文件路径", + "状态文件路径或URL", value=default_status_file, - help="训练过程中生成的JSON状态文件路径", + help="可以是:\n1. 本地文件路径(如 ./output/training_status.json)\n2. HTTP/HTTPS URL(如 http://服务器IP:端口/training_status.json)", ) # 刷新间隔 diff --git a/test.py b/test.py index f74f6c1..4e391e7 100644 --- a/test.py +++ b/test.py @@ -5,20 +5,11 @@ from torch.utils.data import DataLoader from tqdm import tqdm from model.dataset import PinyinInputDataset +from model.model import InputMethodEngine from model.trainer import collate_fn, worker_init_fn -# Try to import DataLoader2 from torchdata, fallback to standard DataLoader -try: - from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService - - DATA_LOADER2_AVAILABLE = True - print("✅ Using DataLoader2 from torchdata") -except ImportError: - DATA_LOADER2_AVAILABLE = False - print("⚠️ torchdata not installed, falling back to standard DataLoader") - -max_iter_length = 128 * 128 -batch_size = 1024 +max_iter_length = 5 +batch_size = 1 if sys.platform == "win32": dataset_path = "data" @@ -29,40 +20,10 @@ dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) def create_dataloader(): - """ - Create dataloader with DataLoader2 if available, otherwise fallback to DataLoader. - This function tries to handle streaming datasets better with DataLoader2. - """ - if DATA_LOADER2_AVAILABLE: - try: - # DataLoader2 configuration for streaming datasets - # Use MultiProcessingReadingService with careful worker settings - reading_service = MultiProcessingReadingService( - num_workers=2, # Start with 2 workers for streaming dataset - prefetch_factor=2, # Reduced prefetch for better memory management - persistent_workers=True, - pin_memory=torch.cuda.is_available(), - worker_init_fn=worker_init_fn, - ) - - dataloader = DataLoader2( - dataset, - reading_service=reading_service, - batch_size=batch_size, - collate_fn=collate_fn, - shuffle=False, # Dataset handles shuffling internally - ) - print(f"✅ Created DataLoader2 with {2} workers") - return dataloader - except Exception as e: - print(f"⚠️ DataLoader2 creation failed: {e}, falling back to DataLoader") - - # Fallback to standard DataLoader - print("📊 Using standard DataLoader") dataloader = DataLoader( dataset, batch_size=batch_size, - num_workers=2, # Limited to 2 for streaming dataset compatibility + num_workers=1, # Limited to 2 for streaming dataset compatibility pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, @@ -72,32 +33,52 @@ def create_dataloader(): return dataloader +samples = [] + # Create the dataloader dataloader = create_dataloader() +# Convert to list to test loading (as in original code) +dataloader_list = list([i for i in dataloader]) +print(f"✅ Successfully loaded {len(dataloader_list)} batches") -# Test the dataloader -print(f"🔍 Testing dataloader with batch_size={batch_size}") -print(f" Dataset max_iter_length: {max_iter_length}") -print(f" Expected batches: {max_iter_length / batch_size:.0f}") +# Process batches +for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)): + samples.append(line) -try: - # Convert to list to test loading (as in original code) - dataloader_list = list([i for i in dataloader]) - print(f"✅ Successfully loaded {len(dataloader_list)} batches") +model = InputMethodEngine(pinyin_vocab_size=30, compile=False) - # Process batches - for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)): - zero_labels = (line["labels"] == 0).sum() - print(f"Batch {i}: labels==0 count = {zero_labels.item()}") - # Early exit for testing - if i >= 5: # Limit to 5 batches for quick testing - print("⚠️ Limited to 5 batches for testing") - break +checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") +model.load_state_dict(checkpoint["model_state_dict"]) +sample = samples[0] +input_ids = sample["input_ids"] +token_type_ids = sample["token_type_ids"] +attention_mask = sample["attention_mask"] +pinyin_ids = sample["pinyin_ids"] +history_slot_ids = sample["history_slot_ids"] +for k, v in sample.items(): + if isinstance(v, str): + print(f"{k}: {v}") +res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) +sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1]) +print(sort_res[0:5]) -except Exception as e: - print(f"❌ Error during dataloader iteration: {e}") - import traceback +# 在test.py的res计算后添加: +import torch.nn.functional as F - traceback.print_exc() +# 计算softmax概率 +probs = F.softmax(res, dim=-1) -print("🏁 Test completed") +print(f"\n📊 概率分布分析:") +print(f" 形状: {probs.shape}") +print(f" 总概率和: {probs.sum().item():.6f}") +print(f" 最大概率: {probs.max().item():.6f}") +print(f" 最小概率: {probs.min().item():.6f}") +print(f" 平均概率: {probs.mean().item():.6f}") + +# 获取top-20概率 +top_probs, top_indices = torch.topk(probs, k=20) +print(f"\n🏆 Top-20预测:") +for i in range(20): + idx = top_indices[0, i].item() + prob = top_probs[0, i].item() + print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}") diff --git a/test_cpu_inference.py b/test_cpu_inference.py new file mode 100644 index 0000000..17e983e --- /dev/null +++ b/test_cpu_inference.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +测试GPU训练的模型在CPU上推理 +解决设备转换和权重加载问题 +""" + +import sys +from pathlib import Path + +import torch + + +def test_device_conversion(checkpoint_path): + """测试设备转换""" + print("=" * 60) + print("GPU->CPU设备转换测试") + print("=" * 60) + + # 方法1:直接加载到CPU + print("\n🔍 方法1: 直接加载到CPU") + try: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + print("✅ 直接加载到CPU成功") + + if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + # 检查是否有CUDA tensor + cuda_tensors = 0 + for key, tensor in state_dict.items(): + if tensor.is_cuda: + cuda_tensors += 1 + + print(f" 模型状态字典包含CUDA tensor: {cuda_tensors}/{len(state_dict)}") + + # 将权重移动到CPU + cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} + print(" 已将所有权重移动到CPU") + + return cpu_state_dict, checkpoint + except Exception as e: + print(f"❌ 方法1失败: {e}") + + return None, None + + +def test_model_creation(state_dict, config=None): + """测试模型创建和权重加载""" + print("\n" + "=" * 60) + print("模型创建和权重加载测试") + print("=" * 60) + + try: + from src.model.model import InputMethodEngine + + # 从checkpoint获取配置或使用默认值 + if config and "config" in config: + model_config = config["config"] + print("📋 使用checkpoint中的配置:") + for key, value in model_config.items(): + print(f" {key}: {value}") + + model = InputMethodEngine( + vocab_size=model_config.get("vocab_size", 10019), + pinyin_vocab_size=model_config.get("pinyin_vocab_size", 30), + dim=model_config.get("dim", 512), + num_slots=model_config.get("num_slots", 8), + n_layers=model_config.get("n_layers", 4), + n_heads=model_config.get("n_heads", 4), + num_experts=model_config.get("num_experts", 20), + max_seq_len=model_config.get("max_seq_len", 128), + compile=False, + ) + else: + print("📋 使用默认配置") + model = InputMethodEngine(compile=False) + + print("✅ 模型创建成功") + print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}") + print(f" 词汇表大小: {model.vocab_size}") + + # 尝试加载权重 + print("\n🔄 加载权重...") + try: + model.load_state_dict(state_dict) + print("✅ 权重加载成功 (strict=True)") + except RuntimeError as e: + print(f"⚠️ 严格模式加载失败: {e}") + print("🔄 尝试非严格模式加载...") + model.load_state_dict(state_dict, strict=False) + print("✅ 权重加载成功 (strict=False)") + + # 检查分类头 + print("\n📊 分类头检查:") + classifier_weight = model.classifier.weight.data + classifier_bias = model.classifier.bias.data + print(f" 权重形状: {classifier_weight.shape}") + print( + f" 权重范围: [{classifier_weight.min():.6f}, {classifier_weight.max():.6f}]" + ) + print( + f" 权重均值: {classifier_weight.mean():.6f} ± {classifier_weight.std():.6f}" + ) + print(f" 偏置形状: {classifier_bias.shape}") + print(f" 偏置范围: [{classifier_bias.min():.6f}, {classifier_bias.max():.6f}]") + + return model + except Exception as e: + print(f"❌ 模型创建失败: {e}") + import traceback + + traceback.print_exc() + return None + + +def test_forward_pass(model): + """测试前向传播""" + print("\n" + "=" * 60) + print("前向传播测试") + print("=" * 60) + + try: + model.eval() + + # 创建简单的测试输入 + batch_size = 2 + seq_len = 64 + + # 使用随机但合理的输入 + input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long) + token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) + pinyin_ids = torch.randint(1, 30, (batch_size, 24), dtype=torch.long) # 避免0 + history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long) + + print("📋 测试输入:") + print(f" input_ids: {input_ids.shape}") + print(f" token_type_ids: {token_type_ids.shape}") + print(f" attention_mask: {attention_mask.shape}") + print(f" pinyin_ids: {pinyin_ids.shape}") + print(f" history_slot_ids: {history_slot_ids.shape}") + + # 执行前向传播 + with torch.no_grad(): + logits = model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + + print("\n✅ 前向传播成功") + print(f" Logits形状: {logits.shape}") + print(f" Logits范围: [{logits.min():.6f}, {logits.max():.6f}]") + print(f" Logits均值: {logits.mean():.6f} ± {logits.std():.6f}") + + # 检查输出是否合理 + if logits.abs().max() < 1e-6: + print("⚠️ 警告: Logits值非常小,可能权重未正确加载") + + # 计算softmax概率 + probs = torch.nn.functional.softmax(logits, dim=-1) + print("\n📊 概率分布:") + print(f" 概率总和: {probs.sum(dim=-1).mean().item():.6f} (应为1.0)") + print(f" 最大概率: {probs.max().item():.6f}") + print(f" 最小概率: {probs.min().item():.6f}") + print( + f" 平均概率: {probs.mean().item():.6f} (预期: ~{1.0 / logits.size(-1):.6f})" + ) + + # 检查top-5预测 + top_probs, top_indices = torch.topk(probs, k=5, dim=-1) + print("\n🏆 Batch 0的Top-5预测:") + for i in range(5): + print( + f" {i + 1}. ID {top_indices[0, i].item()}: {top_probs[0, i].item():.6f}" + ) + + return True + except Exception as e: + print(f"❌ 前向传播失败: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_id_mapping(): + """测试ID映射""" + print("\n" + "=" * 60) + print("ID映射测试") + print("=" * 60) + + try: + from src.model.query import QueryEngine + + query_engine = QueryEngine() + stats_path = ( + Path(__file__).parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + + if stats_path.exists(): + query_engine.load(stats_path) + print("✅ 查询引擎加载成功") + + # 测试一些常见字的ID + test_chars = ["的", "是", "一", "天", "上", "人", "好"] + print("📋 常见字ID映射:") + for char in test_chars: + results = query_engine.query_by_char(char, limit=1) + if results: + char_id, pinyin, count = results[0] + print( + f" '{char}' -> ID: {char_id}, 拼音: {pinyin}, 频次: {count:,}" + ) + else: + print(f" '{char}' -> 未找到") + + # 检查ID范围 + all_ids = list(query_engine._id_to_info.keys()) + if all_ids: + print("\n📊 ID范围统计:") + print(f" 最小ID: {min(all_ids)}") + print(f" 最大ID: {max(all_ids)}") + print(f" ID总数: {len(all_ids)}") + + # 检查ID是否连续 + sorted_ids = sorted(all_ids) + gaps = [] + for i in range(1, len(sorted_ids)): + if sorted_ids[i] - sorted_ids[i - 1] > 1: + gaps.append((sorted_ids[i - 1], sorted_ids[i])) + + if gaps: + print(f" ⚠️ 发现ID间隙: {len(gaps)}处") + for i, (prev, curr) in enumerate(gaps[:3]): + print(f" 间隙{i + 1}: {prev} -> {curr} (差: {curr - prev})") + else: + print(" ✅ ID基本连续") + else: + print(f"❌ 统计文件不存在: {stats_path}") + except Exception as e: + print(f"❌ ID映射测试失败: {e}") + + +def main(): + if len(sys.argv) < 2: + print("使用方法: python test_cpu_inference.py ") + print("示例: python test_cpu_inference.py ~/下载/best_model.pt") + return + + checkpoint_path = sys.argv[1] + + # 测试设备转换 + state_dict, full_checkpoint = test_device_conversion(checkpoint_path) + + if state_dict is None: + print("\n❌ 设备转换失败,无法继续测试") + return + + # 测试模型创建 + model = test_model_creation(state_dict, full_checkpoint) + + if model is None: + print("\n❌ 模型创建失败,无法继续测试") + return + + # 测试前向传播 + success = test_forward_pass(model) + + if success: + print("\n✅ CPU推理测试通过!") + else: + print("\n❌ CPU推理测试失败") + + # 测试ID映射 + test_id_mapping() + + print("\n" + "=" * 60) + print("测试完成") + print("=" * 60) + + +if __name__ == "__main__": + main()