1013 lines
36 KiB
Python
1013 lines
36 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
eval.py - 评估模型在给定文本上的表现
|
||
|
||
使用方法:
|
||
python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N]
|
||
|
||
功能:
|
||
1. 读取文本文件,限制长度300,随机抽取
|
||
2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息
|
||
3. 使用模型推理,参考test.py
|
||
4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记
|
||
"""
|
||
|
||
import argparse
|
||
import random
|
||
import re
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple, Optional
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
|
||
# 添加src目录到路径
|
||
sys.path.append("src")
|
||
|
||
from src.model.model import InputMethodEngine
|
||
from src.model.query import QueryEngine
|
||
from src.model.dataset import text_to_pinyin_ids
|
||
|
||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||
|
||
|
||
class TextEvaluator:
|
||
def __init__(
|
||
self,
|
||
checkpoint_path: str,
|
||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||
):
|
||
self.device = torch.device(device)
|
||
self.checkpoint_path = checkpoint_path
|
||
|
||
# 加载组件
|
||
self.load_model()
|
||
self.load_tokenizer()
|
||
self.load_query_engine()
|
||
|
||
# 拼音风格权重(与dataset.py一致)
|
||
self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1])
|
||
|
||
print(f"✅ 评估器初始化完成 (设备: {self.device})")
|
||
|
||
# 调试:运行一个固定样本测试
|
||
import os
|
||
|
||
if os.environ.get("EVAL_DEBUG"):
|
||
self._debug_sample()
|
||
|
||
def load_model(self):
|
||
"""加载训练好的模型"""
|
||
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
||
if "model_state_dict" in checkpoint:
|
||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||
else:
|
||
self.model.load_state_dict(checkpoint)
|
||
self.model.eval()
|
||
self.model.to(self.device)
|
||
print(
|
||
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
|
||
)
|
||
print(f"✅ 模型词汇表大小: {self.model.vocab_size}")
|
||
|
||
def load_tokenizer(self):
|
||
"""加载tokenizer"""
|
||
tokenizer_path = (
|
||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||
)
|
||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||
|
||
def load_query_engine(self):
|
||
"""加载查询引擎用于字符-ID转换"""
|
||
try:
|
||
self.query_engine = QueryEngine()
|
||
stats_path = (
|
||
Path(__file__).parent
|
||
/ "src"
|
||
/ "model"
|
||
/ "assets"
|
||
/ "pinyin_char_statistics.json"
|
||
)
|
||
if stats_path.exists():
|
||
self.query_engine.load(stats_path)
|
||
print(
|
||
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
|
||
)
|
||
else:
|
||
print(f"⚠️ 统计文件不存在: {stats_path}")
|
||
self.query_engine = None
|
||
except Exception as e:
|
||
print(f"⚠️ 无法加载查询引擎: {e}")
|
||
self.query_engine = None
|
||
|
||
def _debug_sample(self):
|
||
"""调试:运行与test.py相同的固定样本"""
|
||
print("\n🧪 调试:运行固定样本(与test.py相同)")
|
||
|
||
# 复制test.py中的样本
|
||
part1 = "他是一名大学生,在上海读"
|
||
part2 = "dayi"
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
||
part3 = "。"
|
||
part4 = "可行|特别|伤害"
|
||
|
||
encoded = self.tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=128,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
sample = {
|
||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
|
||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
|
||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
|
||
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(
|
||
0
|
||
),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
# 推理
|
||
logits, probs = self.inference(sample)
|
||
analysis = self.analyze_probability_distribution(probs)
|
||
|
||
print(f"📊 调试样本结果:")
|
||
print(f" 最高概率: {analysis['max_prob']:.6f}")
|
||
print(f" 平均概率: {analysis['mean_prob']:.6f}")
|
||
print(f" 标准差: {analysis['std_prob']:.6f}")
|
||
print(f" 低方差: {analysis['low_variance']}")
|
||
|
||
# 打印top-5
|
||
print(f" Top-5预测:")
|
||
for j in range(5):
|
||
idx = analysis["top_indices"][j]
|
||
prob = analysis["top_probs"][j]
|
||
char_info = (
|
||
self.query_engine.query_by_id(idx) if self.query_engine else None
|
||
)
|
||
char = char_info.char if char_info else f"[ID:{idx}]"
|
||
print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}")
|
||
|
||
def generate_pinyin(self, text: str) -> List[str]:
|
||
"""
|
||
流式处理单条文本,转换为拼音列表。
|
||
参考dataset.py中的generate_pinyin方法。
|
||
"""
|
||
if not text:
|
||
return []
|
||
|
||
text_len = len(text)
|
||
result: List[str] = [""] * text_len
|
||
|
||
# 遍历所有连续汉字片段
|
||
for match in _HANZI_RE.finditer(text):
|
||
start_idx = match.start()
|
||
hanzi_segment = match.group()
|
||
|
||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||
|
||
if len(pinyin_list) != len(hanzi_segment):
|
||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||
|
||
for i, py in enumerate(pinyin_list):
|
||
result[start_idx + i] = py
|
||
|
||
# 填充非汉字字符
|
||
for i, char in enumerate(text):
|
||
if not result[i]:
|
||
result[i] = char
|
||
|
||
return result
|
||
|
||
def get_mask_pinyin(
|
||
self, text: str, pinyin_list: List[str]
|
||
) -> Tuple[int, List[str]]:
|
||
"""
|
||
生成需要预测汉字对应的拼音,并进行加强。
|
||
参考dataset.py中的get_mask_pinyin方法。
|
||
"""
|
||
mask_pinyin = []
|
||
for i in range(len(text)):
|
||
if not self.query_engine or not self.query_engine.is_chinese_char(text[i]):
|
||
break
|
||
else:
|
||
py = np.random.choice(
|
||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||
p=self.py_style_weight,
|
||
)
|
||
if py == "":
|
||
py = pinyin_list[i][0]
|
||
mask_pinyin.append(py)
|
||
return len(mask_pinyin), mask_pinyin
|
||
|
||
def create_sample_from_text(
|
||
self,
|
||
text: str,
|
||
position: Optional[int] = None,
|
||
force_slot_count: Optional[int] = None,
|
||
) -> Dict:
|
||
"""
|
||
从文本创建单个样本,参考dataset.py的样本生成逻辑。
|
||
|
||
Args:
|
||
text: 输入文本
|
||
position: 预测位置(汉字位置),如果为None则随机选择
|
||
force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算
|
||
|
||
Returns:
|
||
样本字典,包含模型所需的所有字段
|
||
"""
|
||
# 确保文本长度至少为1
|
||
if not text:
|
||
raise ValueError("文本不能为空")
|
||
|
||
# 生成拼音列表
|
||
pinyin_list = self.generate_pinyin(text)
|
||
|
||
# 找到所有汉字位置
|
||
chinese_positions = [
|
||
i
|
||
for i, char in enumerate(text)
|
||
if self.query_engine and self.query_engine.is_chinese_char(char)
|
||
]
|
||
|
||
if not chinese_positions:
|
||
raise ValueError("文本中没有汉字")
|
||
|
||
# 随机选择预测位置
|
||
if position is None:
|
||
position = random.choice(chinese_positions)
|
||
elif position >= len(text) or position < 0:
|
||
raise ValueError(f"位置 {position} 超出文本范围")
|
||
elif not self.query_engine or not self.query_engine.is_chinese_char(
|
||
text[position]
|
||
):
|
||
raise ValueError(f"位置 {position} 不是汉字")
|
||
|
||
i = position
|
||
|
||
# part1: 光标前文本(最多48个字符)
|
||
if i < 48:
|
||
part1 = text[0:i]
|
||
else:
|
||
part1 = text[i - 48 : i]
|
||
|
||
# part2: 拼音输入(随机长度1-8,高斯分布)
|
||
# 当强制指定历史槽位时,需要确保拼音长度足够
|
||
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||
min_pinyin_len = (force_slot_count + 1) if force_slot_count is not None else 1
|
||
min_pinyin_len = max(1, min(8, min_pinyin_len))
|
||
|
||
# 检查剩余可用字符数量
|
||
remaining_chars = len(text) - i
|
||
max_valid_len = 0
|
||
for j in range(i, min(i + 8, len(text))):
|
||
if self.query_engine and self.query_engine.is_chinese_char(text[j]):
|
||
max_valid_len += 1
|
||
else:
|
||
break
|
||
|
||
# 调整拼音长度:必须至少有 min_pinyin_len 个有效字符
|
||
if max_valid_len < min_pinyin_len:
|
||
# 重新选择位置,确保有足够字符
|
||
raise ValueError(
|
||
f"位置 {i} 后只有 {max_valid_len} 个有效汉字,不足以测试历史槽位 {force_slot_count}"
|
||
)
|
||
|
||
# 随机选择拼音长度(至少 min_pinyin_len)
|
||
if force_slot_count is not None:
|
||
# 强制模式:随机选择比 min_pinyin_len 大的长度
|
||
valid_lengths = [
|
||
l for l in range(min_pinyin_len, min(max_valid_len + 1, 9))
|
||
]
|
||
if not valid_lengths:
|
||
valid_lengths = [min_pinyin_len]
|
||
# 使用原来的概率分布,但只选择有效长度
|
||
adjusted_probs = [pinyin_len_probs[l - 1] for l in valid_lengths]
|
||
adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs)
|
||
pinyin_len = np.random.choice(valid_lengths, p=adjusted_probs)
|
||
else:
|
||
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
||
pinyin_len = min(pinyin_len, max_valid_len)
|
||
|
||
py_end = min(i + pinyin_len, len(text))
|
||
|
||
# 获取拼音
|
||
pinyin_len_actual, part2_pinyin = self.get_mask_pinyin(
|
||
text[i:py_end], pinyin_list[i:py_end]
|
||
)
|
||
|
||
# 添加分隔符
|
||
split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02])
|
||
part2 = split_char.join(part2_pinyin)
|
||
|
||
# 转换为拼音ID
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long)
|
||
|
||
# part3: 光标后文本(70%概率为空)
|
||
part3 = ""
|
||
if random.random() > 0.7:
|
||
part3 = text[
|
||
i + pinyin_len_actual : i
|
||
+ pinyin_len_actual
|
||
+ np.random.choice(range(1, 17))
|
||
]
|
||
|
||
# part4: 上下文提示(50%概率为空)
|
||
part4 = ""
|
||
if random.random() > 0.5:
|
||
num_strings = random.randint(1, 5)
|
||
string_list = []
|
||
for _ in range(num_strings):
|
||
start_pos = random.randint(0, len(text) - 1)
|
||
x = random.randint(2, 6)
|
||
end_pos = min(start_pos + x + 1, len(text))
|
||
string_list.append(text[start_pos:end_pos])
|
||
part4 = "|".join(string_list)
|
||
|
||
# 标签:预测字符的ID(整个拼音序列的所有字符)
|
||
labels = []
|
||
try:
|
||
labels = [
|
||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||
for c, p in zip(
|
||
text[i : i + pinyin_len_actual],
|
||
pinyin_list[i : i + pinyin_len_actual],
|
||
)
|
||
]
|
||
except (AttributeError, TypeError) as e:
|
||
print(f"⚠️ 获取标签失败: {e}")
|
||
labels = [0] * pinyin_len_actual
|
||
|
||
# 历史槽位:与训练数据生成逻辑一致
|
||
# 当 force_slot_count=k 时,历史槽位为 labels[:k],预测目标为 labels[k]
|
||
history_slot_ids = []
|
||
predict_idx = 0 # 当前预测的标签索引
|
||
|
||
if force_slot_count is not None:
|
||
force_slot_count = max(0, min(8, force_slot_count))
|
||
# 检查 labels 长度是否足够
|
||
if len(labels) <= force_slot_count:
|
||
raise ValueError(
|
||
f"标签数量 {len(labels)} 不足以支持历史槽位 {force_slot_count}"
|
||
)
|
||
predict_idx = force_slot_count
|
||
history_slot_ids = labels[:force_slot_count]
|
||
else:
|
||
# 正常评估:随机选择历史长度,模拟训练数据分布
|
||
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||
max_history = min(len(labels) - 1, 8) # 至少留一个字符预测
|
||
if max_history < 0:
|
||
max_history = 0
|
||
valid_history_lens = list(range(0, max_history + 1))
|
||
if valid_history_lens:
|
||
adjusted_probs = [history_weights[h] for h in valid_history_lens]
|
||
adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs)
|
||
history_len = np.random.choice(valid_history_lens, p=adjusted_probs)
|
||
else:
|
||
history_len = 0
|
||
predict_idx = history_len
|
||
history_slot_ids = labels[:history_len]
|
||
|
||
# 填充到8个槽位
|
||
if len(history_slot_ids) < 8:
|
||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||
else:
|
||
history_slot_ids = history_slot_ids[:8]
|
||
|
||
# Tokenize输入
|
||
encoded = self.tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=128,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
# 计算有效槽位数量(ID > 0 的槽位,与训练数据一致)
|
||
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0)
|
||
|
||
# 拼音输入长度(字符数)
|
||
pinyin_input_length = len(part2)
|
||
|
||
# 当前预测的字符和标签
|
||
current_target_char = (
|
||
text[i + predict_idx] if (i + predict_idx) < len(text) else ""
|
||
)
|
||
current_target_label = labels[predict_idx] if predict_idx < len(labels) else 0
|
||
|
||
# 构建样本
|
||
sample = {
|
||
"input_ids": encoded["input_ids"],
|
||
"token_type_ids": encoded["token_type_ids"],
|
||
"attention_mask": encoded["attention_mask"],
|
||
"history_slot_ids": torch.tensor(
|
||
history_slot_ids, dtype=torch.long
|
||
).unsqueeze(0),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
|
||
"true_labels": labels, # 完整标签列表(所有字符)
|
||
"predict_idx": predict_idx, # 当前预测的标签索引
|
||
"current_target_label": current_target_label, # 当前预测的标签
|
||
"position": i,
|
||
"target_char": current_target_char, # 当前预测的字符
|
||
"valid_slot_count": valid_slot_count,
|
||
"pinyin_input_length": pinyin_input_length,
|
||
}
|
||
|
||
return sample
|
||
|
||
def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
执行模型推理。
|
||
|
||
Returns:
|
||
(logits, probs)
|
||
"""
|
||
# 准备输入张量并移动到设备
|
||
input_ids = sample["input_ids"].to(self.device)
|
||
token_type_ids = sample["token_type_ids"].to(self.device)
|
||
attention_mask = sample["attention_mask"].to(self.device)
|
||
pinyin_ids = sample["pinyin_ids"].to(self.device)
|
||
history_slot_ids = sample["history_slot_ids"].to(self.device)
|
||
|
||
with torch.no_grad():
|
||
if self.device.type == "cuda":
|
||
with torch.autocast(device_type="cuda"):
|
||
logits = self.model(
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
pinyin_ids,
|
||
history_slot_ids,
|
||
)
|
||
else:
|
||
logits = self.model(
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
pinyin_ids,
|
||
history_slot_ids,
|
||
)
|
||
|
||
probs = F.softmax(logits, dim=-1)
|
||
return logits, probs
|
||
|
||
def analyze_probability_distribution(
|
||
self, probs: torch.Tensor, threshold_low_var: float = 0.01
|
||
) -> Dict:
|
||
"""
|
||
分析概率分布,检测低方差情况。
|
||
|
||
Args:
|
||
probs: 概率张量,形状为 (1, vocab_size)
|
||
threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差
|
||
|
||
Returns:
|
||
分析结果字典
|
||
"""
|
||
probs_np = probs.cpu().numpy().flatten()
|
||
|
||
# 基本统计
|
||
max_prob = np.max(probs_np)
|
||
min_prob = np.min(probs_np)
|
||
mean_prob = np.mean(probs_np)
|
||
std_prob = np.std(probs_np)
|
||
|
||
# 熵
|
||
entropy = -np.sum(probs_np * np.log(probs_np + 1e-12))
|
||
|
||
# 低方差检测
|
||
low_variance = max_prob < threshold_low_var
|
||
|
||
# Top-k 分析
|
||
top_k = 20
|
||
top_indices = np.argsort(probs_np)[-top_k:][::-1]
|
||
top_probs = probs_np[top_indices]
|
||
|
||
# 检查前5个概率是否接近均匀
|
||
top5_probs = top_probs[:5]
|
||
top5_ratio = (
|
||
top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0
|
||
)
|
||
|
||
analysis = {
|
||
"max_prob": max_prob,
|
||
"min_prob": min_prob,
|
||
"mean_prob": mean_prob,
|
||
"std_prob": std_prob,
|
||
"entropy": entropy,
|
||
"low_variance": low_variance,
|
||
"top_indices": top_indices,
|
||
"top_probs": top_probs,
|
||
"top5_ratio": top5_ratio,
|
||
"vocab_size": len(probs_np),
|
||
}
|
||
|
||
return analysis
|
||
|
||
def evaluate_sample_single_char(
|
||
self, sample: Dict, print_details: bool = True
|
||
) -> Dict:
|
||
"""
|
||
单字符评估(与验证集 trainer.py 评估逻辑一致)。
|
||
|
||
使用 argmax 选择预测字符,与训练时验证逻辑完全一致。
|
||
根据 predict_idx 选择正确的预测目标。
|
||
|
||
Returns:
|
||
评估结果字典
|
||
"""
|
||
logits, probs = self.inference(sample)
|
||
analysis = self.analyze_probability_distribution(probs)
|
||
|
||
true_labels = sample.get("true_labels", [])
|
||
predict_idx = sample.get("predict_idx", 0)
|
||
current_target_label = sample.get(
|
||
"current_target_label", true_labels[predict_idx] if true_labels else 0
|
||
)
|
||
target_char = sample.get("target_char", "")
|
||
|
||
pred_idx = analysis["top_indices"][0]
|
||
correct = pred_idx == current_target_label
|
||
|
||
if print_details:
|
||
valid_slot_count = sample.get("valid_slot_count", 0)
|
||
pinyin_input_length = sample.get("pinyin_input_length", 0)
|
||
|
||
print(f" {target_char}", end="")
|
||
print(f"(ID:{current_target_label})", end="")
|
||
print(
|
||
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
|
||
end="",
|
||
)
|
||
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
|
||
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
|
||
if analysis["low_variance"]:
|
||
print(f" ⚠️", end="")
|
||
|
||
top_strs = []
|
||
for j in range(min(2, len(analysis["top_indices"]))):
|
||
idx = analysis["top_indices"][j]
|
||
prob = analysis["top_probs"][j]
|
||
if idx == 0:
|
||
char = "[空]"
|
||
else:
|
||
char_info = (
|
||
self.query_engine.query_by_id(idx)
|
||
if self.query_engine
|
||
else None
|
||
)
|
||
char = char_info.char if char_info else f"[{idx}]"
|
||
is_correct = idx == current_target_label
|
||
correct_mark = "✓" if is_correct else ""
|
||
top_strs.append(f"{char}{correct_mark}({prob:.2f})")
|
||
print(f" | Top-2: {' '.join(top_strs)}", end="")
|
||
print(f" | 结果:{'✓' if correct else '✗'}")
|
||
|
||
return {
|
||
"sample": sample,
|
||
"analysis": analysis,
|
||
"correct": correct,
|
||
"target_char": target_char,
|
||
"true_label": current_target_label,
|
||
"predicted_label": pred_idx,
|
||
"predict_idx": predict_idx,
|
||
}
|
||
|
||
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
|
||
"""
|
||
评估单个样本(兼容旧接口,调用单字符评估)。
|
||
|
||
Returns:
|
||
评估结果字典
|
||
"""
|
||
return self.evaluate_sample_single_char(sample, print_details)
|
||
|
||
def evaluate_sample_with_sequential_confirmation(
|
||
self, sample: Dict, print_details: bool = True, use_oracle_history: bool = True
|
||
) -> Dict:
|
||
"""
|
||
评估样本并模拟用户逐步确认过程(用于多字符预测)。
|
||
|
||
Args:
|
||
sample: 样本字典
|
||
print_details: 是否打印详细信息
|
||
use_oracle_history: True=使用真实标签作为历史(与训练/验证一致),False=使用预测结果
|
||
|
||
对于拼音长度>1的情况,模拟用户逐步选择过程。
|
||
|
||
Returns:
|
||
评估结果字典
|
||
"""
|
||
# 获取真实标签和拼音长度
|
||
true_labels = sample.get("true_labels", [])
|
||
pinyin_len = len(true_labels) if true_labels else 1
|
||
|
||
if pinyin_len <= 1:
|
||
# 单字符预测,直接使用普通评估
|
||
return self.evaluate_sample(sample, print_details)
|
||
|
||
# 多字符预测:模拟逐步确认
|
||
confirmed_history = [] # 已确认的字符ID
|
||
all_predictions = [] # 每一步的预测结果
|
||
all_correct = [] # 每一步是否正确
|
||
|
||
# 复制样本用于逐步推理
|
||
sample_copy = sample.copy()
|
||
|
||
for step in range(pinyin_len):
|
||
# 更新历史槽位:
|
||
# - oracle模式:使用真实标签作为历史(与训练一致)
|
||
# - user模式:使用模型预测作为历史(模拟实际使用)
|
||
if use_oracle_history:
|
||
# 使用真实标签作为历史(与训练数据生成逻辑一致)
|
||
history_slot_ids = true_labels[:step]
|
||
else:
|
||
# 使用预测结果作为历史
|
||
history_slot_ids = confirmed_history[:]
|
||
|
||
if len(history_slot_ids) < 8:
|
||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||
else:
|
||
history_slot_ids = history_slot_ids[:8]
|
||
|
||
# 更新样本的历史槽位
|
||
sample_copy["history_slot_ids"] = torch.tensor(
|
||
history_slot_ids, dtype=torch.long
|
||
).unsqueeze(0)
|
||
|
||
# 执行推理
|
||
logits, probs = self.inference(sample_copy)
|
||
analysis = self.analyze_probability_distribution(probs)
|
||
|
||
# 获取Top-1预测
|
||
pred_top_idx = analysis["top_indices"][0]
|
||
pred_top_prob = analysis["top_probs"][0]
|
||
|
||
# 检查是否正确
|
||
correct = False
|
||
if step < len(true_labels):
|
||
correct = pred_top_idx == true_labels[step]
|
||
|
||
# 记录结果
|
||
all_predictions.append(
|
||
{
|
||
"predicted_id": pred_top_idx,
|
||
"probability": pred_top_prob,
|
||
"correct": correct,
|
||
}
|
||
)
|
||
all_correct.append(correct)
|
||
|
||
# 更新已确认历史(用于下一轮)
|
||
if pred_top_idx > 0:
|
||
confirmed_history.append(pred_top_idx)
|
||
|
||
# 计算整体准确率:所有步骤都正确才算正确
|
||
overall_correct = all(all_correct)
|
||
|
||
# 打印结果
|
||
if print_details:
|
||
target_char = sample.get("target_char", "")
|
||
pinyin = sample.get("pinyin", "")
|
||
valid_slot_count = sample.get("valid_slot_count", 0)
|
||
|
||
print(f" {target_char}", end="")
|
||
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
|
||
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
|
||
mode_str = "oracle" if use_oracle_history else "user"
|
||
print(f" | [{mode_str}]", end="")
|
||
|
||
# 显示逐步预测结果
|
||
step_results = []
|
||
for i, pred in enumerate(all_predictions):
|
||
char_info = (
|
||
self.query_engine.query_by_id(pred["predicted_id"])
|
||
if self.query_engine
|
||
else None
|
||
)
|
||
char = char_info.char if char_info else f"[{pred['predicted_id']}]"
|
||
correct_mark = "✓" if pred["correct"] else "✗"
|
||
step_results.append(f"{char}{correct_mark}")
|
||
|
||
print(f" | 逐步:{'→'.join(step_results)}", end="")
|
||
print(f" | 结果:{'✓' if overall_correct else '✗'}")
|
||
|
||
# 返回评估结果
|
||
result = {
|
||
"sample": sample,
|
||
"sequential_predictions": all_predictions,
|
||
"correct": overall_correct,
|
||
"step_correct": all_correct,
|
||
"target_char": sample.get("target_char", ""),
|
||
"true_labels": true_labels,
|
||
"pinyin_len": pinyin_len,
|
||
}
|
||
|
||
return result
|
||
|
||
def evaluate_text(
|
||
self, text: str, num_samples: int = 10, use_sequential: bool = False
|
||
) -> List[Dict]:
|
||
"""
|
||
评估给定文本,生成多个样本进行评估。
|
||
|
||
Args:
|
||
text: 输入文本
|
||
num_samples: 生成的样本数量
|
||
use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致)
|
||
|
||
Returns:
|
||
评估结果列表
|
||
"""
|
||
# 限制文本长度
|
||
if len(text) > 300:
|
||
start = random.randint(0, len(text) - 300)
|
||
text = text[start : start + 300]
|
||
print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...")
|
||
else:
|
||
print(f"📝 文本长度: {len(text)} 字符")
|
||
|
||
results = []
|
||
low_variance_samples = []
|
||
|
||
for sample_idx in range(num_samples):
|
||
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
|
||
|
||
try:
|
||
sample = self.create_sample_from_text(text)
|
||
# 默认使用单字符评估(与验证集一致)
|
||
if use_sequential:
|
||
result = self.evaluate_sample_with_sequential_confirmation(
|
||
sample, print_details=True, use_oracle_history=True
|
||
)
|
||
else:
|
||
result = self.evaluate_sample_single_char(
|
||
sample, print_details=True
|
||
)
|
||
results.append(result)
|
||
|
||
except Exception as e:
|
||
print(f"❌ 样本生成或评估失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
if results:
|
||
self.print_summary(results, low_variance_samples)
|
||
|
||
return results
|
||
|
||
def evaluate_by_slot_count(
|
||
self, text: str, samples_per_slot: int = 100, use_sequential: bool = False
|
||
) -> Dict[int, Dict]:
|
||
"""
|
||
按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。
|
||
|
||
Args:
|
||
text: 输入文本
|
||
samples_per_slot: 每个槽位数量的测试样本数
|
||
use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致)
|
||
|
||
Returns:
|
||
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
|
||
"""
|
||
if len(text) > 300:
|
||
start = random.randint(0, len(text) - 300)
|
||
text = text[start : start + 300]
|
||
|
||
print(f"\n{'=' * 60}")
|
||
mode_str = "逐步确认" if use_sequential else "单字符(与验证集一致)"
|
||
print(f"按槽位数量评估 [{mode_str}] | 每组 {samples_per_slot} 个样本")
|
||
print(f"{'=' * 60}")
|
||
|
||
slot_results = {}
|
||
|
||
for slot_count in range(0, 9):
|
||
results = []
|
||
low_var_count = 0
|
||
failed = 0
|
||
|
||
for sample_idx in range(samples_per_slot):
|
||
try:
|
||
sample = self.create_sample_from_text(
|
||
text, force_slot_count=slot_count
|
||
)
|
||
# 默认使用单字符评估(与验证集一致)
|
||
if use_sequential:
|
||
result = self.evaluate_sample_with_sequential_confirmation(
|
||
sample, print_details=False, use_oracle_history=True
|
||
)
|
||
else:
|
||
result = self.evaluate_sample_single_char(
|
||
sample, print_details=False
|
||
)
|
||
results.append(result)
|
||
# 统计低方差样本
|
||
if "analysis" in result and result["analysis"].get("low_variance"):
|
||
low_var_count += 1
|
||
except Exception:
|
||
failed += 1
|
||
continue
|
||
|
||
if results:
|
||
correct_count = sum(1 for r in results if r["correct"])
|
||
accuracy = correct_count / len(results)
|
||
|
||
# 计算平均步数准确率(仅用于逐步确认模式)
|
||
step_accuracies = []
|
||
if use_sequential:
|
||
for r in results:
|
||
if "step_correct" in r:
|
||
step_correct = r["step_correct"]
|
||
if step_correct:
|
||
step_acc = sum(step_correct) / len(step_correct)
|
||
step_accuracies.append(step_acc)
|
||
|
||
mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0
|
||
|
||
slot_results[slot_count] = {
|
||
"results": results,
|
||
"accuracy": accuracy,
|
||
"correct": correct_count,
|
||
"total": len(results),
|
||
"mean_step_accuracy": mean_step_accuracy,
|
||
"low_var_count": low_var_count,
|
||
"failed": failed,
|
||
}
|
||
|
||
bar_len = 30
|
||
filled = int(bar_len * accuracy)
|
||
bar = "█" * filled + "░" * (bar_len - filled)
|
||
if use_sequential:
|
||
print(
|
||
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||
f"({correct_count:2d}/{len(results):2d}) "
|
||
f"| 平均步数准确率: {mean_step_accuracy:.1%}"
|
||
)
|
||
else:
|
||
print(
|
||
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||
f"({correct_count:2d}/{len(results):2d}) "
|
||
f"| 低方差: {low_var_count}"
|
||
)
|
||
else:
|
||
slot_results[slot_count] = {
|
||
"results": [],
|
||
"accuracy": 0,
|
||
"correct": 0,
|
||
"total": 0,
|
||
"mean_step_accuracy": 0,
|
||
"low_var_count": 0,
|
||
"failed": failed,
|
||
}
|
||
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
||
|
||
print(f"\n{'=' * 60}")
|
||
overall_correct = sum(s["correct"] for s in slot_results.values())
|
||
overall_total = sum(s["total"] for s in slot_results.values())
|
||
if overall_total > 0:
|
||
print(
|
||
f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})"
|
||
)
|
||
print(f"{'=' * 60}")
|
||
|
||
return slot_results
|
||
|
||
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
|
||
"""打印评估汇总"""
|
||
print(f"\n--- 汇总 ({len(results)}样本) ---")
|
||
|
||
# 计算准确率
|
||
correct_count = sum(1 for r in results if r["correct"])
|
||
accuracy = correct_count / len(results) if results else 0
|
||
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
|
||
|
||
# 计算平均步数准确率(对于多字符预测)
|
||
step_accuracies = []
|
||
for r in results:
|
||
if "step_correct" in r:
|
||
step_correct = r["step_correct"]
|
||
if step_correct:
|
||
step_acc = sum(step_correct) / len(step_correct)
|
||
step_accuracies.append(step_acc)
|
||
|
||
if step_accuracies:
|
||
mean_step_accuracy = np.mean(step_accuracies)
|
||
print(f"平均步数准确率: {mean_step_accuracy:.1%}")
|
||
|
||
# 槽位和拼音长度统计
|
||
if results and "sample" in results[0]:
|
||
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
|
||
pinyin_lengths = [r.get("pinyin_len", 1) for r in results]
|
||
avg_slots = np.mean(valid_slots) if valid_slots else 0
|
||
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
|
||
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
|
||
parser.add_argument("--text-file", type=str, required=True, help="文本文件路径")
|
||
parser.add_argument(
|
||
"--checkpoint",
|
||
type=str,
|
||
default="/home/songsenand/下载/best_model.ptrom",
|
||
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.ptrom)",
|
||
)
|
||
parser.add_argument(
|
||
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 100)"
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="auto",
|
||
choices=["auto", "cpu", "cuda"],
|
||
help="推理设备 (默认: auto)",
|
||
)
|
||
parser.add_argument(
|
||
"--use-sequential",
|
||
action="store_true",
|
||
help="使用逐步确认评估模式(默认使用单字符评估,与验证集一致)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 选择设备
|
||
if args.device == "auto":
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
else:
|
||
device = args.device
|
||
|
||
# 读取文本文件
|
||
try:
|
||
with open(args.text_file, "r", encoding="utf-8") as f:
|
||
text = f.read().strip()
|
||
except Exception as e:
|
||
print(f"❌ 无法读取文本文件: {e}")
|
||
sys.exit(1)
|
||
|
||
if not text:
|
||
print("❌ 文本文件为空")
|
||
sys.exit(1)
|
||
|
||
print(f"📄 读取文本文件: {args.text_file}")
|
||
print(f"📏 原始文本长度: {len(text)} 字符")
|
||
print(f"🔧 使用设备: {device}")
|
||
print(f"📦 模型checkpoint: {args.checkpoint}")
|
||
print(f"🔬 评估样本数: {args.num_samples}")
|
||
print(
|
||
f"📊 评估模式: {'逐步确认' if args.use_sequential else '单字符(与验证集一致)'}"
|
||
)
|
||
|
||
# 初始化评估器
|
||
try:
|
||
evaluator = TextEvaluator(args.checkpoint, device)
|
||
except Exception as e:
|
||
print(f"❌ 初始化评估器失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
# 执行评估
|
||
evaluator.evaluate_by_slot_count(
|
||
text, samples_per_slot=args.num_samples, use_sequential=args.use_sequential
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|