fix(model): 处理无有效槽位时的池化计算逻辑
This commit is contained in:
parent
504353e895
commit
526cfc8477
Binary file not shown.
|
|
@ -128,8 +128,16 @@ class InputMethodEngine(nn.Module):
|
|||
batch_size = input_ids.size(0)
|
||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch]
|
||||
pooled = numerator / denominator.unsqueeze(-1) # [batch, dim]
|
||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
||||
# 若无有效槽位,使用上下文 H 的掩码均值
|
||||
if torch.all(denominator == 0):
|
||||
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
|
||||
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
|
||||
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
|
||||
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
|
||||
pooled = ctx_sum / ctx_cnt
|
||||
else:
|
||||
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
|
||||
|
||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||
return logits
|
||||
|
|
|
|||
68
test.py
68
test.py
|
|
@ -1,13 +1,13 @@
|
|||
import sys
|
||||
|
||||
sys.path.append("src")
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
from model.model import InputMethodEngine
|
||||
from model.trainer import collate_fn, worker_init_fn
|
||||
from .query import QueryEngine
|
||||
from model.query import QueryEngine
|
||||
|
||||
import random
|
||||
import re
|
||||
|
|
@ -25,8 +25,8 @@ from pypinyin.contrib.tone_convert import to_initials
|
|||
from torch.utils.data import IterableDataset
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
|
||||
)
|
||||
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
|
@ -47,17 +47,17 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
|||
|
||||
|
||||
part1 = "他是一名大学生,在上海读"
|
||||
part2 = "shu"
|
||||
part2 = "dayi"
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
pinyin_ids.extend([0] * (24 - len_py))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
||||
part3 = "。"
|
||||
part4 = ""
|
||||
part4 = "可行|特别|伤害"
|
||||
|
||||
encoded = tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
|
|
@ -70,19 +70,16 @@ encoded = tokenizer(
|
|||
)
|
||||
|
||||
sample = {
|
||||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
|
||||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
|
||||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
|
||||
"history_slot_ids": torch.tensor(
|
||||
masked_labels, dtype=torch.long
|
||||
),
|
||||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
|
||||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
|
||||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
|
||||
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": part2,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
|
||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
|
|
@ -99,9 +96,13 @@ for k, v in sample.items():
|
|||
print(f"{k}: {v}")
|
||||
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
|
||||
sort_res = sorted(
|
||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||
)
|
||||
print(sort_res[0:5])
|
||||
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
# 在test.py的res计算后添加:
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
@ -121,4 +122,35 @@ print(f"\n🏆 Top-20预测:")
|
|||
for i in range(20):
|
||||
idx = top_indices[0, i].item()
|
||||
prob = top_probs[0, i].item()
|
||||
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")
|
||||
print(
|
||||
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 history_slot_ids 全零情况")
|
||||
print("=" * 60)
|
||||
|
||||
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
|
||||
sample["history_slot_ids"] = history_slot_ids
|
||||
|
||||
res_zero = model(
|
||||
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
|
||||
)
|
||||
probs_zero = F.softmax(res_zero, dim=-1)
|
||||
|
||||
print(f"\n📊 概率分布分析 (全零历史):")
|
||||
print(f" 形状: {probs_zero.shape}")
|
||||
print(f" 总概率和: {probs_zero.sum().item():.6f}")
|
||||
print(f" 最大概率: {probs_zero.max().item():.6f}")
|
||||
print(f" 最小概率: {probs_zero.min().item():.6f}")
|
||||
print(f" 平均概率: {probs_zero.mean().item():.6f}")
|
||||
|
||||
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
|
||||
print(f"\n🏆 Top-20预测 (全零历史):")
|
||||
for i in range(20):
|
||||
idx = top_indices_zero[0, i].item()
|
||||
prob = top_probs_zero[0, i].item()
|
||||
print(
|
||||
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue