From 68a6fc3533b1ffb2436f4a4fec649a38decbe11d Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 11 Apr 2026 20:28:31 +0800 Subject: [PATCH] =?UTF-8?q?feat(eval):=20=E6=B7=BB=E5=8A=A0=E6=8C=89?= =?UTF-8?q?=E6=A7=BD=E4=BD=8D=E6=95=B0=E9=87=8F=E8=AF=84=E4=BC=B0=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- eval.py | 158 +++++++++++++++++++++++++++++++++++-------- src/model/dataset.py | 19 +++++- src/model/model.py | 4 ++ 3 files changed, 153 insertions(+), 28 deletions(-) diff --git a/eval.py b/eval.py index 657dcc0..28d914f 100644 --- a/eval.py +++ b/eval.py @@ -222,7 +222,10 @@ class TextEvaluator: return len(mask_pinyin), mask_pinyin def create_sample_from_text( - self, text: str, position: Optional[int] = None + self, + text: str, + position: Optional[int] = None, + force_slot_count: Optional[int] = None, ) -> Dict: """ 从文本创建单个样本,参考dataset.py的样本生成逻辑。 @@ -230,6 +233,7 @@ class TextEvaluator: Args: text: 输入文本 position: 预测位置(汉字位置),如果为None则随机选择 + force_slot_count: 强制设置有效槽位数量(0-8),如果为None则自动计算 Returns: 样本字典,包含模型所需的所有字段 @@ -330,40 +334,51 @@ class TextEvaluator: # 历史槽位:当前预测字符之前已确认的字符 # 从位置i之前的字符中提取历史槽位(最多8个) + # 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取 history_slot_ids = [] if self.query_engine: # 收集i之前的汉字字符(最多8个) - for j in range(i - 1, max(-1, i - 9), -1): + for j in range(i - 1, max(-1, i - 100), -1): if j < 0: break char = text[j] if self.query_engine.is_chinese_char(char): - # 获取该字符的第一个拼音变体的ID + # 尝试获取该字符的ID(与训练数据一致) results = self.query_engine.query_by_char(char, limit=1) if results: - history_slot_ids.append(results[0][0]) - else: - # 如果查询失败,使用0 - history_slot_ids.append(0) - else: - # 非汉字字符,使用0 - history_slot_ids.append(0) + slot_id = results[0][0] + # 只添加有效的非零ID(与训练数据一致) + if slot_id > 0: + history_slot_ids.append(slot_id) if len(history_slot_ids) >= 8: break - # 如果历史槽位为空,随机添加一些常见字符ID(模拟已确认的输入) - if not history_slot_ids and random.random() < 0.5: - # 随机选择1-3个常见字符ID(范围1-1000) - num_slots = random.randint(1, 3) - for _ in range(num_slots): - history_slot_ids.append(random.randint(1, 1000)) - - # 填充到8个槽位 + # 填充到8个槽位(使用0填充,与训练数据一致) if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] + # 强制设置有效槽位数量(用于按槽位数量评估) + # 注意:这里不再使用随机ID填充,而是复制现有的有效槽位 + if force_slot_count is not None: + force_slot_count = max(0, min(8, force_slot_count)) + # 获取所有非零槽位ID + valid_ids = [s for s in history_slot_ids if s != 0] + + if len(valid_ids) >= force_slot_count: + # 如果有效槽位足够,直接取前force_slot_count个 + history_slot_ids = valid_ids[:force_slot_count] + else: + # 如果有效槽位不足,保留现有有效槽位,然后用0填充 + # 这样更符合训练数据的分布(槽位数量由实际历史决定) + history_slot_ids = valid_ids[:] + + # 填充到8个槽位 + while len(history_slot_ids) < 8: + history_slot_ids.append(0) + history_slot_ids = history_slot_ids[:8] + # Tokenize输入 encoded = self.tokenizer( f"{part4}|{part1}", @@ -375,8 +390,8 @@ class TextEvaluator: return_token_type_ids=True, ) - # 计算有效槽位数量(非零元素) - valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0) + # 计算有效槽位数量(ID > 0 的槽位,与训练数据一致) + valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id > 0) # 拼音输入长度(字符数) pinyin_input_length = len(part2) @@ -597,14 +612,10 @@ class TextEvaluator: print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ") try: - # 创建样本 sample = self.create_sample_from_text(text) - - # 评估样本 result = self.evaluate_sample(sample, print_details=True) results.append(result) - # 记录低方差样本 if result["analysis"]["low_variance"]: low_variance_samples.append( { @@ -621,12 +632,105 @@ class TextEvaluator: traceback.print_exc() continue - # 打印汇总统计 if results: self.print_summary(results, low_variance_samples) return results + def evaluate_by_slot_count( + self, text: str, samples_per_slot: int = 100 + ) -> Dict[int, Dict]: + """ + 按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。 + + Args: + text: 输入文本 + samples_per_slot: 每个槽位数量的测试样本数 + + Returns: + 按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}} + """ + if len(text) > 300: + start = random.randint(0, len(text) - 300) + text = text[start : start + 300] + + print(f"\n{'=' * 60}") + print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本") + print(f"{'=' * 60}") + + slot_results = {} + + for slot_count in range(0, 9): + results = [] + low_var_count = 0 + failed = 0 + + for sample_idx in range(samples_per_slot): + try: + sample = self.create_sample_from_text( + text, force_slot_count=slot_count + ) + result = self.evaluate_sample(sample, print_details=False) + results.append(result) + if result["analysis"]["low_variance"]: + low_var_count += 1 + except Exception: + failed += 1 + continue + + if results: + correct_count = sum(1 for r in results if r["correct"]) + accuracy = correct_count / len(results) + max_probs = [r["analysis"]["max_prob"] for r in results] + mean_max_prob = np.mean(max_probs) + entropies = [r["analysis"]["entropy"] for r in results] + mean_entropy = np.mean(entropies) + + slot_results[slot_count] = { + "results": results, + "accuracy": accuracy, + "correct": correct_count, + "total": len(results), + "mean_max_prob": mean_max_prob, + "mean_entropy": mean_entropy, + "low_variance_count": low_var_count, + "failed": failed, + } + + bar_len = 30 + filled = int(bar_len * accuracy) + bar = "█" * filled + "░" * (bar_len - filled) + print( + f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " + f"({correct_count:2d}/{len(results):2d}) " + f"| 平均最高概率: {mean_max_prob:.4f} " + f"| 平均熵: {mean_entropy:.2f} " + f"| 低方差: {low_var_count}" + ) + else: + slot_results[slot_count] = { + "results": [], + "accuracy": 0, + "correct": 0, + "total": 0, + "mean_max_prob": 0, + "mean_entropy": 0, + "low_variance_count": 0, + "failed": failed, + } + print(f" 槽位 {slot_count}/8 | 全部样本生成失败") + + print(f"\n{'=' * 60}") + overall_correct = sum(s["correct"] for s in slot_results.values()) + overall_total = sum(s["total"] for s in slot_results.values()) + if overall_total > 0: + print( + f" 总体准确率: {overall_correct / overall_total:.1%} ({overall_correct}/{overall_total})" + ) + print(f"{'=' * 60}") + + return slot_results + def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]): """打印评估汇总""" print(f"\n--- 汇总 ({len(results)}样本) ---") @@ -685,7 +789,7 @@ def main(): help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)", ) parser.add_argument( - "--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)" + "--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)" ) parser.add_argument( "--device", @@ -732,7 +836,7 @@ def main(): sys.exit(1) # 执行评估 - evaluator.evaluate_text(text, args.num_samples) + evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples) if __name__ == "__main__": diff --git a/src/model/dataset.py b/src/model/dataset.py index 88d3f4e..935f447 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -332,6 +332,22 @@ class PinyinInputDataset(IterableDataset): if random.random() <= 0.1: labels.append(0) + # 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致) + history_slot_list = [] + for j in range(i - 1, max(-1, i - 100), -1): + if j < 0: + break + char = text[j] + if self.query_engine.is_chinese_char(char): + try: + results = self.query_engine.query_by_char(char, limit=1) + if results and results[0][0] > 0: + history_slot_list.append(results[0][0]) + except Exception: + pass + if len(history_slot_list) >= 8: + break + encoded = self.tokenizer( f"{part4}|{part1}", part3, @@ -345,7 +361,8 @@ class PinyinInputDataset(IterableDataset): # 修复变量名冲突:将内层循环变量i重命名为label_idx for label_idx, label in enumerate(labels): repeats = self.adjust_frequency(label) - masked_labels = labels[:label_idx] + # 使用从text[0:i]提取的历史槽位(与eval.py一致) + masked_labels = history_slot_list[:] len_l = len(masked_labels) masked_labels.extend([0] * (8 - len_l)) diff --git a/src/model/model.py b/src/model/model.py index 70f630c..b65834c 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -80,6 +80,10 @@ class InputMethodEngine(nn.Module): # 开启 torch.compile 优化 (如果请求) # 在模型编译时添加优化选项 if compile: + from torch._inductor.select_algorithm import TritonTemplate + + TritonTemplate.all_templates.clear() + self.forward = torch.compile( self.forward, mode="reduce-overhead",