feat(eval): 添加按槽位数量评估文本的功能

This commit is contained in:
songsenand 2026-04-11 20:28:31 +08:00
parent 0fea985b45
commit 68a6fc3533
3 changed files with 153 additions and 28 deletions

158
eval.py
View File

@ -222,7 +222,10 @@ class TextEvaluator:
return len(mask_pinyin), mask_pinyin
def create_sample_from_text(
self, text: str, position: Optional[int] = None
self,
text: str,
position: Optional[int] = None,
force_slot_count: Optional[int] = None,
) -> Dict:
"""
从文本创建单个样本参考dataset.py的样本生成逻辑
@ -230,6 +233,7 @@ class TextEvaluator:
Args:
text: 输入文本
position: 预测位置汉字位置如果为None则随机选择
force_slot_count: 强制设置有效槽位数量(0-8)如果为None则自动计算
Returns:
样本字典包含模型所需的所有字段
@ -330,40 +334,51 @@ class TextEvaluator:
# 历史槽位:当前预测字符之前已确认的字符
# 从位置i之前的字符中提取历史槽位最多8个
# 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取
history_slot_ids = []
if self.query_engine:
# 收集i之前的汉字字符最多8个
for j in range(i - 1, max(-1, i - 9), -1):
for j in range(i - 1, max(-1, i - 100), -1):
if j < 0:
break
char = text[j]
if self.query_engine.is_chinese_char(char):
# 获取该字符的第一个拼音变体的ID
# 尝试获取该字符的ID(与训练数据一致)
results = self.query_engine.query_by_char(char, limit=1)
if results:
history_slot_ids.append(results[0][0])
else:
# 如果查询失败使用0
history_slot_ids.append(0)
else:
# 非汉字字符使用0
history_slot_ids.append(0)
slot_id = results[0][0]
# 只添加有效的非零ID与训练数据一致
if slot_id > 0:
history_slot_ids.append(slot_id)
if len(history_slot_ids) >= 8:
break
# 如果历史槽位为空随机添加一些常见字符ID模拟已确认的输入
if not history_slot_ids and random.random() < 0.5:
# 随机选择1-3个常见字符ID范围1-1000
num_slots = random.randint(1, 3)
for _ in range(num_slots):
history_slot_ids.append(random.randint(1, 1000))
# 填充到8个槽位
# 填充到8个槽位使用0填充与训练数据一致
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
# 强制设置有效槽位数量(用于按槽位数量评估)
# 注意这里不再使用随机ID填充而是复制现有的有效槽位
if force_slot_count is not None:
force_slot_count = max(0, min(8, force_slot_count))
# 获取所有非零槽位ID
valid_ids = [s for s in history_slot_ids if s != 0]
if len(valid_ids) >= force_slot_count:
# 如果有效槽位足够直接取前force_slot_count个
history_slot_ids = valid_ids[:force_slot_count]
else:
# 如果有效槽位不足保留现有有效槽位然后用0填充
# 这样更符合训练数据的分布(槽位数量由实际历史决定)
history_slot_ids = valid_ids[:]
# 填充到8个槽位
while len(history_slot_ids) < 8:
history_slot_ids.append(0)
history_slot_ids = history_slot_ids[:8]
# Tokenize输入
encoded = self.tokenizer(
f"{part4}|{part1}",
@ -375,8 +390,8 @@ class TextEvaluator:
return_token_type_ids=True,
)
# 计算有效槽位数量(非零元素
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0)
# 计算有效槽位数量(ID > 0 的槽位,与训练数据一致
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0)
# 拼音输入长度(字符数)
pinyin_input_length = len(part2)
@ -597,14 +612,10 @@ class TextEvaluator:
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
try:
# 创建样本
sample = self.create_sample_from_text(text)
# 评估样本
result = self.evaluate_sample(sample, print_details=True)
results.append(result)
# 记录低方差样本
if result["analysis"]["low_variance"]:
low_variance_samples.append(
{
@ -621,12 +632,105 @@ class TextEvaluator:
traceback.print_exc()
continue
# 打印汇总统计
if results:
self.print_summary(results, low_variance_samples)
return results
def evaluate_by_slot_count(
self, text: str, samples_per_slot: int = 100
) -> Dict[int, Dict]:
"""
按槽位数量分别评估每个槽位数量测试 samples_per_slot 个样本
Args:
text: 输入文本
samples_per_slot: 每个槽位数量的测试样本数
Returns:
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
"""
if len(text) > 300:
start = random.randint(0, len(text) - 300)
text = text[start : start + 300]
print(f"\n{'=' * 60}")
print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本")
print(f"{'=' * 60}")
slot_results = {}
for slot_count in range(0, 9):
results = []
low_var_count = 0
failed = 0
for sample_idx in range(samples_per_slot):
try:
sample = self.create_sample_from_text(
text, force_slot_count=slot_count
)
result = self.evaluate_sample(sample, print_details=False)
results.append(result)
if result["analysis"]["low_variance"]:
low_var_count += 1
except Exception:
failed += 1
continue
if results:
correct_count = sum(1 for r in results if r["correct"])
accuracy = correct_count / len(results)
max_probs = [r["analysis"]["max_prob"] for r in results]
mean_max_prob = np.mean(max_probs)
entropies = [r["analysis"]["entropy"] for r in results]
mean_entropy = np.mean(entropies)
slot_results[slot_count] = {
"results": results,
"accuracy": accuracy,
"correct": correct_count,
"total": len(results),
"mean_max_prob": mean_max_prob,
"mean_entropy": mean_entropy,
"low_variance_count": low_var_count,
"failed": failed,
}
bar_len = 30
filled = int(bar_len * accuracy)
bar = "" * filled + "" * (bar_len - filled)
print(
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
f"({correct_count:2d}/{len(results):2d}) "
f"| 平均最高概率: {mean_max_prob:.4f} "
f"| 平均熵: {mean_entropy:.2f} "
f"| 低方差: {low_var_count}"
)
else:
slot_results[slot_count] = {
"results": [],
"accuracy": 0,
"correct": 0,
"total": 0,
"mean_max_prob": 0,
"mean_entropy": 0,
"low_variance_count": 0,
"failed": failed,
}
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
print(f"\n{'=' * 60}")
overall_correct = sum(s["correct"] for s in slot_results.values())
overall_total = sum(s["total"] for s in slot_results.values())
if overall_total > 0:
print(
f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})"
)
print(f"{'=' * 60}")
return slot_results
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
"""打印评估汇总"""
print(f"\n--- 汇总 ({len(results)}样本) ---")
@ -685,7 +789,7 @@ def main():
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
)
parser.add_argument(
"--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)"
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)"
)
parser.add_argument(
"--device",
@ -732,7 +836,7 @@ def main():
sys.exit(1)
# 执行评估
evaluator.evaluate_text(text, args.num_samples)
evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples)
if __name__ == "__main__":

View File

@ -332,6 +332,22 @@ class PinyinInputDataset(IterableDataset):
if random.random() <= 0.1:
labels.append(0)
# 提取历史槽位从预测位置i之前的字符中获取与eval.py一致
history_slot_list = []
for j in range(i - 1, max(-1, i - 100), -1):
if j < 0:
break
char = text[j]
if self.query_engine.is_chinese_char(char):
try:
results = self.query_engine.query_by_char(char, limit=1)
if results and results[0][0] > 0:
history_slot_list.append(results[0][0])
except Exception:
pass
if len(history_slot_list) >= 8:
break
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
@ -345,7 +361,8 @@ class PinyinInputDataset(IterableDataset):
# 修复变量名冲突将内层循环变量i重命名为label_idx
for label_idx, label in enumerate(labels):
repeats = self.adjust_frequency(label)
masked_labels = labels[:label_idx]
# 使用从text[0:i]提取的历史槽位与eval.py一致
masked_labels = history_slot_list[:]
len_l = len(masked_labels)
masked_labels.extend([0] * (8 - len_l))

View File

@ -80,6 +80,10 @@ class InputMethodEngine(nn.Module):
# 开启 torch.compile 优化 (如果请求)
# 在模型编译时添加优化选项
if compile:
from torch._inductor.select_algorithm import TritonTemplate
TritonTemplate.all_templates.clear()
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",