import sys import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.dataset import PinyinInputDataset from model.model import InputMethodEngine from model.trainer import collate_fn, worker_init_fn max_iter_length = 5 batch_size = 1 if sys.platform == "win32": dataset_path = "data" else: dataset_path = "/home/songsenand/Data/corpus/CCI-Data/" dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) def create_dataloader(): dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=1, # Limited to 2 for streaming dataset compatibility pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=2, # Reduced from 64 to avoid memory issues persistent_workers=True, ) return dataloader samples = [] # Create the dataloader dataloader = create_dataloader() # Convert to list to test loading (as in original code) dataloader_list = list([i for i in dataloader]) print(f"✅ Successfully loaded {len(dataloader_list)} batches") # Process batches for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)): samples.append(line) model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) sample = samples[0] input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"] pinyin_ids = sample["pinyin_ids"] history_slot_ids = sample["history_slot_ids"] for k, v in sample.items(): if isinstance(v, str): print(f"{k}: {v}") res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1]) print(sort_res[0:5]) # 在test.py的res计算后添加: import torch.nn.functional as F # 计算softmax概率 probs = F.softmax(res, dim=-1) print(f"\n📊 概率分布分析:") print(f" 形状: {probs.shape}") print(f" 总概率和: {probs.sum().item():.6f}") print(f" 最大概率: {probs.max().item():.6f}") print(f" 最小概率: {probs.min().item():.6f}") print(f" 平均概率: {probs.mean().item():.6f}") # 获取top-20概率 top_probs, top_indices = torch.topk(probs, k=20) print(f"\n🏆 Top-20预测:") for i in range(20): idx = top_indices[0, i].item() prob = top_probs[0, i].item() print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")