feat(eval): 添加按槽位数量评估文本的功能
This commit is contained in:
parent
0fea985b45
commit
68a6fc3533
158
eval.py
158
eval.py
|
|
@ -222,7 +222,10 @@ class TextEvaluator:
|
||||||
return len(mask_pinyin), mask_pinyin
|
return len(mask_pinyin), mask_pinyin
|
||||||
|
|
||||||
def create_sample_from_text(
|
def create_sample_from_text(
|
||||||
self, text: str, position: Optional[int] = None
|
self,
|
||||||
|
text: str,
|
||||||
|
position: Optional[int] = None,
|
||||||
|
force_slot_count: Optional[int] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
从文本创建单个样本,参考dataset.py的样本生成逻辑。
|
从文本创建单个样本,参考dataset.py的样本生成逻辑。
|
||||||
|
|
@ -230,6 +233,7 @@ class TextEvaluator:
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
position: 预测位置(汉字位置),如果为None则随机选择
|
position: 预测位置(汉字位置),如果为None则随机选择
|
||||||
|
force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
样本字典,包含模型所需的所有字段
|
样本字典,包含模型所需的所有字段
|
||||||
|
|
@ -330,40 +334,51 @@ class TextEvaluator:
|
||||||
|
|
||||||
# 历史槽位:当前预测字符之前已确认的字符
|
# 历史槽位:当前预测字符之前已确认的字符
|
||||||
# 从位置i之前的字符中提取历史槽位(最多8个)
|
# 从位置i之前的字符中提取历史槽位(最多8个)
|
||||||
|
# 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取
|
||||||
history_slot_ids = []
|
history_slot_ids = []
|
||||||
if self.query_engine:
|
if self.query_engine:
|
||||||
# 收集i之前的汉字字符(最多8个)
|
# 收集i之前的汉字字符(最多8个)
|
||||||
for j in range(i - 1, max(-1, i - 9), -1):
|
for j in range(i - 1, max(-1, i - 100), -1):
|
||||||
if j < 0:
|
if j < 0:
|
||||||
break
|
break
|
||||||
char = text[j]
|
char = text[j]
|
||||||
if self.query_engine.is_chinese_char(char):
|
if self.query_engine.is_chinese_char(char):
|
||||||
# 获取该字符的第一个拼音变体的ID
|
# 尝试获取该字符的ID(与训练数据一致)
|
||||||
results = self.query_engine.query_by_char(char, limit=1)
|
results = self.query_engine.query_by_char(char, limit=1)
|
||||||
if results:
|
if results:
|
||||||
history_slot_ids.append(results[0][0])
|
slot_id = results[0][0]
|
||||||
else:
|
# 只添加有效的非零ID(与训练数据一致)
|
||||||
# 如果查询失败,使用0
|
if slot_id > 0:
|
||||||
history_slot_ids.append(0)
|
history_slot_ids.append(slot_id)
|
||||||
else:
|
|
||||||
# 非汉字字符,使用0
|
|
||||||
history_slot_ids.append(0)
|
|
||||||
if len(history_slot_ids) >= 8:
|
if len(history_slot_ids) >= 8:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 如果历史槽位为空,随机添加一些常见字符ID(模拟已确认的输入)
|
# 填充到8个槽位(使用0填充,与训练数据一致)
|
||||||
if not history_slot_ids and random.random() < 0.5:
|
|
||||||
# 随机选择1-3个常见字符ID(范围1-1000)
|
|
||||||
num_slots = random.randint(1, 3)
|
|
||||||
for _ in range(num_slots):
|
|
||||||
history_slot_ids.append(random.randint(1, 1000))
|
|
||||||
|
|
||||||
# 填充到8个槽位
|
|
||||||
if len(history_slot_ids) < 8:
|
if len(history_slot_ids) < 8:
|
||||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||||
else:
|
else:
|
||||||
history_slot_ids = history_slot_ids[:8]
|
history_slot_ids = history_slot_ids[:8]
|
||||||
|
|
||||||
|
# 强制设置有效槽位数量(用于按槽位数量评估)
|
||||||
|
# 注意:这里不再使用随机ID填充,而是复制现有的有效槽位
|
||||||
|
if force_slot_count is not None:
|
||||||
|
force_slot_count = max(0, min(8, force_slot_count))
|
||||||
|
# 获取所有非零槽位ID
|
||||||
|
valid_ids = [s for s in history_slot_ids if s != 0]
|
||||||
|
|
||||||
|
if len(valid_ids) >= force_slot_count:
|
||||||
|
# 如果有效槽位足够,直接取前force_slot_count个
|
||||||
|
history_slot_ids = valid_ids[:force_slot_count]
|
||||||
|
else:
|
||||||
|
# 如果有效槽位不足,保留现有有效槽位,然后用0填充
|
||||||
|
# 这样更符合训练数据的分布(槽位数量由实际历史决定)
|
||||||
|
history_slot_ids = valid_ids[:]
|
||||||
|
|
||||||
|
# 填充到8个槽位
|
||||||
|
while len(history_slot_ids) < 8:
|
||||||
|
history_slot_ids.append(0)
|
||||||
|
history_slot_ids = history_slot_ids[:8]
|
||||||
|
|
||||||
# Tokenize输入
|
# Tokenize输入
|
||||||
encoded = self.tokenizer(
|
encoded = self.tokenizer(
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
|
|
@ -375,8 +390,8 @@ class TextEvaluator:
|
||||||
return_token_type_ids=True,
|
return_token_type_ids=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算有效槽位数量(非零元素)
|
# 计算有效槽位数量(ID > 0 的槽位,与训练数据一致)
|
||||||
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0)
|
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0)
|
||||||
|
|
||||||
# 拼音输入长度(字符数)
|
# 拼音输入长度(字符数)
|
||||||
pinyin_input_length = len(part2)
|
pinyin_input_length = len(part2)
|
||||||
|
|
@ -597,14 +612,10 @@ class TextEvaluator:
|
||||||
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
|
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建样本
|
|
||||||
sample = self.create_sample_from_text(text)
|
sample = self.create_sample_from_text(text)
|
||||||
|
|
||||||
# 评估样本
|
|
||||||
result = self.evaluate_sample(sample, print_details=True)
|
result = self.evaluate_sample(sample, print_details=True)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
# 记录低方差样本
|
|
||||||
if result["analysis"]["low_variance"]:
|
if result["analysis"]["low_variance"]:
|
||||||
low_variance_samples.append(
|
low_variance_samples.append(
|
||||||
{
|
{
|
||||||
|
|
@ -621,12 +632,105 @@ class TextEvaluator:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 打印汇总统计
|
|
||||||
if results:
|
if results:
|
||||||
self.print_summary(results, low_variance_samples)
|
self.print_summary(results, low_variance_samples)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def evaluate_by_slot_count(
|
||||||
|
self, text: str, samples_per_slot: int = 100
|
||||||
|
) -> Dict[int, Dict]:
|
||||||
|
"""
|
||||||
|
按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
samples_per_slot: 每个槽位数量的测试样本数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
|
||||||
|
"""
|
||||||
|
if len(text) > 300:
|
||||||
|
start = random.randint(0, len(text) - 300)
|
||||||
|
text = text[start : start + 300]
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
slot_results = {}
|
||||||
|
|
||||||
|
for slot_count in range(0, 9):
|
||||||
|
results = []
|
||||||
|
low_var_count = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for sample_idx in range(samples_per_slot):
|
||||||
|
try:
|
||||||
|
sample = self.create_sample_from_text(
|
||||||
|
text, force_slot_count=slot_count
|
||||||
|
)
|
||||||
|
result = self.evaluate_sample(sample, print_details=False)
|
||||||
|
results.append(result)
|
||||||
|
if result["analysis"]["low_variance"]:
|
||||||
|
low_var_count += 1
|
||||||
|
except Exception:
|
||||||
|
failed += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if results:
|
||||||
|
correct_count = sum(1 for r in results if r["correct"])
|
||||||
|
accuracy = correct_count / len(results)
|
||||||
|
max_probs = [r["analysis"]["max_prob"] for r in results]
|
||||||
|
mean_max_prob = np.mean(max_probs)
|
||||||
|
entropies = [r["analysis"]["entropy"] for r in results]
|
||||||
|
mean_entropy = np.mean(entropies)
|
||||||
|
|
||||||
|
slot_results[slot_count] = {
|
||||||
|
"results": results,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct": correct_count,
|
||||||
|
"total": len(results),
|
||||||
|
"mean_max_prob": mean_max_prob,
|
||||||
|
"mean_entropy": mean_entropy,
|
||||||
|
"low_variance_count": low_var_count,
|
||||||
|
"failed": failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
bar_len = 30
|
||||||
|
filled = int(bar_len * accuracy)
|
||||||
|
bar = "█" * filled + "░" * (bar_len - filled)
|
||||||
|
print(
|
||||||
|
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||||||
|
f"({correct_count:2d}/{len(results):2d}) "
|
||||||
|
f"| 平均最高概率: {mean_max_prob:.4f} "
|
||||||
|
f"| 平均熵: {mean_entropy:.2f} "
|
||||||
|
f"| 低方差: {low_var_count}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
slot_results[slot_count] = {
|
||||||
|
"results": [],
|
||||||
|
"accuracy": 0,
|
||||||
|
"correct": 0,
|
||||||
|
"total": 0,
|
||||||
|
"mean_max_prob": 0,
|
||||||
|
"mean_entropy": 0,
|
||||||
|
"low_variance_count": 0,
|
||||||
|
"failed": failed,
|
||||||
|
}
|
||||||
|
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
overall_correct = sum(s["correct"] for s in slot_results.values())
|
||||||
|
overall_total = sum(s["total"] for s in slot_results.values())
|
||||||
|
if overall_total > 0:
|
||||||
|
print(
|
||||||
|
f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})"
|
||||||
|
)
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
return slot_results
|
||||||
|
|
||||||
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
|
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
|
||||||
"""打印评估汇总"""
|
"""打印评估汇总"""
|
||||||
print(f"\n--- 汇总 ({len(results)}样本) ---")
|
print(f"\n--- 汇总 ({len(results)}样本) ---")
|
||||||
|
|
@ -685,7 +789,7 @@ def main():
|
||||||
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
|
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)"
|
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
|
|
@ -732,7 +836,7 @@ def main():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 执行评估
|
# 执行评估
|
||||||
evaluator.evaluate_text(text, args.num_samples)
|
evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -332,6 +332,22 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if random.random() <= 0.1:
|
if random.random() <= 0.1:
|
||||||
labels.append(0)
|
labels.append(0)
|
||||||
|
|
||||||
|
# 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致)
|
||||||
|
history_slot_list = []
|
||||||
|
for j in range(i - 1, max(-1, i - 100), -1):
|
||||||
|
if j < 0:
|
||||||
|
break
|
||||||
|
char = text[j]
|
||||||
|
if self.query_engine.is_chinese_char(char):
|
||||||
|
try:
|
||||||
|
results = self.query_engine.query_by_char(char, limit=1)
|
||||||
|
if results and results[0][0] > 0:
|
||||||
|
history_slot_list.append(results[0][0])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if len(history_slot_list) >= 8:
|
||||||
|
break
|
||||||
|
|
||||||
encoded = self.tokenizer(
|
encoded = self.tokenizer(
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
part3,
|
part3,
|
||||||
|
|
@ -345,7 +361,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||||||
for label_idx, label in enumerate(labels):
|
for label_idx, label in enumerate(labels):
|
||||||
repeats = self.adjust_frequency(label)
|
repeats = self.adjust_frequency(label)
|
||||||
masked_labels = labels[:label_idx]
|
# 使用从text[0:i]提取的历史槽位(与eval.py一致)
|
||||||
|
masked_labels = history_slot_list[:]
|
||||||
len_l = len(masked_labels)
|
len_l = len(masked_labels)
|
||||||
masked_labels.extend([0] * (8 - len_l))
|
masked_labels.extend([0] * (8 - len_l))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,10 @@ class InputMethodEngine(nn.Module):
|
||||||
# 开启 torch.compile 优化 (如果请求)
|
# 开启 torch.compile 优化 (如果请求)
|
||||||
# 在模型编译时添加优化选项
|
# 在模型编译时添加优化选项
|
||||||
if compile:
|
if compile:
|
||||||
|
from torch._inductor.select_algorithm import TritonTemplate
|
||||||
|
|
||||||
|
TritonTemplate.all_templates.clear()
|
||||||
|
|
||||||
self.forward = torch.compile(
|
self.forward = torch.compile(
|
||||||
self.forward,
|
self.forward,
|
||||||
mode="reduce-overhead",
|
mode="reduce-overhead",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue