Compare commits

...

2 Commits

3 changed files with 60 additions and 19 deletions

Binary file not shown.

View File

@ -128,8 +128,16 @@ class InputMethodEngine(nn.Module):
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1) slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim] numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch] denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
pooled = numerator / denominator.unsqueeze(-1) # [batch, dim] # 若无有效槽位,使用上下文 H 的掩码均值
if torch.all(denominator == 0):
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
pooled = ctx_sum / ctx_cnt
else:
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
logits = self.classifier(pooled) # [batch, vocab_size] logits = self.classifier(pooled) # [batch, vocab_size]
return logits return logits

67
test.py
View File

@ -1,13 +1,13 @@
import sys import sys
sys.path.append("src")
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.model import InputMethodEngine from model.model import InputMethodEngine
from model.trainer import collate_fn, worker_init_fn from model.query import QueryEngine
from .query import QueryEngine
import random import random
import re import re
@ -25,8 +25,8 @@ from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer" Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
) )
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
@ -47,17 +47,17 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
part1 = "他是一名大学生,在上海读" part1 = "他是一名大学生,在上海读"
part2 = "shu" part2 = "dayi"
pinyin_ids = text_to_pinyin_ids(part2) pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids) len_py = len(pinyin_ids)
if len_py < 24: if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py)) pinyin_ids.extend([0] * (24 - len_py))
else: else:
pinyin_ids = pinyin_ids[:24] pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [0] * 8 masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
part3 = "" part3 = ""
part4 = "" part4 = "可行|特别|伤害"
encoded = tokenizer( encoded = tokenizer(
f"{part4}|{part1}", f"{part4}|{part1}",
@ -70,12 +70,10 @@ encoded = tokenizer(
) )
sample = { sample = {
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)] "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)], "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)], "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
"history_slot_ids": torch.tensor( "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0),
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}", "prefix": f"{part4}^{part1}",
"suffix": part3, "suffix": part3,
"pinyin": part2, "pinyin": part2,
@ -98,9 +96,13 @@ for k, v in sample.items():
print(f"{k}: {v}") print(f"{k}: {v}")
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True) sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
)
print(sort_res[0:5]) print(sort_res[0:5])
query_engine = QueryEngine()
query_engine.load()
# 在test.py的res计算后添加 # 在test.py的res计算后添加
import torch.nn.functional as F import torch.nn.functional as F
@ -120,4 +122,35 @@ print(f"\n🏆 Top-20预测:")
for i in range(20): for i in range(20):
idx = top_indices[0, i].item() idx = top_indices[0, i].item()
prob = top_probs[0, i].item() prob = top_probs[0, i].item()
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}") print(
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
)
print("\n" + "=" * 60)
print("测试 history_slot_ids 全零情况")
print("=" * 60)
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
sample["history_slot_ids"] = history_slot_ids
res_zero = model(
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
)
probs_zero = F.softmax(res_zero, dim=-1)
print(f"\n📊 概率分布分析 (全零历史):")
print(f" 形状: {probs_zero.shape}")
print(f" 总概率和: {probs_zero.sum().item():.6f}")
print(f" 最大概率: {probs_zero.max().item():.6f}")
print(f" 最小概率: {probs_zero.min().item():.6f}")
print(f" 平均概率: {probs_zero.mean().item():.6f}")
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
print(f"\n🏆 Top-20预测 (全零历史):")
for i in range(20):
idx = top_indices_zero[0, i].item()
prob = top_probs_zero[0, i].item()
print(
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
)