Compare commits

..

2 Commits

3 changed files with 60 additions and 19 deletions

Binary file not shown.

View File

@ -128,8 +128,16 @@ class InputMethodEngine(nn.Module):
batch_size = input_ids.size(0)
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch]
pooled = numerator / denominator.unsqueeze(-1) # [batch, dim]
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
# 若无有效槽位,使用上下文 H 的掩码均值
if torch.all(denominator == 0):
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
pooled = ctx_sum / ctx_cnt
else:
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
logits = self.classifier(pooled) # [batch, vocab_size]
return logits

67
test.py
View File

@ -1,13 +1,13 @@
import sys
sys.path.append("src")
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.model import InputMethodEngine
from model.trainer import collate_fn, worker_init_fn
from .query import QueryEngine
from model.query import QueryEngine
import random
import re
@ -25,8 +25,8 @@ from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
)
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
)
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
@ -47,17 +47,17 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
part1 = "他是一名大学生,在上海读"
part2 = "shu"
part2 = "dayi"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
masked_labels = [0] * 8
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
part3 = ""
part4 = ""
part4 = "可行|特别|伤害"
encoded = tokenizer(
f"{part4}|{part1}",
@ -70,12 +70,10 @@ encoded = tokenizer(
)
sample = {
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
"history_slot_ids": torch.tensor(
masked_labels, dtype=torch.long
),
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
@ -98,9 +96,13 @@ for k, v in sample.items():
print(f"{k}: {v}")
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
)
print(sort_res[0:5])
query_engine = QueryEngine()
query_engine.load()
# 在test.py的res计算后添加
import torch.nn.functional as F
@ -120,4 +122,35 @@ print(f"\n🏆 Top-20预测:")
for i in range(20):
idx = top_indices[0, i].item()
prob = top_probs[0, i].item()
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")
print(
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
)
print("\n" + "=" * 60)
print("测试 history_slot_ids 全零情况")
print("=" * 60)
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
sample["history_slot_ids"] = history_slot_ids
res_zero = model(
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
)
probs_zero = F.softmax(res_zero, dim=-1)
print(f"\n📊 概率分布分析 (全零历史):")
print(f" 形状: {probs_zero.shape}")
print(f" 总概率和: {probs_zero.sum().item():.6f}")
print(f" 最大概率: {probs_zero.max().item():.6f}")
print(f" 最小概率: {probs_zero.min().item():.6f}")
print(f" 平均概率: {probs_zero.mean().item():.6f}")
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
print(f"\n🏆 Top-20预测 (全零历史):")
for i in range(20):
idx = top_indices_zero[0, i].item()
prob = top_probs_zero[0, i].item()
print(
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
)