#!/usr/bin/env python3 """ eval.py - 评估模型在给定文本上的表现 使用方法: python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N] 功能: 1. 读取文本文件,限制长度300,随机抽取 2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息 3. 使用模型推理,参考test.py 4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记 """ import argparse import random import re import sys from pathlib import Path from typing import Dict, List, Tuple, Optional import numpy as np import torch import torch.nn.functional as F from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials # 添加src目录到路径 sys.path.append("src") from src.model.model import InputMethodEngine from src.model.query import QueryEngine from src.model.dataset import text_to_pinyin_ids _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") class TextEvaluator: def __init__( self, checkpoint_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): self.device = torch.device(device) self.checkpoint_path = checkpoint_path # 加载组件 self.load_model() self.load_tokenizer() self.load_query_engine() # 拼音风格权重(与dataset.py一致) self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1]) print(f"✅ 评估器初始化完成 (设备: {self.device})") # 调试:运行一个固定样本测试 import os if os.environ.get("EVAL_DEBUG"): self._debug_sample() def load_model(self): """加载训练好的模型""" self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load(self.checkpoint_path, map_location="cpu") if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: self.model.load_state_dict(checkpoint) self.model.eval() self.model.to(self.device) print( f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}" ) print(f"✅ 模型词汇表大小: {self.model.vocab_size}") def load_tokenizer(self): """加载tokenizer""" tokenizer_path = ( Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" ) self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") def load_query_engine(self): """加载查询引擎用于字符-ID转换""" try: self.query_engine = QueryEngine() stats_path = ( Path(__file__).parent / "src" / "model" / "assets" / "pinyin_char_statistics.json" ) if stats_path.exists(): self.query_engine.load(stats_path) print( f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}" ) else: print(f"⚠️ 统计文件不存在: {stats_path}") self.query_engine = None except Exception as e: print(f"⚠️ 无法加载查询引擎: {e}") self.query_engine = None def _debug_sample(self): """调试:运行与test.py相同的固定样本""" print("\n🧪 调试:运行固定样本(与test.py相同)") # 复制test.py中的样本 part1 = "他是一名大学生,在上海读" part2 = "dayi" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) masked_labels = [15, 4, 0, 0, 0, 0, 0, 0] part3 = "。" part4 = "可行|特别|伤害" encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) sample = { "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]), "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]), "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]), "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze( 0 ), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } # 推理 logits, probs = self.inference(sample) analysis = self.analyze_probability_distribution(probs) print(f"📊 调试样本结果:") print(f" 最高概率: {analysis['max_prob']:.6f}") print(f" 平均概率: {analysis['mean_prob']:.6f}") print(f" 标准差: {analysis['std_prob']:.6f}") print(f" 低方差: {analysis['low_variance']}") # 打印top-5 print(f" Top-5预测:") for j in range(5): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] char_info = ( self.query_engine.query_by_id(idx) if self.query_engine else None ) char = char_info.char if char_info else f"[ID:{idx}]" print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}") def generate_pinyin(self, text: str) -> List[str]: """ 流式处理单条文本,转换为拼音列表。 参考dataset.py中的generate_pinyin方法。 """ if not text: return [] text_len = len(text) result: List[str] = [""] * text_len # 遍历所有连续汉字片段 for match in _HANZI_RE.finditer(text): start_idx = match.start() hanzi_segment = match.group() pinyin_list = lazy_pinyin(hanzi_segment) if len(pinyin_list) != len(hanzi_segment): pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] for i, py in enumerate(pinyin_list): result[start_idx + i] = py # 填充非汉字字符 for i, char in enumerate(text): if not result[i]: result[i] = char return result def get_mask_pinyin( self, text: str, pinyin_list: List[str] ) -> Tuple[int, List[str]]: """ 生成需要预测汉字对应的拼音,并进行加强。 参考dataset.py中的get_mask_pinyin方法。 """ mask_pinyin = [] for i in range(len(text)): if not self.query_engine or not self.query_engine.is_chinese_char(text[i]): break else: py = np.random.choice( (pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]), p=self.py_style_weight, ) if py == "": py = pinyin_list[i][0] mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin def create_sample_from_text( self, text: str, position: Optional[int] = None, force_slot_count: Optional[int] = None, ) -> Dict: """ 从文本创建单个样本,参考dataset.py的样本生成逻辑。 Args: text: 输入文本 position: 预测位置(汉字位置),如果为None则随机选择 force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算 Returns: 样本字典,包含模型所需的所有字段 """ # 确保文本长度至少为1 if not text: raise ValueError("文本不能为空") # 生成拼音列表 pinyin_list = self.generate_pinyin(text) # 找到所有汉字位置 chinese_positions = [ i for i, char in enumerate(text) if self.query_engine and self.query_engine.is_chinese_char(char) ] if not chinese_positions: raise ValueError("文本中没有汉字") # 随机选择预测位置 if position is None: position = random.choice(chinese_positions) elif position >= len(text) or position < 0: raise ValueError(f"位置 {position} 超出文本范围") elif not self.query_engine or not self.query_engine.is_chinese_char( text[position] ): raise ValueError(f"位置 {position} 不是汉字") i = position # part1: 光标前文本(最多48个字符) if i < 48: part1 = text[0:i] else: part1 = text[i - 48 : i] # part2: 拼音输入(随机长度1-8,高斯分布) pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) py_end = min(i + pinyin_len, len(text)) # 获取拼音 pinyin_len_actual, part2_pinyin = self.get_mask_pinyin( text[i:py_end], pinyin_list[i:py_end] ) # 添加分隔符 split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]) part2 = split_char.join(part2_pinyin) # 转换为拼音ID pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long) # part3: 光标后文本(70%概率为空) part3 = "" if random.random() > 0.7: part3 = text[ i + pinyin_len_actual : i + pinyin_len_actual + np.random.choice(range(1, 17)) ] # part4: 上下文提示(50%概率为空) part4 = "" if random.random() > 0.5: num_strings = random.randint(1, 5) string_list = [] for _ in range(num_strings): start_pos = random.randint(0, len(text) - 1) x = random.randint(2, 6) end_pos = min(start_pos + x + 1, len(text)) string_list.append(text[start_pos:end_pos]) part4 = "|".join(string_list) # 标签:预测字符的ID labels = [] try: labels = [ self.query_engine.get_char_info_by_char_pinyin(c, p).id for c, p in zip( text[i : i + pinyin_len_actual], pinyin_list[i : i + pinyin_len_actual], ) ] except (AttributeError, TypeError) as e: # 如果查询失败,使用默认ID print(f"⚠️ 获取标签失败: {e}") labels = [0] * pinyin_len_actual # 历史槽位:模拟用户逐步确认过程 # 对于多字符预测(pinyin_len_actual > 1),需要模拟用户逐步选择 # 注意:评估时不知道正确答案,只能基于模型预测来构建历史 history_slot_ids = [] # 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致) if force_slot_count is not None: force_slot_count = max(0, min(8, force_slot_count)) # 创建模拟的历史槽位:前force_slot_count个为有效槽位,其余为0 # 这里使用简单的模拟:有效槽位用1-1000的随机ID(与训练时随机填充逻辑一致) for _ in range(force_slot_count): history_slot_ids.append(random.randint(1, 1000)) else: # 正常评估:历史槽位为空(模拟从零开始输入) # 注意:实际使用时,历史槽位应该来自之前的用户选择 # 但为了评估公平,我们从空历史开始 pass # 填充到8个槽位 if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] # Tokenize输入 encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) # 计算有效槽位数量(ID > 0 的槽位,与训练数据一致) valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0) # 拼音输入长度(字符数) pinyin_input_length = len(part2) # 构建样本 sample = { "input_ids": encoded["input_ids"], "token_type_ids": encoded["token_type_ids"], "attention_mask": encoded["attention_mask"], "history_slot_ids": torch.tensor( history_slot_ids, dtype=torch.long ).unsqueeze(0), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids_tensor.unsqueeze(0), "true_labels": labels, # 真实标签(多个字符) "position": i, "target_char": text[i] if i < len(text) else "", "valid_slot_count": valid_slot_count, # 有效槽位数量 "pinyin_input_length": pinyin_input_length, # 拼音输入长度 } return sample def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]: """ 执行模型推理。 Returns: (logits, probs) """ # 准备输入张量并移动到设备 input_ids = sample["input_ids"].to(self.device) token_type_ids = sample["token_type_ids"].to(self.device) attention_mask = sample["attention_mask"].to(self.device) pinyin_ids = sample["pinyin_ids"].to(self.device) history_slot_ids = sample["history_slot_ids"].to(self.device) with torch.no_grad(): if self.device.type == "cuda": with torch.autocast(device_type="cuda"): logits = self.model( input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids, ) else: logits = self.model( input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids, ) probs = F.softmax(logits, dim=-1) return logits, probs def analyze_probability_distribution( self, probs: torch.Tensor, threshold_low_var: float = 0.01 ) -> Dict: """ 分析概率分布,检测低方差情况。 Args: probs: 概率张量,形状为 (1, vocab_size) threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差 Returns: 分析结果字典 """ probs_np = probs.cpu().numpy().flatten() # 基本统计 max_prob = np.max(probs_np) min_prob = np.min(probs_np) mean_prob = np.mean(probs_np) std_prob = np.std(probs_np) # 熵 entropy = -np.sum(probs_np * np.log(probs_np + 1e-12)) # 低方差检测 low_variance = max_prob < threshold_low_var # Top-k 分析 top_k = 20 top_indices = np.argsort(probs_np)[-top_k:][::-1] top_probs = probs_np[top_indices] # 检查前5个概率是否接近均匀 top5_probs = top_probs[:5] top5_ratio = ( top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0 ) analysis = { "max_prob": max_prob, "min_prob": min_prob, "mean_prob": mean_prob, "std_prob": std_prob, "entropy": entropy, "low_variance": low_variance, "top_indices": top_indices, "top_probs": top_probs, "top5_ratio": top5_ratio, "vocab_size": len(probs_np), } return analysis def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: """ 评估单个样本。 Returns: 评估结果字典 """ # 执行推理 logits, probs = self.inference(sample) # 分析概率分布 analysis = self.analyze_probability_distribution(probs) # 获取真实标签 true_labels = sample.get("true_labels", []) target_char = sample.get("target_char", "") # 检查预测是否正确(只检查第一个字符) correct = False pred_top_idx = analysis["top_indices"][0] if true_labels and len(true_labels) > 0: correct = pred_top_idx == true_labels[0] # 打印结果 if print_details: # 获取槽位信息 valid_slot_count = sample.get("valid_slot_count", 0) pinyin_input_length = sample.get("pinyin_input_length", 0) # 简化输出格式(在同一行继续输出) print(f" {target_char}", end="") if true_labels: print(f"(ID:{true_labels[0]})", end="") print( f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}", end="", ) print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="") # 关键概率信息 print(f" | 最高概率:{analysis['max_prob']:.3f}", end="") if analysis["low_variance"]: print(f" ⚠️", end="") # Top-2 预测 top_strs = [] for j in range(min(2, len(analysis["top_indices"]))): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] # 处理特殊ID if idx == 0: char = "[空]" else: char_info = ( self.query_engine.query_by_id(idx) if self.query_engine else None ) char = char_info.char if char_info else f"[{idx}]" is_correct = ( (true_labels and idx == true_labels[0]) if true_labels else False ) correct_mark = "✓" if is_correct else "" top_strs.append(f"{char}{correct_mark}({prob:.2f})") print(f" | Top-2: {' '.join(top_strs)}", end="") print(f" | 结果:{'✓' if correct else '✗'}") # 返回评估结果 result = { "sample": sample, "analysis": analysis, "correct": correct, "target_char": target_char, "true_label": true_labels[0] if true_labels else None, "predicted_label": pred_top_idx, } return result def evaluate_sample_with_sequential_confirmation( self, sample: Dict, print_details: bool = True ) -> Dict: """ 评估样本并模拟用户逐步确认过程(用于多字符预测)。 对于拼音长度>1的情况,模拟用户逐步选择Top-1预测作为已确认字符。 Returns: 评估结果字典 """ # 获取真实标签和拼音长度 true_labels = sample.get("true_labels", []) pinyin_len = len(true_labels) if true_labels else 1 if pinyin_len <= 1: # 单字符预测,直接使用普通评估 return self.evaluate_sample(sample, print_details) # 多字符预测:模拟逐步确认 confirmed_history = [] # 用户已确认的字符ID all_predictions = [] # 每一步的预测结果 all_correct = [] # 每一步是否正确 # 复制样本用于逐步推理 sample_copy = sample.copy() for step in range(pinyin_len): # 更新历史槽位:使用已确认的字符 history_slot_ids = confirmed_history[:] if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] # 更新样本的历史槽位 sample_copy["history_slot_ids"] = torch.tensor( history_slot_ids, dtype=torch.long ).unsqueeze(0) # 执行推理 logits, probs = self.inference(sample_copy) analysis = self.analyze_probability_distribution(probs) # 获取Top-1预测 pred_top_idx = analysis["top_indices"][0] pred_top_prob = analysis["top_probs"][0] # 检查是否正确 correct = False if step < len(true_labels): correct = pred_top_idx == true_labels[step] # 记录结果 all_predictions.append( { "predicted_id": pred_top_idx, "probability": pred_top_prob, "correct": correct, } ) all_correct.append(correct) # 模拟用户选择:将Top-1预测添加到已确认历史 # 注意:这里假设用户总是选择Top-1结果 if pred_top_idx > 0: # 只添加有效ID confirmed_history.append(pred_top_idx) # 计算整体准确率:所有步骤都正确才算正确 overall_correct = all(all_correct) # 打印结果 if print_details: target_char = sample.get("target_char", "") pinyin = sample.get("pinyin", "") valid_slot_count = sample.get("valid_slot_count", 0) print(f" {target_char}", end="") print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="") print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="") # 显示逐步预测结果 step_results = [] for i, pred in enumerate(all_predictions): char_info = ( self.query_engine.query_by_id(pred["predicted_id"]) if self.query_engine else None ) char = char_info.char if char_info else f"[{pred['predicted_id']}]" correct_mark = "✓" if pred["correct"] else "✗" step_results.append(f"{char}{correct_mark}") print(f" | 逐步:{'→'.join(step_results)}", end="") print(f" | 结果:{'✓' if overall_correct else '✗'}") # 返回评估结果 result = { "sample": sample, "sequential_predictions": all_predictions, "correct": overall_correct, "step_correct": all_correct, "target_char": sample.get("target_char", ""), "true_labels": true_labels, "pinyin_len": pinyin_len, } return result def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]: """ 评估给定文本,生成多个样本进行评估。 Args: text: 输入文本 num_samples: 生成的样本数量 Returns: 评估结果列表 """ # 限制文本长度 if len(text) > 300: start = random.randint(0, len(text) - 300) text = text[start : start + 300] print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...") else: print(f"📝 文本长度: {len(text)} 字符") results = [] low_variance_samples = [] for sample_idx in range(num_samples): print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ") try: sample = self.create_sample_from_text(text) # 使用逐步确认评估 result = self.evaluate_sample_with_sequential_confirmation( sample, print_details=True ) results.append(result) except Exception as e: print(f"❌ 样本生成或评估失败: {e}") import traceback traceback.print_exc() continue if results: self.print_summary(results, low_variance_samples) return results def evaluate_by_slot_count( self, text: str, samples_per_slot: int = 100 ) -> Dict[int, Dict]: """ 按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。 Args: text: 输入文本 samples_per_slot: 每个槽位数量的测试样本数 Returns: 按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}} """ if len(text) > 300: start = random.randint(0, len(text) - 300) text = text[start : start + 300] print(f"\n{'=' * 60}") print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本") print(f"{'=' * 60}") slot_results = {} for slot_count in range(0, 9): results = [] low_var_count = 0 failed = 0 for sample_idx in range(samples_per_slot): try: sample = self.create_sample_from_text( text, force_slot_count=slot_count ) # 使用逐步确认评估(更符合实际使用场景) result = self.evaluate_sample_with_sequential_confirmation( sample, print_details=False ) results.append(result) # 注意:逐步确认评估没有low_variance分析 except Exception: failed += 1 continue if results: correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) # 对于逐步确认评估,计算平均步数准确率 step_accuracies = [] for r in results: if "step_correct" in r: step_correct = r["step_correct"] if step_correct: step_acc = sum(step_correct) / len(step_correct) step_accuracies.append(step_acc) mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0 slot_results[slot_count] = { "results": results, "accuracy": accuracy, "correct": correct_count, "total": len(results), "mean_step_accuracy": mean_step_accuracy, "failed": failed, } bar_len = 30 filled = int(bar_len * accuracy) bar = "█" * filled + "░" * (bar_len - filled) print( f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " f"({correct_count:2d}/{len(results):2d}) " f"| 平均步数准确率: {mean_step_accuracy:.1%}" ) else: slot_results[slot_count] = { "results": [], "accuracy": 0, "correct": 0, "total": 0, "mean_step_accuracy": 0, "failed": failed, } print(f" 槽位 {slot_count}/8 | 全部样本生成失败") print(f"\n{'=' * 60}") overall_correct = sum(s["correct"] for s in slot_results.values()) overall_total = sum(s["total"] for s in slot_results.values()) if overall_total > 0: print( f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})" ) print(f"{'=' * 60}") return slot_results def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]): """打印评估汇总""" print(f"\n--- 汇总 ({len(results)}样本) ---") # 计算准确率 correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) if results else 0 print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})") # 计算平均步数准确率(对于多字符预测) step_accuracies = [] for r in results: if "step_correct" in r: step_correct = r["step_correct"] if step_correct: step_acc = sum(step_correct) / len(step_correct) step_accuracies.append(step_acc) if step_accuracies: mean_step_accuracy = np.mean(step_accuracies) print(f"平均步数准确率: {mean_step_accuracy:.1%}") # 槽位和拼音长度统计 if results and "sample" in results[0]: valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results] pinyin_lengths = [r.get("pinyin_len", 1) for r in results] avg_slots = np.mean(valid_slots) if valid_slots else 0 avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0 print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}") def main(): parser = argparse.ArgumentParser(description="评估模型在文本上的表现") parser.add_argument("--text-file", type=str, required=True, help="文本文件路径") parser.add_argument( "--checkpoint", type=str, default="/home/songsenand/下载/best_model.pt", help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)", ) parser.add_argument( "--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)" ) parser.add_argument( "--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="推理设备 (默认: auto)", ) args = parser.parse_args() # 选择设备 if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device # 读取文本文件 try: with open(args.text_file, "r", encoding="utf-8") as f: text = f.read().strip() except Exception as e: print(f"❌ 无法读取文本文件: {e}") sys.exit(1) if not text: print("❌ 文本文件为空") sys.exit(1) print(f"📄 读取文本文件: {args.text_file}") print(f"📏 原始文本长度: {len(text)} 字符") print(f"🔧 使用设备: {device}") print(f"📦 模型checkpoint: {args.checkpoint}") print(f"🔬 评估样本数: {args.num_samples}") # 初始化评估器 try: evaluator = TextEvaluator(args.checkpoint, device) except Exception as e: print(f"❌ 初始化评估器失败: {e}") import traceback traceback.print_exc() sys.exit(1) # 执行评估 evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples) if __name__ == "__main__": main()