feat(eval): 添加按槽位数量评估文本的功能
This commit is contained in:
parent
0fea985b45
commit
68a6fc3533
158
eval.py
158
eval.py
|
|
@ -222,7 +222,10 @@ class TextEvaluator:
|
|||
return len(mask_pinyin), mask_pinyin
|
||||
|
||||
def create_sample_from_text(
|
||||
self, text: str, position: Optional[int] = None
|
||||
self,
|
||||
text: str,
|
||||
position: Optional[int] = None,
|
||||
force_slot_count: Optional[int] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
从文本创建单个样本,参考dataset.py的样本生成逻辑。
|
||||
|
|
@ -230,6 +233,7 @@ class TextEvaluator:
|
|||
Args:
|
||||
text: 输入文本
|
||||
position: 预测位置(汉字位置),如果为None则随机选择
|
||||
force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算
|
||||
|
||||
Returns:
|
||||
样本字典,包含模型所需的所有字段
|
||||
|
|
@ -330,40 +334,51 @@ class TextEvaluator:
|
|||
|
||||
# 历史槽位:当前预测字符之前已确认的字符
|
||||
# 从位置i之前的字符中提取历史槽位(最多8个)
|
||||
# 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取
|
||||
history_slot_ids = []
|
||||
if self.query_engine:
|
||||
# 收集i之前的汉字字符(最多8个)
|
||||
for j in range(i - 1, max(-1, i - 9), -1):
|
||||
for j in range(i - 1, max(-1, i - 100), -1):
|
||||
if j < 0:
|
||||
break
|
||||
char = text[j]
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
# 获取该字符的第一个拼音变体的ID
|
||||
# 尝试获取该字符的ID(与训练数据一致)
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results:
|
||||
history_slot_ids.append(results[0][0])
|
||||
else:
|
||||
# 如果查询失败,使用0
|
||||
history_slot_ids.append(0)
|
||||
else:
|
||||
# 非汉字字符,使用0
|
||||
history_slot_ids.append(0)
|
||||
slot_id = results[0][0]
|
||||
# 只添加有效的非零ID(与训练数据一致)
|
||||
if slot_id > 0:
|
||||
history_slot_ids.append(slot_id)
|
||||
if len(history_slot_ids) >= 8:
|
||||
break
|
||||
|
||||
# 如果历史槽位为空,随机添加一些常见字符ID(模拟已确认的输入)
|
||||
if not history_slot_ids and random.random() < 0.5:
|
||||
# 随机选择1-3个常见字符ID(范围1-1000)
|
||||
num_slots = random.randint(1, 3)
|
||||
for _ in range(num_slots):
|
||||
history_slot_ids.append(random.randint(1, 1000))
|
||||
|
||||
# 填充到8个槽位
|
||||
# 填充到8个槽位(使用0填充,与训练数据一致)
|
||||
if len(history_slot_ids) < 8:
|
||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||
else:
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
# 强制设置有效槽位数量(用于按槽位数量评估)
|
||||
# 注意:这里不再使用随机ID填充,而是复制现有的有效槽位
|
||||
if force_slot_count is not None:
|
||||
force_slot_count = max(0, min(8, force_slot_count))
|
||||
# 获取所有非零槽位ID
|
||||
valid_ids = [s for s in history_slot_ids if s != 0]
|
||||
|
||||
if len(valid_ids) >= force_slot_count:
|
||||
# 如果有效槽位足够,直接取前force_slot_count个
|
||||
history_slot_ids = valid_ids[:force_slot_count]
|
||||
else:
|
||||
# 如果有效槽位不足,保留现有有效槽位,然后用0填充
|
||||
# 这样更符合训练数据的分布(槽位数量由实际历史决定)
|
||||
history_slot_ids = valid_ids[:]
|
||||
|
||||
# 填充到8个槽位
|
||||
while len(history_slot_ids) < 8:
|
||||
history_slot_ids.append(0)
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
# Tokenize输入
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
|
|
@ -375,8 +390,8 @@ class TextEvaluator:
|
|||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
# 计算有效槽位数量(非零元素)
|
||||
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0)
|
||||
# 计算有效槽位数量(ID > 0 的槽位,与训练数据一致)
|
||||
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0)
|
||||
|
||||
# 拼音输入长度(字符数)
|
||||
pinyin_input_length = len(part2)
|
||||
|
|
@ -597,14 +612,10 @@ class TextEvaluator:
|
|||
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
|
||||
|
||||
try:
|
||||
# 创建样本
|
||||
sample = self.create_sample_from_text(text)
|
||||
|
||||
# 评估样本
|
||||
result = self.evaluate_sample(sample, print_details=True)
|
||||
results.append(result)
|
||||
|
||||
# 记录低方差样本
|
||||
if result["analysis"]["low_variance"]:
|
||||
low_variance_samples.append(
|
||||
{
|
||||
|
|
@ -621,12 +632,105 @@ class TextEvaluator:
|
|||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# 打印汇总统计
|
||||
if results:
|
||||
self.print_summary(results, low_variance_samples)
|
||||
|
||||
return results
|
||||
|
||||
def evaluate_by_slot_count(
|
||||
self, text: str, samples_per_slot: int = 100
|
||||
) -> Dict[int, Dict]:
|
||||
"""
|
||||
按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
samples_per_slot: 每个槽位数量的测试样本数
|
||||
|
||||
Returns:
|
||||
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
|
||||
"""
|
||||
if len(text) > 300:
|
||||
start = random.randint(0, len(text) - 300)
|
||||
text = text[start : start + 300]
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
slot_results = {}
|
||||
|
||||
for slot_count in range(0, 9):
|
||||
results = []
|
||||
low_var_count = 0
|
||||
failed = 0
|
||||
|
||||
for sample_idx in range(samples_per_slot):
|
||||
try:
|
||||
sample = self.create_sample_from_text(
|
||||
text, force_slot_count=slot_count
|
||||
)
|
||||
result = self.evaluate_sample(sample, print_details=False)
|
||||
results.append(result)
|
||||
if result["analysis"]["low_variance"]:
|
||||
low_var_count += 1
|
||||
except Exception:
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if results:
|
||||
correct_count = sum(1 for r in results if r["correct"])
|
||||
accuracy = correct_count / len(results)
|
||||
max_probs = [r["analysis"]["max_prob"] for r in results]
|
||||
mean_max_prob = np.mean(max_probs)
|
||||
entropies = [r["analysis"]["entropy"] for r in results]
|
||||
mean_entropy = np.mean(entropies)
|
||||
|
||||
slot_results[slot_count] = {
|
||||
"results": results,
|
||||
"accuracy": accuracy,
|
||||
"correct": correct_count,
|
||||
"total": len(results),
|
||||
"mean_max_prob": mean_max_prob,
|
||||
"mean_entropy": mean_entropy,
|
||||
"low_variance_count": low_var_count,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
bar_len = 30
|
||||
filled = int(bar_len * accuracy)
|
||||
bar = "█" * filled + "░" * (bar_len - filled)
|
||||
print(
|
||||
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||||
f"({correct_count:2d}/{len(results):2d}) "
|
||||
f"| 平均最高概率: {mean_max_prob:.4f} "
|
||||
f"| 平均熵: {mean_entropy:.2f} "
|
||||
f"| 低方差: {low_var_count}"
|
||||
)
|
||||
else:
|
||||
slot_results[slot_count] = {
|
||||
"results": [],
|
||||
"accuracy": 0,
|
||||
"correct": 0,
|
||||
"total": 0,
|
||||
"mean_max_prob": 0,
|
||||
"mean_entropy": 0,
|
||||
"low_variance_count": 0,
|
||||
"failed": failed,
|
||||
}
|
||||
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
overall_correct = sum(s["correct"] for s in slot_results.values())
|
||||
overall_total = sum(s["total"] for s in slot_results.values())
|
||||
if overall_total > 0:
|
||||
print(
|
||||
f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})"
|
||||
)
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
return slot_results
|
||||
|
||||
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
|
||||
"""打印评估汇总"""
|
||||
print(f"\n--- 汇总 ({len(results)}样本) ---")
|
||||
|
|
@ -685,7 +789,7 @@ def main():
|
|||
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)"
|
||||
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
|
|
@ -732,7 +836,7 @@ def main():
|
|||
sys.exit(1)
|
||||
|
||||
# 执行评估
|
||||
evaluator.evaluate_text(text, args.num_samples)
|
||||
evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -332,6 +332,22 @@ class PinyinInputDataset(IterableDataset):
|
|||
if random.random() <= 0.1:
|
||||
labels.append(0)
|
||||
|
||||
# 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致)
|
||||
history_slot_list = []
|
||||
for j in range(i - 1, max(-1, i - 100), -1):
|
||||
if j < 0:
|
||||
break
|
||||
char = text[j]
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
try:
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results and results[0][0] > 0:
|
||||
history_slot_list.append(results[0][0])
|
||||
except Exception:
|
||||
pass
|
||||
if len(history_slot_list) >= 8:
|
||||
break
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
|
|
@ -345,7 +361,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||||
for label_idx, label in enumerate(labels):
|
||||
repeats = self.adjust_frequency(label)
|
||||
masked_labels = labels[:label_idx]
|
||||
# 使用从text[0:i]提取的历史槽位(与eval.py一致)
|
||||
masked_labels = history_slot_list[:]
|
||||
len_l = len(masked_labels)
|
||||
masked_labels.extend([0] * (8 - len_l))
|
||||
|
||||
|
|
|
|||
|
|
@ -80,6 +80,10 @@ class InputMethodEngine(nn.Module):
|
|||
# 开启 torch.compile 优化 (如果请求)
|
||||
# 在模型编译时添加优化选项
|
||||
if compile:
|
||||
from torch._inductor.select_algorithm import TritonTemplate
|
||||
|
||||
TritonTemplate.all_templates.clear()
|
||||
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
mode="reduce-overhead",
|
||||
|
|
|
|||
Loading…
Reference in New Issue