Compare commits
2 Commits
05e440bcfe
...
8b0beeb56c
| Author | SHA1 | Date |
|---|---|---|
|
|
8b0beeb56c | |
|
|
526cfc8477 |
Binary file not shown.
|
|
@ -128,8 +128,16 @@ class InputMethodEngine(nn.Module):
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
||||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
||||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch]
|
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
||||||
pooled = numerator / denominator.unsqueeze(-1) # [batch, dim]
|
# 若无有效槽位,使用上下文 H 的掩码均值
|
||||||
|
if torch.all(denominator == 0):
|
||||||
|
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
|
||||||
|
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
|
||||||
|
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
|
||||||
|
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
|
||||||
|
pooled = ctx_sum / ctx_cnt
|
||||||
|
else:
|
||||||
|
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
|
||||||
|
|
||||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||||
return logits
|
return logits
|
||||||
|
|
|
||||||
65
test.py
65
test.py
|
|
@ -1,13 +1,13 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("src")
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from model.dataset import PinyinInputDataset
|
|
||||||
from model.model import InputMethodEngine
|
from model.model import InputMethodEngine
|
||||||
from model.trainer import collate_fn, worker_init_fn
|
from model.query import QueryEngine
|
||||||
from .query import QueryEngine
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
@ -25,7 +25,7 @@ from pypinyin.contrib.tone_convert import to_initials
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
|
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||||
|
|
@ -47,17 +47,17 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
|
|
||||||
|
|
||||||
part1 = "他是一名大学生,在上海读"
|
part1 = "他是一名大学生,在上海读"
|
||||||
part2 = "shu"
|
part2 = "dayi"
|
||||||
pinyin_ids = text_to_pinyin_ids(part2)
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
len_py = len(pinyin_ids)
|
len_py = len(pinyin_ids)
|
||||||
if len_py < 24:
|
if len_py < 24:
|
||||||
pinyin_ids.extend([0] * (24 - len_py))
|
pinyin_ids.extend([0] * (24 - len_py))
|
||||||
else:
|
else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
masked_labels = [0] * 8
|
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
||||||
part3 = "。"
|
part3 = "。"
|
||||||
part4 = ""
|
part4 = "可行|特别|伤害"
|
||||||
|
|
||||||
encoded = tokenizer(
|
encoded = tokenizer(
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
|
|
@ -70,12 +70,10 @@ encoded = tokenizer(
|
||||||
)
|
)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
|
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
|
||||||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
|
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
|
||||||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
|
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
|
||||||
"history_slot_ids": torch.tensor(
|
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0),
|
||||||
masked_labels, dtype=torch.long
|
|
||||||
),
|
|
||||||
"prefix": f"{part4}^{part1}",
|
"prefix": f"{part4}^{part1}",
|
||||||
"suffix": part3,
|
"suffix": part3,
|
||||||
"pinyin": part2,
|
"pinyin": part2,
|
||||||
|
|
@ -98,9 +96,13 @@ for k, v in sample.items():
|
||||||
print(f"{k}: {v}")
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
|
sort_res = sorted(
|
||||||
|
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||||
|
)
|
||||||
print(sort_res[0:5])
|
print(sort_res[0:5])
|
||||||
|
|
||||||
|
query_engine = QueryEngine()
|
||||||
|
query_engine.load()
|
||||||
# 在test.py的res计算后添加:
|
# 在test.py的res计算后添加:
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
@ -120,4 +122,35 @@ print(f"\n🏆 Top-20预测:")
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
idx = top_indices[0, i].item()
|
idx = top_indices[0, i].item()
|
||||||
prob = top_probs[0, i].item()
|
prob = top_probs[0, i].item()
|
||||||
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")
|
print(
|
||||||
|
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试 history_slot_ids 全零情况")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
|
||||||
|
sample["history_slot_ids"] = history_slot_ids
|
||||||
|
|
||||||
|
res_zero = model(
|
||||||
|
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
|
||||||
|
)
|
||||||
|
probs_zero = F.softmax(res_zero, dim=-1)
|
||||||
|
|
||||||
|
print(f"\n📊 概率分布分析 (全零历史):")
|
||||||
|
print(f" 形状: {probs_zero.shape}")
|
||||||
|
print(f" 总概率和: {probs_zero.sum().item():.6f}")
|
||||||
|
print(f" 最大概率: {probs_zero.max().item():.6f}")
|
||||||
|
print(f" 最小概率: {probs_zero.min().item():.6f}")
|
||||||
|
print(f" 平均概率: {probs_zero.mean().item():.6f}")
|
||||||
|
|
||||||
|
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
|
||||||
|
print(f"\n🏆 Top-20预测 (全零历史):")
|
||||||
|
for i in range(20):
|
||||||
|
idx = top_indices_zero[0, i].item()
|
||||||
|
prob = top_probs_zero[0, i].item()
|
||||||
|
print(
|
||||||
|
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue