SUimeModelTraner/test.py

125 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.model import InputMethodEngine
from model.trainer import collate_fn, worker_init_fn
from .query import QueryEngine
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
)
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
"""
将拼音字符串转换为 ID 列表。
支持 a-z 和 `。
未知字符映射为 0 (PAD/UNK)。
"""
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "他是一名大学生,在上海读"
part2 = "shu"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
part3 = ""
part4 = ""
encoded = tokenizer(
f"{part4}|{part1}",
part3,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
sample = {
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
"history_slot_ids": torch.tensor(
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"]
attention_mask = sample["attention_mask"]
pinyin_ids = sample["pinyin_ids"]
history_slot_ids = sample["history_slot_ids"]
for k, v in sample.items():
if isinstance(v, str):
print(f"{k}: {v}")
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
print(sort_res[0:5])
# 在test.py的res计算后添加
import torch.nn.functional as F
# 计算softmax概率
probs = F.softmax(res, dim=-1)
print(f"\n📊 概率分布分析:")
print(f" 形状: {probs.shape}")
print(f" 总概率和: {probs.sum().item():.6f}")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(f" 平均概率: {probs.mean().item():.6f}")
# 获取top-20概率
top_probs, top_indices = torch.topk(probs, k=20)
print(f"\n🏆 Top-20预测:")
for i in range(20):
idx = top_indices[0, i].item()
prob = top_probs[0, i].item()
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")