#!/usr/bin/env python3 """ eval.py - 评估模型在给定文本上的表现 使用方法: python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N] 功能: 1. 读取文本文件,限制长度300,随机抽取 2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息 3. 使用模型推理,参考test.py 4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记 """ import argparse import random import re import sys from pathlib import Path from typing import Dict, List, Tuple, Optional import numpy as np import torch import torch.nn.functional as F from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials # 添加src目录到路径 sys.path.append("src") from src.model.model import InputMethodEngine from src.model.query import QueryEngine from src.model.dataset import text_to_pinyin_ids _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") class TextEvaluator: def __init__( self, checkpoint_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): self.device = torch.device(device) self.checkpoint_path = checkpoint_path # 加载组件 self.load_model() self.load_tokenizer() self.load_query_engine() # 拼音风格权重(与dataset.py一致) self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1]) print(f"✅ 评估器初始化完成 (设备: {self.device})") # 调试:运行一个固定样本测试 import os if os.environ.get("EVAL_DEBUG"): self._debug_sample() def load_model(self): """加载训练好的模型""" self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load(self.checkpoint_path, map_location="cpu") if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: self.model.load_state_dict(checkpoint) self.model.eval() self.model.to(self.device) print( f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}" ) print(f"✅ 模型词汇表大小: {self.model.vocab_size}") def load_tokenizer(self): """加载tokenizer""" tokenizer_path = ( Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" ) self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") def load_query_engine(self): """加载查询引擎用于字符-ID转换""" try: self.query_engine = QueryEngine() stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): self.query_engine.load(stats_path) print( f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}" ) else: print(f"⚠️ 统计文件不存在: {stats_path}") self.query_engine = None except Exception as e: print(f"⚠️ 无法加载查询引擎: {e}") self.query_engine = None def _debug_sample(self): """调试:运行与test.py相同的固定样本""" print("\n🧪 调试:运行固定样本(与test.py相同)") # 复制test.py中的样本 part1 = "他是一名大学生,在上海读" part2 = "dayi" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) masked_labels = [15, 4, 0, 0, 0, 0, 0, 0] part3 = "。" part4 = "可行|特别|伤害" encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) sample = { "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]), "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]), "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]), "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze( 0 ), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } # 推理 logits, probs = self.inference(sample) analysis = self.analyze_probability_distribution(probs) print(f"📊 调试样本结果:") print(f" 最高概率: {analysis['max_prob']:.6f}") print(f" 平均概率: {analysis['mean_prob']:.6f}") print(f" 标准差: {analysis['std_prob']:.6f}") print(f" 低方差: {analysis['low_variance']}") # 打印top-5 print(f" Top-5预测:") for j in range(5): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] char_info = ( self.query_engine.query_by_id(idx) if self.query_engine else None ) char = char_info.char if char_info else f"[ID:{idx}]" print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}") def generate_pinyin(self, text: str) -> List[str]: """ 流式处理单条文本,转换为拼音列表。 参考dataset.py中的generate_pinyin方法。 """ if not text: return [] text_len = len(text) result: List[str] = [""] * text_len # 遍历所有连续汉字片段 for match in _HANZI_RE.finditer(text): start_idx = match.start() hanzi_segment = match.group() pinyin_list = lazy_pinyin(hanzi_segment) if len(pinyin_list) != len(hanzi_segment): pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] for i, py in enumerate(pinyin_list): result[start_idx + i] = py # 填充非汉字字符 for i, char in enumerate(text): if not result[i]: result[i] = char return result def get_mask_pinyin( self, text: str, pinyin_list: List[str] ) -> Tuple[int, List[str]]: """ 生成需要预测汉字对应的拼音,并进行加强。 参考dataset.py中的get_mask_pinyin方法。 """ mask_pinyin = [] for i in range(len(text)): if not self.query_engine or not self.query_engine.is_chinese_char(text[i]): break else: py = np.random.choice( (pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]), p=self.py_style_weight, ) if py == "": py = pinyin_list[i][0] mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin def create_sample_from_text( self, text: str, position: Optional[int] = None, force_slot_count: Optional[int] = None, ) -> Dict: """ 从文本创建单个样本,参考dataset.py的样本生成逻辑。 Args: text: 输入文本 position: 预测位置(汉字位置),如果为None则随机选择 force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算 Returns: 样本字典,包含模型所需的所有字段 """ # 确保文本长度至少为1 if not text: raise ValueError("文本不能为空") # 生成拼音列表 pinyin_list = self.generate_pinyin(text) # 找到所有汉字位置 chinese_positions = [ i for i, char in enumerate(text) if self.query_engine and self.query_engine.is_chinese_char(char) ] if not chinese_positions: raise ValueError("文本中没有汉字") # 随机选择预测位置 if position is None: position = random.choice(chinese_positions) elif position >= len(text) or position < 0: raise ValueError(f"位置 {position} 超出文本范围") elif not self.query_engine or not self.query_engine.is_chinese_char( text[position] ): raise ValueError(f"位置 {position} 不是汉字") i = position # part1: 光标前文本(最多48个字符) if i < 48: part1 = text[0:i] else: part1 = text[i - 48 : i] # part2: 拼音输入(随机长度1-8,高斯分布) # 当强制指定历史槽位时,需要确保拼音长度足够 pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] min_pinyin_len = (force_slot_count + 1) if force_slot_count is not None else 1 min_pinyin_len = max(1, min(8, min_pinyin_len)) # 检查剩余可用字符数量 remaining_chars = len(text) - i max_valid_len = 0 for j in range(i, min(i + 8, len(text))): if self.query_engine and self.query_engine.is_chinese_char(text[j]): max_valid_len += 1 else: break # 调整拼音长度:必须至少有 min_pinyin_len 个有效字符 if max_valid_len < min_pinyin_len: # 重新选择位置,确保有足够字符 raise ValueError( f"位置 {i} 后只有 {max_valid_len} 个有效汉字,不足以测试历史槽位 {force_slot_count}" ) # 随机选择拼音长度(至少 min_pinyin_len) if force_slot_count is not None: # 强制模式:随机选择比 min_pinyin_len 大的长度 valid_lengths = [ l for l in range(min_pinyin_len, min(max_valid_len + 1, 9)) ] if not valid_lengths: valid_lengths = [min_pinyin_len] # 使用原来的概率分布,但只选择有效长度 adjusted_probs = [pinyin_len_probs[l - 1] for l in valid_lengths] adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs) pinyin_len = np.random.choice(valid_lengths, p=adjusted_probs) else: pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) pinyin_len = min(pinyin_len, max_valid_len) py_end = min(i + pinyin_len, len(text)) # 获取拼音 pinyin_len_actual, part2_pinyin = self.get_mask_pinyin( text[i:py_end], pinyin_list[i:py_end] ) # 添加分隔符 split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]) part2 = split_char.join(part2_pinyin) # 转换为拼音ID pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long) # part3: 光标后文本(70%概率为空) part3 = "" if random.random() > 0.7: part3 = text[ i + pinyin_len_actual : i + pinyin_len_actual + np.random.choice(range(1, 17)) ] # part4: 上下文提示(50%概率为空) part4 = "" if random.random() > 0.5: num_strings = random.randint(1, 5) string_list = [] for _ in range(num_strings): start_pos = random.randint(0, len(text) - 1) x = random.randint(2, 6) end_pos = min(start_pos + x + 1, len(text)) string_list.append(text[start_pos:end_pos]) part4 = "|".join(string_list) # 标签:预测字符的ID(整个拼音序列的所有字符) labels = [] try: labels = [ self.query_engine.get_char_info_by_char_pinyin(c, p).id for c, p in zip( text[i : i + pinyin_len_actual], pinyin_list[i : i + pinyin_len_actual], ) ] except (AttributeError, TypeError) as e: print(f"⚠️ 获取标签失败: {e}") labels = [0] * pinyin_len_actual # 历史槽位:与训练数据生成逻辑一致 # 当 force_slot_count=k 时,历史槽位为 labels[:k],预测目标为 labels[k] history_slot_ids = [] predict_idx = 0 # 当前预测的标签索引 if force_slot_count is not None: force_slot_count = max(0, min(8, force_slot_count)) # 检查 labels 长度是否足够 if len(labels) <= force_slot_count: raise ValueError( f"标签数量 {len(labels)} 不足以支持历史槽位 {force_slot_count}" ) predict_idx = force_slot_count history_slot_ids = labels[:force_slot_count] else: # 正常评估:随机选择历史长度,模拟训练数据分布 history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] max_history = min(len(labels) - 1, 8) # 至少留一个字符预测 if max_history < 0: max_history = 0 valid_history_lens = list(range(0, max_history + 1)) if valid_history_lens: adjusted_probs = [history_weights[h] for h in valid_history_lens] adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs) history_len = np.random.choice(valid_history_lens, p=adjusted_probs) else: history_len = 0 predict_idx = history_len history_slot_ids = labels[:history_len] # 填充到8个槽位 if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] # Tokenize输入 encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) # 计算有效槽位数量(ID > 0 的槽位,与训练数据一致) valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0) # 拼音输入长度(字符数) pinyin_input_length = len(part2) # 当前预测的字符和标签 current_target_char = ( text[i + predict_idx] if (i + predict_idx) < len(text) else "" ) current_target_label = labels[predict_idx] if predict_idx < len(labels) else 0 # 构建样本 sample = { "input_ids": encoded["input_ids"], "token_type_ids": encoded["token_type_ids"], "attention_mask": encoded["attention_mask"], "history_slot_ids": torch.tensor( history_slot_ids, dtype=torch.long ).unsqueeze(0), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids_tensor.unsqueeze(0), "true_labels": labels, # 完整标签列表(所有字符) "predict_idx": predict_idx, # 当前预测的标签索引 "current_target_label": current_target_label, # 当前预测的标签 "position": i, "target_char": current_target_char, # 当前预测的字符 "valid_slot_count": valid_slot_count, "pinyin_input_length": pinyin_input_length, } return sample def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]: """ 执行模型推理。 Returns: (logits, probs) """ # 准备输入张量并移动到设备 input_ids = sample["input_ids"].to(self.device) token_type_ids = sample["token_type_ids"].to(self.device) attention_mask = sample["attention_mask"].to(self.device) pinyin_ids = sample["pinyin_ids"].to(self.device) history_slot_ids = sample["history_slot_ids"].to(self.device) with torch.no_grad(): if self.device.type == "cuda": with torch.autocast(device_type="cuda"): logits = self.model( input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids, ) else: logits = self.model( input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids, ) probs = F.softmax(logits, dim=-1) return logits, probs def analyze_probability_distribution( self, probs: torch.Tensor, threshold_low_var: float = 0.01 ) -> Dict: """ 分析概率分布,检测低方差情况。 Args: probs: 概率张量,形状为 (1, vocab_size) threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差 Returns: 分析结果字典 """ probs_np = probs.cpu().numpy().flatten() # 基本统计 max_prob = np.max(probs_np) min_prob = np.min(probs_np) mean_prob = np.mean(probs_np) std_prob = np.std(probs_np) # 熵 entropy = -np.sum(probs_np * np.log(probs_np + 1e-12)) # 低方差检测 low_variance = max_prob < threshold_low_var # Top-k 分析 top_k = 20 top_indices = np.argsort(probs_np)[-top_k:][::-1] top_probs = probs_np[top_indices] # 检查前5个概率是否接近均匀 top5_probs = top_probs[:5] top5_ratio = ( top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0 ) analysis = { "max_prob": max_prob, "min_prob": min_prob, "mean_prob": mean_prob, "std_prob": std_prob, "entropy": entropy, "low_variance": low_variance, "top_indices": top_indices, "top_probs": top_probs, "top5_ratio": top5_ratio, "vocab_size": len(probs_np), } return analysis def evaluate_sample_single_char( self, sample: Dict, print_details: bool = True ) -> Dict: """ 单字符评估(与验证集 trainer.py 评估逻辑一致)。 使用 argmax 选择预测字符,与训练时验证逻辑完全一致。 根据 predict_idx 选择正确的预测目标。 Returns: 评估结果字典 """ logits, probs = self.inference(sample) analysis = self.analyze_probability_distribution(probs) true_labels = sample.get("true_labels", []) predict_idx = sample.get("predict_idx", 0) current_target_label = sample.get( "current_target_label", true_labels[predict_idx] if true_labels else 0 ) target_char = sample.get("target_char", "") pred_idx = analysis["top_indices"][0] correct = pred_idx == current_target_label if print_details: valid_slot_count = sample.get("valid_slot_count", 0) pinyin_input_length = sample.get("pinyin_input_length", 0) print(f" {target_char}", end="") print(f"(ID:{current_target_label})", end="") print( f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}", end="", ) print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="") print(f" | 最高概率:{analysis['max_prob']:.3f}", end="") if analysis["low_variance"]: print(f" ⚠️", end="") top_strs = [] for j in range(min(2, len(analysis["top_indices"]))): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] if idx == 0: char = "[空]" else: char_info = ( self.query_engine.query_by_id(idx) if self.query_engine else None ) char = char_info.char if char_info else f"[{idx}]" is_correct = idx == current_target_label correct_mark = "✓" if is_correct else "" top_strs.append(f"{char}{correct_mark}({prob:.2f})") print(f" | Top-2: {' '.join(top_strs)}", end="") print(f" | 结果:{'✓' if correct else '✗'}") return { "sample": sample, "analysis": analysis, "correct": correct, "target_char": target_char, "true_label": current_target_label, "predicted_label": pred_idx, "predict_idx": predict_idx, } def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: """ 评估单个样本(兼容旧接口,调用单字符评估)。 Returns: 评估结果字典 """ return self.evaluate_sample_single_char(sample, print_details) def evaluate_sample_with_sequential_confirmation( self, sample: Dict, print_details: bool = True, use_oracle_history: bool = True ) -> Dict: """ 评估样本并模拟用户逐步确认过程(用于多字符预测)。 Args: sample: 样本字典 print_details: 是否打印详细信息 use_oracle_history: True=使用真实标签作为历史(与训练/验证一致),False=使用预测结果 对于拼音长度>1的情况,模拟用户逐步选择过程。 Returns: 评估结果字典 """ # 获取真实标签和拼音长度 true_labels = sample.get("true_labels", []) pinyin_len = len(true_labels) if true_labels else 1 if pinyin_len <= 1: # 单字符预测,直接使用普通评估 return self.evaluate_sample(sample, print_details) # 多字符预测:模拟逐步确认 confirmed_history = [] # 已确认的字符ID all_predictions = [] # 每一步的预测结果 all_correct = [] # 每一步是否正确 # 复制样本用于逐步推理 sample_copy = sample.copy() for step in range(pinyin_len): # 更新历史槽位: # - oracle模式:使用真实标签作为历史(与训练一致) # - user模式:使用模型预测作为历史(模拟实际使用) if use_oracle_history: # 使用真实标签作为历史(与训练数据生成逻辑一致) history_slot_ids = true_labels[:step] else: # 使用预测结果作为历史 history_slot_ids = confirmed_history[:] if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] # 更新样本的历史槽位 sample_copy["history_slot_ids"] = torch.tensor( history_slot_ids, dtype=torch.long ).unsqueeze(0) # 执行推理 logits, probs = self.inference(sample_copy) analysis = self.analyze_probability_distribution(probs) # 获取Top-1预测 pred_top_idx = analysis["top_indices"][0] pred_top_prob = analysis["top_probs"][0] # 检查是否正确 correct = False if step < len(true_labels): correct = pred_top_idx == true_labels[step] # 记录结果 all_predictions.append( { "predicted_id": pred_top_idx, "probability": pred_top_prob, "correct": correct, } ) all_correct.append(correct) # 更新已确认历史(用于下一轮) if pred_top_idx > 0: confirmed_history.append(pred_top_idx) # 计算整体准确率:所有步骤都正确才算正确 overall_correct = all(all_correct) # 打印结果 if print_details: target_char = sample.get("target_char", "") pinyin = sample.get("pinyin", "") valid_slot_count = sample.get("valid_slot_count", 0) print(f" {target_char}", end="") print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="") print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="") mode_str = "oracle" if use_oracle_history else "user" print(f" | [{mode_str}]", end="") # 显示逐步预测结果 step_results = [] for i, pred in enumerate(all_predictions): char_info = ( self.query_engine.query_by_id(pred["predicted_id"]) if self.query_engine else None ) char = char_info.char if char_info else f"[{pred['predicted_id']}]" correct_mark = "✓" if pred["correct"] else "✗" step_results.append(f"{char}{correct_mark}") print(f" | 逐步:{'→'.join(step_results)}", end="") print(f" | 结果:{'✓' if overall_correct else '✗'}") # 返回评估结果 result = { "sample": sample, "sequential_predictions": all_predictions, "correct": overall_correct, "step_correct": all_correct, "target_char": sample.get("target_char", ""), "true_labels": true_labels, "pinyin_len": pinyin_len, } return result def evaluate_text( self, text: str, num_samples: int = 10, use_sequential: bool = False ) -> List[Dict]: """ 评估给定文本,生成多个样本进行评估。 Args: text: 输入文本 num_samples: 生成的样本数量 use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致) Returns: 评估结果列表 """ # 限制文本长度 if len(text) > 300: start = random.randint(0, len(text) - 300) text = text[start : start + 300] print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...") else: print(f"📝 文本长度: {len(text)} 字符") results = [] low_variance_samples = [] for sample_idx in range(num_samples): print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ") try: sample = self.create_sample_from_text(text) # 默认使用单字符评估(与验证集一致) if use_sequential: result = self.evaluate_sample_with_sequential_confirmation( sample, print_details=True, use_oracle_history=True ) else: result = self.evaluate_sample_single_char( sample, print_details=True ) results.append(result) except Exception as e: print(f"❌ 样本生成或评估失败: {e}") import traceback traceback.print_exc() continue if results: self.print_summary(results, low_variance_samples) return results def evaluate_by_slot_count( self, text: str, samples_per_slot: int = 100, use_sequential: bool = False ) -> Dict[int, Dict]: """ 按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。 Args: text: 输入文本 samples_per_slot: 每个槽位数量的测试样本数 use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致) Returns: 按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}} """ if len(text) > 300: start = random.randint(0, len(text) - 300) text = text[start : start + 300] print(f"\n{'=' * 60}") mode_str = "逐步确认" if use_sequential else "单字符(与验证集一致)" print(f"按槽位数量评估 [{mode_str}] | 每组 {samples_per_slot} 个样本") print(f"{'=' * 60}") slot_results = {} for slot_count in range(0, 9): results = [] low_var_count = 0 failed = 0 for sample_idx in range(samples_per_slot): try: sample = self.create_sample_from_text( text, force_slot_count=slot_count ) # 默认使用单字符评估(与验证集一致) if use_sequential: result = self.evaluate_sample_with_sequential_confirmation( sample, print_details=False, use_oracle_history=True ) else: result = self.evaluate_sample_single_char( sample, print_details=False ) results.append(result) # 统计低方差样本 if "analysis" in result and result["analysis"].get("low_variance"): low_var_count += 1 except Exception: failed += 1 continue if results: correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) # 计算平均步数准确率(仅用于逐步确认模式) step_accuracies = [] if use_sequential: for r in results: if "step_correct" in r: step_correct = r["step_correct"] if step_correct: step_acc = sum(step_correct) / len(step_correct) step_accuracies.append(step_acc) mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0 slot_results[slot_count] = { "results": results, "accuracy": accuracy, "correct": correct_count, "total": len(results), "mean_step_accuracy": mean_step_accuracy, "low_var_count": low_var_count, "failed": failed, } bar_len = 30 filled = int(bar_len * accuracy) bar = "█" * filled + "░" * (bar_len - filled) if use_sequential: print( f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " f"({correct_count:2d}/{len(results):2d}) " f"| 平均步数准确率: {mean_step_accuracy:.1%}" ) else: print( f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " f"({correct_count:2d}/{len(results):2d}) " f"| 低方差: {low_var_count}" ) else: slot_results[slot_count] = { "results": [], "accuracy": 0, "correct": 0, "total": 0, "mean_step_accuracy": 0, "low_var_count": 0, "failed": failed, } print(f" 槽位 {slot_count}/8 | 全部样本生成失败") print(f"\n{'=' * 60}") overall_correct = sum(s["correct"] for s in slot_results.values()) overall_total = sum(s["total"] for s in slot_results.values()) if overall_total > 0: print( f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})" ) print(f"{'=' * 60}") return slot_results def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]): """打印评估汇总""" print(f"\n--- 汇总 ({len(results)}样本) ---") # 计算准确率 correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) if results else 0 print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})") # 计算平均步数准确率(对于多字符预测) step_accuracies = [] for r in results: if "step_correct" in r: step_correct = r["step_correct"] if step_correct: step_acc = sum(step_correct) / len(step_correct) step_accuracies.append(step_acc) if step_accuracies: mean_step_accuracy = np.mean(step_accuracies) print(f"平均步数准确率: {mean_step_accuracy:.1%}") # 槽位和拼音长度统计 if results and "sample" in results[0]: valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results] pinyin_lengths = [r.get("pinyin_len", 1) for r in results] avg_slots = np.mean(valid_slots) if valid_slots else 0 avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0 print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}") def main(): parser = argparse.ArgumentParser(description="评估模型在文本上的表现") parser.add_argument("--text-file", type=str, required=True, help="文本文件路径") parser.add_argument( "--checkpoint", type=str, default="/home/songsenand/下载/best_model.ptrom", help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.ptrom)", ) parser.add_argument( "--num-samples", type=int, default=100, help="评估样本数量 (默认: 100)" ) parser.add_argument( "--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="推理设备 (默认: auto)", ) parser.add_argument( "--use-sequential", action="store_true", help="使用逐步确认评估模式(默认使用单字符评估,与验证集一致)", ) args = parser.parse_args() # 选择设备 if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device # 读取文本文件 try: with open(args.text_file, "r", encoding="utf-8") as f: text = f.read().strip() except Exception as e: print(f"❌ 无法读取文本文件: {e}") sys.exit(1) if not text: print("❌ 文本文件为空") sys.exit(1) print(f"📄 读取文本文件: {args.text_file}") print(f"📏 原始文本长度: {len(text)} 字符") print(f"🔧 使用设备: {device}") print(f"📦 模型checkpoint: {args.checkpoint}") print(f"🔬 评估样本数: {args.num_samples}") print( f"📊 评估模式: {'逐步确认' if args.use_sequential else '单字符(与验证集一致)'}" ) # 初始化评估器 try: evaluator = TextEvaluator(args.checkpoint, device) except Exception as e: print(f"❌ 初始化评估器失败: {e}") import traceback traceback.print_exc() sys.exit(1) # 执行评估 evaluator.evaluate_by_slot_count( text, samples_per_slot=args.num_samples, use_sequential=args.use_sequential ) if __name__ == "__main__": main()