125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
import sys
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from model.dataset import PinyinInputDataset
|
||
from model.model import InputMethodEngine
|
||
from model.trainer import collate_fn, worker_init_fn
|
||
from .query import QueryEngine
|
||
|
||
import random
|
||
import re
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
from torch.utils.data import IterableDataset
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
|
||
)
|
||
|
||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||
|
||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||
|
||
|
||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||
"""
|
||
将拼音字符串转换为 ID 列表。
|
||
支持 a-z 和 `。
|
||
未知字符映射为 0 (PAD/UNK)。
|
||
"""
|
||
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||
|
||
|
||
part1 = "他是一名大学生,在上海读"
|
||
part2 = "shu"
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||
|
||
part3 = "。"
|
||
part4 = ""
|
||
|
||
encoded = tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=128,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
sample = {
|
||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
|
||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
|
||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
|
||
"history_slot_ids": torch.tensor(
|
||
masked_labels, dtype=torch.long
|
||
),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||
|
||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||
|
||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
input_ids = sample["input_ids"]
|
||
token_type_ids = sample["token_type_ids"]
|
||
attention_mask = sample["attention_mask"]
|
||
pinyin_ids = sample["pinyin_ids"]
|
||
history_slot_ids = sample["history_slot_ids"]
|
||
|
||
for k, v in sample.items():
|
||
if isinstance(v, str):
|
||
print(f"{k}: {v}")
|
||
|
||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
|
||
print(sort_res[0:5])
|
||
|
||
# 在test.py的res计算后添加:
|
||
import torch.nn.functional as F
|
||
|
||
# 计算softmax概率
|
||
probs = F.softmax(res, dim=-1)
|
||
|
||
print(f"\n📊 概率分布分析:")
|
||
print(f" 形状: {probs.shape}")
|
||
print(f" 总概率和: {probs.sum().item():.6f}")
|
||
print(f" 最大概率: {probs.max().item():.6f}")
|
||
print(f" 最小概率: {probs.min().item():.6f}")
|
||
print(f" 平均概率: {probs.mean().item():.6f}")
|
||
|
||
# 获取top-20概率
|
||
top_probs, top_indices = torch.topk(probs, k=20)
|
||
print(f"\n🏆 Top-20预测:")
|
||
for i in range(20):
|
||
idx = top_indices[0, i].item()
|
||
prob = top_probs[0, i].item()
|
||
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")
|