85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
import sys
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from model.dataset import PinyinInputDataset
|
||
from model.model import InputMethodEngine
|
||
from model.trainer import collate_fn, worker_init_fn
|
||
|
||
max_iter_length = 5
|
||
batch_size = 1
|
||
|
||
if sys.platform == "win32":
|
||
dataset_path = "data"
|
||
else:
|
||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
||
|
||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||
|
||
|
||
def create_dataloader():
|
||
dataloader = DataLoader(
|
||
dataset,
|
||
batch_size=batch_size,
|
||
num_workers=1, # Limited to 2 for streaming dataset compatibility
|
||
pin_memory=torch.cuda.is_available(),
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=2, # Reduced from 64 to avoid memory issues
|
||
persistent_workers=True,
|
||
)
|
||
return dataloader
|
||
|
||
|
||
samples = []
|
||
|
||
# Create the dataloader
|
||
dataloader = create_dataloader()
|
||
# Convert to list to test loading (as in original code)
|
||
dataloader_list = list([i for i in dataloader])
|
||
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
|
||
|
||
# Process batches
|
||
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
|
||
samples.append(line)
|
||
|
||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||
|
||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
sample = samples[0]
|
||
input_ids = sample["input_ids"]
|
||
token_type_ids = sample["token_type_ids"]
|
||
attention_mask = sample["attention_mask"]
|
||
pinyin_ids = sample["pinyin_ids"]
|
||
history_slot_ids = sample["history_slot_ids"]
|
||
for k, v in sample.items():
|
||
if isinstance(v, str):
|
||
print(f"{k}: {v}")
|
||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
|
||
print(sort_res[0:5])
|
||
|
||
# 在test.py的res计算后添加:
|
||
import torch.nn.functional as F
|
||
|
||
# 计算softmax概率
|
||
probs = F.softmax(res, dim=-1)
|
||
|
||
print(f"\n📊 概率分布分析:")
|
||
print(f" 形状: {probs.shape}")
|
||
print(f" 总概率和: {probs.sum().item():.6f}")
|
||
print(f" 最大概率: {probs.max().item():.6f}")
|
||
print(f" 最小概率: {probs.min().item():.6f}")
|
||
print(f" 平均概率: {probs.mean().item():.6f}")
|
||
|
||
# 获取top-20概率
|
||
top_probs, top_indices = torch.topk(probs, k=20)
|
||
print(f"\n🏆 Top-20预测:")
|
||
for i in range(20):
|
||
idx = top_indices[0, i].item()
|
||
prob = top_probs[0, i].item()
|
||
print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")
|