SUimeModelTraner/test.py

85 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.model import InputMethodEngine
from model.trainer import collate_fn, worker_init_fn
max_iter_length = 5
batch_size = 1
if sys.platform == "win32":
dataset_path = "data"
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
def create_dataloader():
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=1, # Limited to 2 for streaming dataset compatibility
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # Reduced from 64 to avoid memory issues
persistent_workers=True,
)
return dataloader
samples = []
# Create the dataloader
dataloader = create_dataloader()
# Convert to list to test loading (as in original code)
dataloader_list = list([i for i in dataloader])
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
# Process batches
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
samples.append(line)
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
sample = samples[0]
input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"]
attention_mask = sample["attention_mask"]
pinyin_ids = sample["pinyin_ids"]
history_slot_ids = sample["history_slot_ids"]
for k, v in sample.items():
if isinstance(v, str):
print(f"{k}: {v}")
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
print(sort_res[0:5])
# 在test.py的res计算后添加
import torch.nn.functional as F
# 计算softmax概率
probs = F.softmax(res, dim=-1)
print(f"\n📊 概率分布分析:")
print(f" 形状: {probs.shape}")
print(f" 总概率和: {probs.sum().item():.6f}")
print(f" 最大概率: {probs.max().item():.6f}")
print(f" 最小概率: {probs.min().item():.6f}")
print(f" 平均概率: {probs.mean().item():.6f}")
# 获取top-20概率
top_probs, top_indices = torch.topk(probs, k=20)
print(f"\n🏆 Top-20预测:")
for i in range(20):
idx = top_indices[0, i].item()
prob = top_probs[0, i].item()
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")