From 526cfc8477451adee1c15ae6ee03bb119f41914b Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 9 Apr 2026 22:11:50 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E5=A4=84=E7=90=86=E6=97=A0?= =?UTF-8?q?=E6=9C=89=E6=95=88=E6=A7=BD=E4=BD=8D=E6=97=B6=E7=9A=84=E6=B1=A0?= =?UTF-8?q?=E5=8C=96=E8=AE=A1=E7=AE=97=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../.pinyin_char_statistics.json.kate-swp | Bin 0 -> 77 bytes src/model/model.py | 12 +++- test.py | 68 +++++++++++++----- 3 files changed, 60 insertions(+), 20 deletions(-) create mode 100644 src/model/assets/.pinyin_char_statistics.json.kate-swp diff --git a/src/model/assets/.pinyin_char_statistics.json.kate-swp b/src/model/assets/.pinyin_char_statistics.json.kate-swp new file mode 100644 index 0000000000000000000000000000000000000000..8091be63aeba017542c162933e20c565937a4b6b GIT binary patch literal 77 zcmZQzU=Z?7EJ;-eE>A2_aLdd|RWQ;sU|?VnF}ND9FZTGe(UTvRvz^P&PS_N>ZcVT! V0|UoTpe!2@Gp4u(2O;s*U%6x9F# literal 0 HcmV?d00001 diff --git a/src/model/model.py b/src/model/model.py index 082ad32..379eed0 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -128,8 +128,16 @@ class InputMethodEngine(nn.Module): batch_size = input_ids.size(0) slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1) numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim] - denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch] - pooled = numerator / denominator.unsqueeze(-1) # [batch, dim] + denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch] + # 若无有效槽位,使用上下文 H 的掩码均值 + if torch.all(denominator == 0): + # H: [batch, seq_len, dim], attention_mask: [batch, seq_len] + ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1] + ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim] + ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1] + pooled = ctx_sum / ctx_cnt + else: + pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim] logits = self.classifier(pooled) # [batch, vocab_size] return logits diff --git a/test.py b/test.py index b71e675..9702402 100644 --- a/test.py +++ b/test.py @@ -1,13 +1,13 @@ import sys +sys.path.append("src") + import torch from torch.utils.data import DataLoader from tqdm import tqdm -from model.dataset import PinyinInputDataset from model.model import InputMethodEngine -from model.trainer import collate_fn, worker_init_fn -from .query import QueryEngine +from model.query import QueryEngine import random import re @@ -25,8 +25,8 @@ from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset tokenizer = AutoTokenizer.from_pretrained( - Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer" - ) + Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer" +) _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") @@ -47,17 +47,17 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]: part1 = "他是一名大学生,在上海读" -part2 = "shu" +part2 = "dayi" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] -pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) - +pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) +masked_labels = [15, 4, 0, 0, 0, 0, 0, 0] part3 = "。" -part4 = "" +part4 = "可行|特别|伤害" encoded = tokenizer( f"{part4}|{part1}", @@ -70,19 +70,16 @@ encoded = tokenizer( ) sample = { - "input_ids": torch.stack([encoded["input_ids"].squeeze(0)] - "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)], - "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)], - "history_slot_ids": torch.tensor( - masked_labels, dtype=torch.long - ), + "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]), + "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]), + "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]), + "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } -dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) model = InputMethodEngine(pinyin_vocab_size=30, compile=False) @@ -99,9 +96,13 @@ for k, v in sample.items(): print(f"{k}: {v}") res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) -sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True) +sort_res = sorted( + [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True +) print(sort_res[0:5]) +query_engine = QueryEngine() +query_engine.load() # 在test.py的res计算后添加: import torch.nn.functional as F @@ -121,4 +122,35 @@ print(f"\n🏆 Top-20预测:") for i in range(20): idx = top_indices[0, i].item() prob = top_probs[0, i].item() - print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}") + print( + f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}" + ) + +print("\n" + "=" * 60) +print("测试 history_slot_ids 全零情况") +print("=" * 60) + +masked_labels = [0, 0, 0, 0, 0, 0, 0, 0] +history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0) +sample["history_slot_ids"] = history_slot_ids + +res_zero = model( + input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids +) +probs_zero = F.softmax(res_zero, dim=-1) + +print(f"\n📊 概率分布分析 (全零历史):") +print(f" 形状: {probs_zero.shape}") +print(f" 总概率和: {probs_zero.sum().item():.6f}") +print(f" 最大概率: {probs_zero.max().item():.6f}") +print(f" 最小概率: {probs_zero.min().item():.6f}") +print(f" 平均概率: {probs_zero.mean().item():.6f}") + +top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20) +print(f"\n🏆 Top-20预测 (全零历史):") +for i in range(20): + idx = top_indices_zero[0, i].item() + prob = top_probs_zero[0, i].item() + print( + f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}" + )