#!/usr/bin/env python3 """ eval.py - 评估模型在给定文本上的表现 使用方法: python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N] 功能: 1. 读取文本文件,限制长度300,随机抽取 2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息 3. 使用模型推理,参考test.py 4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记 """ import argparse import random import re import sys from pathlib import Path from typing import Dict, List, Tuple, Optional import numpy as np import torch import torch.nn.functional as F from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials # 添加src目录到路径 sys.path.append("src") from src.model.model import InputMethodEngine from src.model.query import QueryEngine from src.model.dataset import text_to_pinyin_ids _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") class TextEvaluator: def __init__( self, checkpoint_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): self.device = torch.device(device) self.checkpoint_path = checkpoint_path # 加载组件 self.load_model() self.load_tokenizer() self.load_query_engine() # 拼音风格权重(与dataset.py一致) self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1]) print(f"✅ 评估器初始化完成 (设备: {self.device})") # 调试:运行一个固定样本测试 import os if os.environ.get("EVAL_DEBUG"): self._debug_sample() def load_model(self): """加载训练好的模型""" self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load(self.checkpoint_path, map_location="cpu") if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: self.model.load_state_dict(checkpoint) self.model.eval() self.model.to(self.device) print( f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}" ) print(f"✅ 模型词汇表大小: {self.model.vocab_size}") def load_tokenizer(self): """加载tokenizer""" tokenizer_path = ( Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" ) self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") def load_query_engine(self): """加载查询引擎用于字符-ID转换""" try: self.query_engine = QueryEngine() stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): self.query_engine.load(stats_path) print( f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}" ) else: print(f"⚠️ 统计文件不存在: {stats_path}") self.query_engine = None except Exception as e: print(f"⚠️ 无法加载查询引擎: {e}") self.query_engine = None def _debug_sample(self): """调试:运行与test.py相同的固定样本""" print("\n🧪 调试:运行固定样本(与test.py相同)") # 复制test.py中的样本 part1 = "他是一名大学生,在上海读" part2 = "dayi" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) masked_labels = [15, 4, 0, 0, 0, 0, 0, 0] part3 = "。" part4 = "可行|特别|伤害" encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) sample = { "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]), "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]), "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]), "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze( 0 ), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } # 推理 logits, probs = self.inference(sample) analysis = self.analyze_probability_distribution(probs) print(f"📊 调试样本结果:") print(f" 最高概率: {analysis['max_prob']:.6f}") print(f" 平均概率: {analysis['mean_prob']:.6f}") print(f" 标准差: {analysis['std_prob']:.6f}") print(f" 低方差: {analysis['low_variance']}") # 打印top-5 print(f" Top-5预测:") for j in range(5): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] char_info = ( self.query_engine.query_by_id(idx) if self.query_engine else None ) char = char_info.char if char_info else f"[ID:{idx}]" print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}") def generate_pinyin(self, text: str) -> List[str]: """ 流式处理单条文本,转换为拼音列表。 参考dataset.py中的generate_pinyin方法。 """ if not text: return [] text_len = len(text) result: List[str] = [""] * text_len # 遍历所有连续汉字片段 for match in _HANZI_RE.finditer(text): start_idx = match.start() hanzi_segment = match.group() pinyin_list = lazy_pinyin(hanzi_segment) if len(pinyin_list) != len(hanzi_segment): pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] for i, py in enumerate(pinyin_list): result[start_idx + i] = py # 填充非汉字字符 for i, char in enumerate(text): if not result[i]: result[i] = char return result def get_mask_pinyin( self, text: str, pinyin_list: List[str] ) -> Tuple[int, List[str]]: """ 生成需要预测汉字对应的拼音,并进行加强。 参考dataset.py中的get_mask_pinyin方法。 """ mask_pinyin = [] for i in range(len(text)): if not self.query_engine or not self.query_engine.is_chinese_char(text[i]): break else: py = np.random.choice( (pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]), p=self.py_style_weight, ) if py == "": py = pinyin_list[i][0] mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin def create_sample_from_text( self, text: str, position: Optional[int] = None, force_slot_count: Optional[int] = None, ) -> Dict: """ 从文本创建单个样本,参考dataset.py的样本生成逻辑。 Args: text: 输入文本 position: 预测位置(汉字位置),如果为None则随机选择 force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算 Returns: 样本字典,包含模型所需的所有字段 """ # 确保文本长度至少为1 if not text: raise ValueError("文本不能为空") # 生成拼音列表 pinyin_list = self.generate_pinyin(text) # 找到所有汉字位置 chinese_positions = [ i for i, char in enumerate(text) if self.query_engine and self.query_engine.is_chinese_char(char) ] if not chinese_positions: raise ValueError("文本中没有汉字") # 随机选择预测位置 if position is None: position = random.choice(chinese_positions) elif position >= len(text) or position < 0: raise ValueError(f"位置 {position} 超出文本范围") elif not self.query_engine or not self.query_engine.is_chinese_char( text[position] ): raise ValueError(f"位置 {position} 不是汉字") i = position # part1: 光标前文本(最多48个字符) if i < 48: part1 = text[0:i] else: part1 = text[i - 48 : i] # part2: 拼音输入(随机长度1-8,高斯分布) pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02] pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) py_end = min(i + pinyin_len, len(text)) # 获取拼音 pinyin_len_actual, part2_pinyin = self.get_mask_pinyin( text[i:py_end], pinyin_list[i:py_end] ) # 添加分隔符 split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]) part2 = split_char.join(part2_pinyin) # 转换为拼音ID pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long) # part3: 光标后文本(70%概率为空) part3 = "" if random.random() > 0.7: part3 = text[ i + pinyin_len_actual : i + pinyin_len_actual + np.random.choice(range(1, 17)) ] # part4: 上下文提示(50%概率为空) part4 = "" if random.random() > 0.5: num_strings = random.randint(1, 5) string_list = [] for _ in range(num_strings): start_pos = random.randint(0, len(text) - 1) x = random.randint(2, 6) end_pos = min(start_pos + x + 1, len(text)) string_list.append(text[start_pos:end_pos]) part4 = "|".join(string_list) # 标签:预测字符的ID labels = [] try: labels = [ self.query_engine.get_char_info_by_char_pinyin(c, p).id for c, p in zip( text[i : i + pinyin_len_actual], pinyin_list[i : i + pinyin_len_actual], ) ] except (AttributeError, TypeError) as e: # 如果查询失败,使用默认ID print(f"⚠️ 获取标签失败: {e}") labels = [0] * pinyin_len_actual # 历史槽位:当前预测字符之前已确认的字符 # 从位置i之前的字符中提取历史槽位(最多8个) # 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取 history_slot_ids = [] if self.query_engine: # 收集i之前的汉字字符(最多8个) for j in range(i - 1, max(-1, i - 100), -1): if j < 0: break char = text[j] if self.query_engine.is_chinese_char(char): # 尝试获取该字符的ID(与训练数据一致) results = self.query_engine.query_by_char(char, limit=1) if results: slot_id = results[0][0] # 只添加有效的非零ID(与训练数据一致) if slot_id > 0: history_slot_ids.append(slot_id) if len(history_slot_ids) >= 8: break # 填充到8个槽位(使用0填充,与训练数据一致) if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] # 强制设置有效槽位数量(用于按槽位数量评估) # 注意:这里不再使用随机ID填充,而是复制现有的有效槽位 if force_slot_count is not None: force_slot_count = max(0, min(8, force_slot_count)) # 获取所有非零槽位ID valid_ids = [s for s in history_slot_ids if s != 0] if len(valid_ids) >= force_slot_count: # 如果有效槽位足够,直接取前force_slot_count个 history_slot_ids = valid_ids[:force_slot_count] else: # 如果有效槽位不足,保留现有有效槽位,然后用0填充 # 这样更符合训练数据的分布(槽位数量由实际历史决定) history_slot_ids = valid_ids[:] # 填充到8个槽位 while len(history_slot_ids) < 8: history_slot_ids.append(0) history_slot_ids = history_slot_ids[:8] # Tokenize输入 encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) # 计算有效槽位数量(ID > 0 的槽位,与训练数据一致) valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0) # 拼音输入长度(字符数) pinyin_input_length = len(part2) # 构建样本 sample = { "input_ids": encoded["input_ids"], "token_type_ids": encoded["token_type_ids"], "attention_mask": encoded["attention_mask"], "history_slot_ids": torch.tensor( history_slot_ids, dtype=torch.long ).unsqueeze(0), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids_tensor.unsqueeze(0), "true_labels": labels, # 真实标签(多个字符) "position": i, "target_char": text[i] if i < len(text) else "", "valid_slot_count": valid_slot_count, # 有效槽位数量 "pinyin_input_length": pinyin_input_length, # 拼音输入长度 } return sample def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]: """ 执行模型推理。 Returns: (logits, probs) """ # 准备输入张量并移动到设备 input_ids = sample["input_ids"].to(self.device) token_type_ids = sample["token_type_ids"].to(self.device) attention_mask = sample["attention_mask"].to(self.device) pinyin_ids = sample["pinyin_ids"].to(self.device) history_slot_ids = sample["history_slot_ids"].to(self.device) with torch.no_grad(): if self.device.type == "cuda": with torch.autocast(device_type="cuda"): logits = self.model( input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids, ) else: logits = self.model( input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids, ) probs = F.softmax(logits, dim=-1) return logits, probs def analyze_probability_distribution( self, probs: torch.Tensor, threshold_low_var: float = 0.01 ) -> Dict: """ 分析概率分布,检测低方差情况。 Args: probs: 概率张量,形状为 (1, vocab_size) threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差 Returns: 分析结果字典 """ probs_np = probs.cpu().numpy().flatten() # 基本统计 max_prob = np.max(probs_np) min_prob = np.min(probs_np) mean_prob = np.mean(probs_np) std_prob = np.std(probs_np) # 熵 entropy = -np.sum(probs_np * np.log(probs_np + 1e-12)) # 低方差检测 low_variance = max_prob < threshold_low_var # Top-k 分析 top_k = 20 top_indices = np.argsort(probs_np)[-top_k:][::-1] top_probs = probs_np[top_indices] # 检查前5个概率是否接近均匀 top5_probs = top_probs[:5] top5_ratio = ( top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0 ) analysis = { "max_prob": max_prob, "min_prob": min_prob, "mean_prob": mean_prob, "std_prob": std_prob, "entropy": entropy, "low_variance": low_variance, "top_indices": top_indices, "top_probs": top_probs, "top5_ratio": top5_ratio, "vocab_size": len(probs_np), } return analysis def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: """ 评估单个样本。 Returns: 评估结果字典 """ # 执行推理 logits, probs = self.inference(sample) # 分析概率分布 analysis = self.analyze_probability_distribution(probs) # 获取真实标签 true_labels = sample.get("true_labels", []) target_char = sample.get("target_char", "") # 检查预测是否正确(只检查第一个字符) correct = False pred_top_idx = analysis["top_indices"][0] if true_labels and len(true_labels) > 0: correct = pred_top_idx == true_labels[0] # 打印结果 if print_details: # 获取槽位信息 valid_slot_count = sample.get("valid_slot_count", 0) pinyin_input_length = sample.get("pinyin_input_length", 0) # 简化输出格式(在同一行继续输出) print(f" {target_char}", end="") if true_labels: print(f"(ID:{true_labels[0]})", end="") print( f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}", end="", ) print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="") # 关键概率信息 print(f" | 最高概率:{analysis['max_prob']:.3f}", end="") if analysis["low_variance"]: print(f" ⚠️", end="") # Top-2 预测 top_strs = [] for j in range(min(2, len(analysis["top_indices"]))): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] # 处理特殊ID if idx == 0: char = "[空]" else: char_info = ( self.query_engine.query_by_id(idx) if self.query_engine else None ) char = char_info.char if char_info else f"[{idx}]" is_correct = ( (true_labels and idx == true_labels[0]) if true_labels else False ) correct_mark = "✓" if is_correct else "" top_strs.append(f"{char}{correct_mark}({prob:.2f})") print(f" | Top-2: {' '.join(top_strs)}", end="") print(f" | 结果:{'✓' if correct else '✗'}") # 返回评估结果 result = { "sample": sample, "analysis": analysis, "correct": correct, "target_char": target_char, "true_label": true_labels[0] if true_labels else None, "predicted_label": pred_top_idx, } return result def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]: """ 评估给定文本,生成多个样本进行评估。 Args: text: 输入文本 num_samples: 生成的样本数量 Returns: 评估结果列表 """ # 限制文本长度 if len(text) > 300: start = random.randint(0, len(text) - 300) text = text[start : start + 300] print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...") else: print(f"📝 文本长度: {len(text)} 字符") results = [] low_variance_samples = [] for sample_idx in range(num_samples): print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ") try: sample = self.create_sample_from_text(text) result = self.evaluate_sample(sample, print_details=True) results.append(result) if result["analysis"]["low_variance"]: low_variance_samples.append( { "sample_idx": sample_idx, "max_prob": result["analysis"]["max_prob"], "entropy": result["analysis"]["entropy"], } ) except Exception as e: print(f"❌ 样本生成或评估失败: {e}") import traceback traceback.print_exc() continue if results: self.print_summary(results, low_variance_samples) return results def evaluate_by_slot_count( self, text: str, samples_per_slot: int = 100 ) -> Dict[int, Dict]: """ 按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。 Args: text: 输入文本 samples_per_slot: 每个槽位数量的测试样本数 Returns: 按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}} """ if len(text) > 300: start = random.randint(0, len(text) - 300) text = text[start : start + 300] print(f"\n{'=' * 60}") print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本") print(f"{'=' * 60}") slot_results = {} for slot_count in range(0, 9): results = [] low_var_count = 0 failed = 0 for sample_idx in range(samples_per_slot): try: sample = self.create_sample_from_text( text, force_slot_count=slot_count ) result = self.evaluate_sample(sample, print_details=False) results.append(result) if result["analysis"]["low_variance"]: low_var_count += 1 except Exception: failed += 1 continue if results: correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) max_probs = [r["analysis"]["max_prob"] for r in results] mean_max_prob = np.mean(max_probs) entropies = [r["analysis"]["entropy"] for r in results] mean_entropy = np.mean(entropies) slot_results[slot_count] = { "results": results, "accuracy": accuracy, "correct": correct_count, "total": len(results), "mean_max_prob": mean_max_prob, "mean_entropy": mean_entropy, "low_variance_count": low_var_count, "failed": failed, } bar_len = 30 filled = int(bar_len * accuracy) bar = "█" * filled + "░" * (bar_len - filled) print( f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " f"({correct_count:2d}/{len(results):2d}) " f"| 平均最高概率: {mean_max_prob:.4f} " f"| 平均熵: {mean_entropy:.2f} " f"| 低方差: {low_var_count}" ) else: slot_results[slot_count] = { "results": [], "accuracy": 0, "correct": 0, "total": 0, "mean_max_prob": 0, "mean_entropy": 0, "low_variance_count": 0, "failed": failed, } print(f" 槽位 {slot_count}/8 | 全部样本生成失败") print(f"\n{'=' * 60}") overall_correct = sum(s["correct"] for s in slot_results.values()) overall_total = sum(s["total"] for s in slot_results.values()) if overall_total > 0: print( f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})" ) print(f"{'=' * 60}") return slot_results def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]): """打印评估汇总""" print(f"\n--- 汇总 ({len(results)}样本) ---") # 计算准确率 correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) if results else 0 print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})") # 概率统计 max_probs = [r["analysis"]["max_prob"] for r in results] mean_max_prob = np.mean(max_probs) if max_probs else 0 print(f"平均最高概率: {mean_max_prob:.4f}") # 预测ID:0的比例 zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0) zero_pred_ratio = zero_pred_count / len(results) if results else 0 print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})") # 平均熵 entropies = [r["analysis"]["entropy"] for r in results] mean_entropy = np.mean(entropies) if entropies else 0 print(f"平均熵: {mean_entropy:.2f}") # 槽位和拼音长度统计 if results and "sample" in results[0]: valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results] pinyin_lengths = [ r["sample"].get("pinyin_input_length", 0) for r in results ] avg_slots = np.mean(valid_slots) if valid_slots else 0 avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0 print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}") # 低方差样本 if low_variance_samples: print(f"低方差样本: {len(low_variance_samples)}个") if len(low_variance_samples) <= 5: for lv in low_variance_samples: print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}") else: for lv in low_variance_samples[:3]: print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}") print(f" ... 还有{len(low_variance_samples) - 3}个") else: print(f"低方差样本: 0个") def main(): parser = argparse.ArgumentParser(description="评估模型在文本上的表现") parser.add_argument("--text-file", type=str, required=True, help="文本文件路径") parser.add_argument( "--checkpoint", type=str, default="/home/songsenand/下载/best_model.pt", help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)", ) parser.add_argument( "--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)" ) parser.add_argument( "--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="推理设备 (默认: auto)", ) args = parser.parse_args() # 选择设备 if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device # 读取文本文件 try: with open(args.text_file, "r", encoding="utf-8") as f: text = f.read().strip() except Exception as e: print(f"❌ 无法读取文本文件: {e}") sys.exit(1) if not text: print("❌ 文本文件为空") sys.exit(1) print(f"📄 读取文本文件: {args.text_file}") print(f"📏 原始文本长度: {len(text)} 字符") print(f"🔧 使用设备: {device}") print(f"📦 模型checkpoint: {args.checkpoint}") print(f"🔬 评估样本数: {args.num_samples}") # 初始化评估器 try: evaluator = TextEvaluator(args.checkpoint, device) except Exception as e: print(f"❌ 初始化评估器失败: {e}") import traceback traceback.print_exc() sys.exit(1) # 执行评估 evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples) if __name__ == "__main__": main()