157 lines
4.6 KiB
Python
157 lines
4.6 KiB
Python
import sys
|
||
|
||
sys.path.append("src")
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from model.model import InputMethodEngine
|
||
from model.query import QueryEngine
|
||
|
||
import random
|
||
import re
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
from torch.utils.data import IterableDataset
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
|
||
)
|
||
|
||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||
|
||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||
|
||
|
||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||
"""
|
||
将拼音字符串转换为 ID 列表。
|
||
支持 a-z 和 `。
|
||
未知字符映射为 0 (PAD/UNK)。
|
||
"""
|
||
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||
|
||
|
||
part1 = "他是一名大学生,在上海读"
|
||
part2 = "dayi"
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
||
part3 = "。"
|
||
part4 = "可行|特别|伤害"
|
||
|
||
encoded = tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=128,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
sample = {
|
||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
|
||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
|
||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
|
||
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||
|
||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
|
||
input_ids = sample["input_ids"]
|
||
token_type_ids = sample["token_type_ids"]
|
||
attention_mask = sample["attention_mask"]
|
||
pinyin_ids = sample["pinyin_ids"]
|
||
history_slot_ids = sample["history_slot_ids"]
|
||
|
||
for k, v in sample.items():
|
||
if isinstance(v, str):
|
||
print(f"{k}: {v}")
|
||
|
||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||
sort_res = sorted(
|
||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||
)
|
||
print(sort_res[0:5])
|
||
|
||
query_engine = QueryEngine()
|
||
query_engine.load()
|
||
# 在test.py的res计算后添加:
|
||
import torch.nn.functional as F
|
||
|
||
# 计算softmax概率
|
||
probs = F.softmax(res, dim=-1)
|
||
|
||
print(f"\n📊 概率分布分析:")
|
||
print(f" 形状: {probs.shape}")
|
||
print(f" 总概率和: {probs.sum().item():.6f}")
|
||
print(f" 最大概率: {probs.max().item():.6f}")
|
||
print(f" 最小概率: {probs.min().item():.6f}")
|
||
print(f" 平均概率: {probs.mean().item():.6f}")
|
||
|
||
# 获取top-20概率
|
||
top_probs, top_indices = torch.topk(probs, k=20)
|
||
print(f"\n🏆 Top-20预测:")
|
||
for i in range(20):
|
||
idx = top_indices[0, i].item()
|
||
prob = top_probs[0, i].item()
|
||
print(
|
||
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||
)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试 history_slot_ids 全零情况")
|
||
print("=" * 60)
|
||
|
||
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
|
||
sample["history_slot_ids"] = history_slot_ids
|
||
|
||
res_zero = model(
|
||
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
|
||
)
|
||
probs_zero = F.softmax(res_zero, dim=-1)
|
||
|
||
print(f"\n📊 概率分布分析 (全零历史):")
|
||
print(f" 形状: {probs_zero.shape}")
|
||
print(f" 总概率和: {probs_zero.sum().item():.6f}")
|
||
print(f" 最大概率: {probs_zero.max().item():.6f}")
|
||
print(f" 最小概率: {probs_zero.min().item():.6f}")
|
||
print(f" 平均概率: {probs_zero.mean().item():.6f}")
|
||
|
||
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
|
||
print(f"\n🏆 Top-20预测 (全零历史):")
|
||
for i in range(20):
|
||
idx = top_indices_zero[0, i].item()
|
||
prob = top_probs_zero[0, i].item()
|
||
print(
|
||
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||
)
|