740 lines
25 KiB
Python
740 lines
25 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
eval.py - 评估模型在给定文本上的表现
|
||
|
||
使用方法:
|
||
python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N]
|
||
|
||
功能:
|
||
1. 读取文本文件,限制长度300,随机抽取
|
||
2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息
|
||
3. 使用模型推理,参考test.py
|
||
4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记
|
||
"""
|
||
|
||
import argparse
|
||
import random
|
||
import re
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple, Optional
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
|
||
# 添加src目录到路径
|
||
sys.path.append("src")
|
||
|
||
from src.model.model import InputMethodEngine
|
||
from src.model.query import QueryEngine
|
||
from src.model.dataset import text_to_pinyin_ids
|
||
|
||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||
|
||
|
||
class TextEvaluator:
|
||
def __init__(
|
||
self,
|
||
checkpoint_path: str,
|
||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||
):
|
||
self.device = torch.device(device)
|
||
self.checkpoint_path = checkpoint_path
|
||
|
||
# 加载组件
|
||
self.load_model()
|
||
self.load_tokenizer()
|
||
self.load_query_engine()
|
||
|
||
# 拼音风格权重(与dataset.py一致)
|
||
self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1])
|
||
|
||
print(f"✅ 评估器初始化完成 (设备: {self.device})")
|
||
|
||
# 调试:运行一个固定样本测试
|
||
import os
|
||
|
||
if os.environ.get("EVAL_DEBUG"):
|
||
self._debug_sample()
|
||
|
||
def load_model(self):
|
||
"""加载训练好的模型"""
|
||
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
||
if "model_state_dict" in checkpoint:
|
||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||
else:
|
||
self.model.load_state_dict(checkpoint)
|
||
self.model.eval()
|
||
self.model.to(self.device)
|
||
print(
|
||
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
|
||
)
|
||
print(f"✅ 模型词汇表大小: {self.model.vocab_size}")
|
||
|
||
def load_tokenizer(self):
|
||
"""加载tokenizer"""
|
||
tokenizer_path = (
|
||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||
)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||
|
||
def load_query_engine(self):
|
||
"""加载查询引擎用于字符-ID转换"""
|
||
try:
|
||
self.query_engine = QueryEngine()
|
||
stats_path = (
|
||
Path(__file__).parent
|
||
/ "src"
|
||
/ "model"
|
||
/ "assets"
|
||
/ "pinyin_char_statistics.json"
|
||
)
|
||
if stats_path.exists():
|
||
self.query_engine.load(stats_path)
|
||
print(
|
||
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
|
||
)
|
||
else:
|
||
print(f"⚠️ 统计文件不存在: {stats_path}")
|
||
self.query_engine = None
|
||
except Exception as e:
|
||
print(f"⚠️ 无法加载查询引擎: {e}")
|
||
self.query_engine = None
|
||
|
||
def _debug_sample(self):
|
||
"""调试:运行与test.py相同的固定样本"""
|
||
print("\n🧪 调试:运行固定样本(与test.py相同)")
|
||
|
||
# 复制test.py中的样本
|
||
part1 = "他是一名大学生,在上海读"
|
||
part2 = "dayi"
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
||
part3 = "。"
|
||
part4 = "可行|特别|伤害"
|
||
|
||
encoded = self.tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=128,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
sample = {
|
||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
|
||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
|
||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
|
||
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(
|
||
0
|
||
),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
# 推理
|
||
logits, probs = self.inference(sample)
|
||
analysis = self.analyze_probability_distribution(probs)
|
||
|
||
print(f"📊 调试样本结果:")
|
||
print(f" 最高概率: {analysis['max_prob']:.6f}")
|
||
print(f" 平均概率: {analysis['mean_prob']:.6f}")
|
||
print(f" 标准差: {analysis['std_prob']:.6f}")
|
||
print(f" 低方差: {analysis['low_variance']}")
|
||
|
||
# 打印top-5
|
||
print(f" Top-5预测:")
|
||
for j in range(5):
|
||
idx = analysis["top_indices"][j]
|
||
prob = analysis["top_probs"][j]
|
||
char_info = (
|
||
self.query_engine.query_by_id(idx) if self.query_engine else None
|
||
)
|
||
char = char_info.char if char_info else f"[ID:{idx}]"
|
||
print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}")
|
||
|
||
def generate_pinyin(self, text: str) -> List[str]:
|
||
"""
|
||
流式处理单条文本,转换为拼音列表。
|
||
参考dataset.py中的generate_pinyin方法。
|
||
"""
|
||
if not text:
|
||
return []
|
||
|
||
text_len = len(text)
|
||
result: List[str] = [""] * text_len
|
||
|
||
# 遍历所有连续汉字片段
|
||
for match in _HANZI_RE.finditer(text):
|
||
start_idx = match.start()
|
||
hanzi_segment = match.group()
|
||
|
||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||
|
||
if len(pinyin_list) != len(hanzi_segment):
|
||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||
|
||
for i, py in enumerate(pinyin_list):
|
||
result[start_idx + i] = py
|
||
|
||
# 填充非汉字字符
|
||
for i, char in enumerate(text):
|
||
if not result[i]:
|
||
result[i] = char
|
||
|
||
return result
|
||
|
||
def get_mask_pinyin(
|
||
self, text: str, pinyin_list: List[str]
|
||
) -> Tuple[int, List[str]]:
|
||
"""
|
||
生成需要预测汉字对应的拼音,并进行加强。
|
||
参考dataset.py中的get_mask_pinyin方法。
|
||
"""
|
||
mask_pinyin = []
|
||
for i in range(len(text)):
|
||
if not self.query_engine or not self.query_engine.is_chinese_char(text[i]):
|
||
break
|
||
else:
|
||
py = np.random.choice(
|
||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||
p=self.py_style_weight,
|
||
)
|
||
if py == "":
|
||
py = pinyin_list[i][0]
|
||
mask_pinyin.append(py)
|
||
return len(mask_pinyin), mask_pinyin
|
||
|
||
def create_sample_from_text(
|
||
self, text: str, position: Optional[int] = None
|
||
) -> Dict:
|
||
"""
|
||
从文本创建单个样本,参考dataset.py的样本生成逻辑。
|
||
|
||
Args:
|
||
text: 输入文本
|
||
position: 预测位置(汉字位置),如果为None则随机选择
|
||
|
||
Returns:
|
||
样本字典,包含模型所需的所有字段
|
||
"""
|
||
# 确保文本长度至少为1
|
||
if not text:
|
||
raise ValueError("文本不能为空")
|
||
|
||
# 生成拼音列表
|
||
pinyin_list = self.generate_pinyin(text)
|
||
|
||
# 找到所有汉字位置
|
||
chinese_positions = [
|
||
i
|
||
for i, char in enumerate(text)
|
||
if self.query_engine and self.query_engine.is_chinese_char(char)
|
||
]
|
||
|
||
if not chinese_positions:
|
||
raise ValueError("文本中没有汉字")
|
||
|
||
# 随机选择预测位置
|
||
if position is None:
|
||
position = random.choice(chinese_positions)
|
||
elif position >= len(text) or position < 0:
|
||
raise ValueError(f"位置 {position} 超出文本范围")
|
||
elif not self.query_engine or not self.query_engine.is_chinese_char(
|
||
text[position]
|
||
):
|
||
raise ValueError(f"位置 {position} 不是汉字")
|
||
|
||
i = position
|
||
|
||
# part1: 光标前文本(最多48个字符)
|
||
if i < 48:
|
||
part1 = text[0:i]
|
||
else:
|
||
part1 = text[i - 48 : i]
|
||
|
||
# part2: 拼音输入(随机长度1-8,高斯分布)
|
||
pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
||
py_end = min(i + pinyin_len, len(text))
|
||
|
||
# 获取拼音
|
||
pinyin_len_actual, part2_pinyin = self.get_mask_pinyin(
|
||
text[i:py_end], pinyin_list[i:py_end]
|
||
)
|
||
|
||
# 添加分隔符
|
||
split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02])
|
||
part2 = split_char.join(part2_pinyin)
|
||
|
||
# 转换为拼音ID
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long)
|
||
|
||
# part3: 光标后文本(70%概率为空)
|
||
part3 = ""
|
||
if random.random() > 0.7:
|
||
part3 = text[
|
||
i + pinyin_len_actual : i
|
||
+ pinyin_len_actual
|
||
+ np.random.choice(range(1, 17))
|
||
]
|
||
|
||
# part4: 上下文提示(50%概率为空)
|
||
part4 = ""
|
||
if random.random() > 0.5:
|
||
num_strings = random.randint(1, 5)
|
||
string_list = []
|
||
for _ in range(num_strings):
|
||
start_pos = random.randint(0, len(text) - 1)
|
||
x = random.randint(2, 6)
|
||
end_pos = min(start_pos + x + 1, len(text))
|
||
string_list.append(text[start_pos:end_pos])
|
||
part4 = "|".join(string_list)
|
||
|
||
# 标签:预测字符的ID
|
||
labels = []
|
||
try:
|
||
labels = [
|
||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||
for c, p in zip(
|
||
text[i : i + pinyin_len_actual],
|
||
pinyin_list[i : i + pinyin_len_actual],
|
||
)
|
||
]
|
||
except (AttributeError, TypeError) as e:
|
||
# 如果查询失败,使用默认ID
|
||
print(f"⚠️ 获取标签失败: {e}")
|
||
labels = [0] * pinyin_len_actual
|
||
|
||
# 历史槽位:当前预测字符之前已确认的字符
|
||
# 从位置i之前的字符中提取历史槽位(最多8个)
|
||
history_slot_ids = []
|
||
if self.query_engine:
|
||
# 收集i之前的汉字字符(最多8个)
|
||
for j in range(i - 1, max(-1, i - 9), -1):
|
||
if j < 0:
|
||
break
|
||
char = text[j]
|
||
if self.query_engine.is_chinese_char(char):
|
||
# 获取该字符的第一个拼音变体的ID
|
||
results = self.query_engine.query_by_char(char, limit=1)
|
||
if results:
|
||
history_slot_ids.append(results[0][0])
|
||
else:
|
||
# 如果查询失败,使用0
|
||
history_slot_ids.append(0)
|
||
else:
|
||
# 非汉字字符,使用0
|
||
history_slot_ids.append(0)
|
||
if len(history_slot_ids) >= 8:
|
||
break
|
||
|
||
# 如果历史槽位为空,随机添加一些常见字符ID(模拟已确认的输入)
|
||
if not history_slot_ids and random.random() < 0.5:
|
||
# 随机选择1-3个常见字符ID(范围1-1000)
|
||
num_slots = random.randint(1, 3)
|
||
for _ in range(num_slots):
|
||
history_slot_ids.append(random.randint(1, 1000))
|
||
|
||
# 填充到8个槽位
|
||
if len(history_slot_ids) < 8:
|
||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||
else:
|
||
history_slot_ids = history_slot_ids[:8]
|
||
|
||
# Tokenize输入
|
||
encoded = self.tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=128,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
# 计算有效槽位数量(非零元素)
|
||
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0)
|
||
|
||
# 拼音输入长度(字符数)
|
||
pinyin_input_length = len(part2)
|
||
|
||
# 构建样本
|
||
sample = {
|
||
"input_ids": encoded["input_ids"],
|
||
"token_type_ids": encoded["token_type_ids"],
|
||
"attention_mask": encoded["attention_mask"],
|
||
"history_slot_ids": torch.tensor(
|
||
history_slot_ids, dtype=torch.long
|
||
).unsqueeze(0),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
|
||
"true_labels": labels, # 真实标签(多个字符)
|
||
"position": i,
|
||
"target_char": text[i] if i < len(text) else "",
|
||
"valid_slot_count": valid_slot_count, # 有效槽位数量
|
||
"pinyin_input_length": pinyin_input_length, # 拼音输入长度
|
||
}
|
||
|
||
return sample
|
||
|
||
def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
执行模型推理。
|
||
|
||
Returns:
|
||
(logits, probs)
|
||
"""
|
||
# 准备输入张量并移动到设备
|
||
input_ids = sample["input_ids"].to(self.device)
|
||
token_type_ids = sample["token_type_ids"].to(self.device)
|
||
attention_mask = sample["attention_mask"].to(self.device)
|
||
pinyin_ids = sample["pinyin_ids"].to(self.device)
|
||
history_slot_ids = sample["history_slot_ids"].to(self.device)
|
||
|
||
with torch.no_grad():
|
||
if self.device.type == "cuda":
|
||
with torch.autocast(device_type="cuda"):
|
||
logits = self.model(
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
pinyin_ids,
|
||
history_slot_ids,
|
||
)
|
||
else:
|
||
logits = self.model(
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
pinyin_ids,
|
||
history_slot_ids,
|
||
)
|
||
|
||
probs = F.softmax(logits, dim=-1)
|
||
return logits, probs
|
||
|
||
def analyze_probability_distribution(
|
||
self, probs: torch.Tensor, threshold_low_var: float = 0.01
|
||
) -> Dict:
|
||
"""
|
||
分析概率分布,检测低方差情况。
|
||
|
||
Args:
|
||
probs: 概率张量,形状为 (1, vocab_size)
|
||
threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差
|
||
|
||
Returns:
|
||
分析结果字典
|
||
"""
|
||
probs_np = probs.cpu().numpy().flatten()
|
||
|
||
# 基本统计
|
||
max_prob = np.max(probs_np)
|
||
min_prob = np.min(probs_np)
|
||
mean_prob = np.mean(probs_np)
|
||
std_prob = np.std(probs_np)
|
||
|
||
# 熵
|
||
entropy = -np.sum(probs_np * np.log(probs_np + 1e-12))
|
||
|
||
# 低方差检测
|
||
low_variance = max_prob < threshold_low_var
|
||
|
||
# Top-k 分析
|
||
top_k = 20
|
||
top_indices = np.argsort(probs_np)[-top_k:][::-1]
|
||
top_probs = probs_np[top_indices]
|
||
|
||
# 检查前5个概率是否接近均匀
|
||
top5_probs = top_probs[:5]
|
||
top5_ratio = (
|
||
top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0
|
||
)
|
||
|
||
analysis = {
|
||
"max_prob": max_prob,
|
||
"min_prob": min_prob,
|
||
"mean_prob": mean_prob,
|
||
"std_prob": std_prob,
|
||
"entropy": entropy,
|
||
"low_variance": low_variance,
|
||
"top_indices": top_indices,
|
||
"top_probs": top_probs,
|
||
"top5_ratio": top5_ratio,
|
||
"vocab_size": len(probs_np),
|
||
}
|
||
|
||
return analysis
|
||
|
||
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
|
||
"""
|
||
评估单个样本。
|
||
|
||
Returns:
|
||
评估结果字典
|
||
"""
|
||
# 执行推理
|
||
logits, probs = self.inference(sample)
|
||
|
||
# 分析概率分布
|
||
analysis = self.analyze_probability_distribution(probs)
|
||
|
||
# 获取真实标签
|
||
true_labels = sample.get("true_labels", [])
|
||
target_char = sample.get("target_char", "")
|
||
|
||
# 检查预测是否正确(只检查第一个字符)
|
||
correct = False
|
||
pred_top_idx = analysis["top_indices"][0]
|
||
if true_labels and len(true_labels) > 0:
|
||
correct = pred_top_idx == true_labels[0]
|
||
|
||
# 打印结果
|
||
if print_details:
|
||
# 获取槽位信息
|
||
valid_slot_count = sample.get("valid_slot_count", 0)
|
||
pinyin_input_length = sample.get("pinyin_input_length", 0)
|
||
|
||
# 简化输出格式(在同一行继续输出)
|
||
print(f" {target_char}", end="")
|
||
if true_labels:
|
||
print(f"(ID:{true_labels[0]})", end="")
|
||
print(
|
||
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
|
||
end="",
|
||
)
|
||
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
|
||
|
||
# 关键概率信息
|
||
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
|
||
if analysis["low_variance"]:
|
||
print(f" ⚠️", end="")
|
||
|
||
# Top-2 预测
|
||
top_strs = []
|
||
for j in range(min(2, len(analysis["top_indices"]))):
|
||
idx = analysis["top_indices"][j]
|
||
prob = analysis["top_probs"][j]
|
||
# 处理特殊ID
|
||
if idx == 0:
|
||
char = "[空]"
|
||
else:
|
||
char_info = (
|
||
self.query_engine.query_by_id(idx)
|
||
if self.query_engine
|
||
else None
|
||
)
|
||
char = char_info.char if char_info else f"[{idx}]"
|
||
is_correct = (
|
||
(true_labels and idx == true_labels[0]) if true_labels else False
|
||
)
|
||
correct_mark = "✓" if is_correct else ""
|
||
top_strs.append(f"{char}{correct_mark}({prob:.2f})")
|
||
print(f" | Top-2: {' '.join(top_strs)}", end="")
|
||
|
||
print(f" | 结果:{'✓' if correct else '✗'}")
|
||
|
||
# 返回评估结果
|
||
result = {
|
||
"sample": sample,
|
||
"analysis": analysis,
|
||
"correct": correct,
|
||
"target_char": target_char,
|
||
"true_label": true_labels[0] if true_labels else None,
|
||
"predicted_label": pred_top_idx,
|
||
}
|
||
|
||
return result
|
||
|
||
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
|
||
"""
|
||
评估给定文本,生成多个样本进行评估。
|
||
|
||
Args:
|
||
text: 输入文本
|
||
num_samples: 生成的样本数量
|
||
|
||
Returns:
|
||
评估结果列表
|
||
"""
|
||
# 限制文本长度
|
||
if len(text) > 300:
|
||
start = random.randint(0, len(text) - 300)
|
||
text = text[start : start + 300]
|
||
print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...")
|
||
else:
|
||
print(f"📝 文本长度: {len(text)} 字符")
|
||
|
||
results = []
|
||
low_variance_samples = []
|
||
|
||
for sample_idx in range(num_samples):
|
||
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
|
||
|
||
try:
|
||
# 创建样本
|
||
sample = self.create_sample_from_text(text)
|
||
|
||
# 评估样本
|
||
result = self.evaluate_sample(sample, print_details=True)
|
||
results.append(result)
|
||
|
||
# 记录低方差样本
|
||
if result["analysis"]["low_variance"]:
|
||
low_variance_samples.append(
|
||
{
|
||
"sample_idx": sample_idx,
|
||
"max_prob": result["analysis"]["max_prob"],
|
||
"entropy": result["analysis"]["entropy"],
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"❌ 样本生成或评估失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
# 打印汇总统计
|
||
if results:
|
||
self.print_summary(results, low_variance_samples)
|
||
|
||
return results
|
||
|
||
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
|
||
"""打印评估汇总"""
|
||
print(f"\n--- 汇总 ({len(results)}样本) ---")
|
||
|
||
# 计算准确率
|
||
correct_count = sum(1 for r in results if r["correct"])
|
||
accuracy = correct_count / len(results) if results else 0
|
||
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
|
||
|
||
# 概率统计
|
||
max_probs = [r["analysis"]["max_prob"] for r in results]
|
||
mean_max_prob = np.mean(max_probs) if max_probs else 0
|
||
print(f"平均最高概率: {mean_max_prob:.4f}")
|
||
|
||
# 预测ID:0的比例
|
||
zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0)
|
||
zero_pred_ratio = zero_pred_count / len(results) if results else 0
|
||
print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})")
|
||
|
||
# 平均熵
|
||
entropies = [r["analysis"]["entropy"] for r in results]
|
||
mean_entropy = np.mean(entropies) if entropies else 0
|
||
print(f"平均熵: {mean_entropy:.2f}")
|
||
|
||
# 槽位和拼音长度统计
|
||
if results and "sample" in results[0]:
|
||
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
|
||
pinyin_lengths = [
|
||
r["sample"].get("pinyin_input_length", 0) for r in results
|
||
]
|
||
avg_slots = np.mean(valid_slots) if valid_slots else 0
|
||
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
|
||
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
|
||
|
||
# 低方差样本
|
||
if low_variance_samples:
|
||
print(f"低方差样本: {len(low_variance_samples)}个")
|
||
if len(low_variance_samples) <= 5:
|
||
for lv in low_variance_samples:
|
||
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
||
else:
|
||
for lv in low_variance_samples[:3]:
|
||
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
||
print(f" ... 还有{len(low_variance_samples) - 3}个")
|
||
else:
|
||
print(f"低方差样本: 0个")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
|
||
parser.add_argument("--text-file", type=str, required=True, help="文本文件路径")
|
||
parser.add_argument(
|
||
"--checkpoint",
|
||
type=str,
|
||
default="/home/songsenand/下载/best_model.pt",
|
||
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
|
||
)
|
||
parser.add_argument(
|
||
"--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)"
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="auto",
|
||
choices=["auto", "cpu", "cuda"],
|
||
help="推理设备 (默认: auto)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 选择设备
|
||
if args.device == "auto":
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
else:
|
||
device = args.device
|
||
|
||
# 读取文本文件
|
||
try:
|
||
with open(args.text_file, "r", encoding="utf-8") as f:
|
||
text = f.read().strip()
|
||
except Exception as e:
|
||
print(f"❌ 无法读取文本文件: {e}")
|
||
sys.exit(1)
|
||
|
||
if not text:
|
||
print("❌ 文本文件为空")
|
||
sys.exit(1)
|
||
|
||
print(f"📄 读取文本文件: {args.text_file}")
|
||
print(f"📏 原始文本长度: {len(text)} 字符")
|
||
print(f"🔧 使用设备: {device}")
|
||
print(f"📦 模型checkpoint: {args.checkpoint}")
|
||
print(f"🔬 评估样本数: {args.num_samples}")
|
||
|
||
# 初始化评估器
|
||
try:
|
||
evaluator = TextEvaluator(args.checkpoint, device)
|
||
except Exception as e:
|
||
print(f"❌ 初始化评估器失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
# 执行评估
|
||
evaluator.evaluate_text(text, args.num_samples)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|