import sys import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.dataset import PinyinInputDataset from model.model import InputMethodEngine from model.trainer import collate_fn, worker_init_fn from .query import QueryEngine import random import re from importlib.resources import files from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset tokenizer = AutoTokenizer.from_pretrained( Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer" ) _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 CHAR_TO_ID["`"] = 27 # 显式添加反引号 CHAR_TO_ID["'"] = 28 # 显式添加引号 CHAR_TO_ID["-"] = 29 # 显式添加短横 def text_to_pinyin_ids(pinyin_str: str) -> List[int]: """ 将拼音字符串转换为 ID 列表。 支持 a-z 和 `。 未知字符映射为 0 (PAD/UNK)。 """ # 使用 dict.get(key, default) 处理未知字符,默认返回 0 return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] part1 = "他是一名大学生,在上海读" part2 = "shu" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) masked_labels = [0] * 8 part3 = "。" part4 = "" encoded = tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) sample = { "input_ids": torch.stack([encoded["input_ids"].squeeze(0)] "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)], "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)], "history_slot_ids": torch.tensor( masked_labels, dtype=torch.long ), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"] pinyin_ids = sample["pinyin_ids"] history_slot_ids = sample["history_slot_ids"] for k, v in sample.items(): if isinstance(v, str): print(f"{k}: {v}") res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True) print(sort_res[0:5]) # 在test.py的res计算后添加: import torch.nn.functional as F # 计算softmax概率 probs = F.softmax(res, dim=-1) print(f"\n📊 概率分布分析:") print(f" 形状: {probs.shape}") print(f" 总概率和: {probs.sum().item():.6f}") print(f" 最大概率: {probs.max().item():.6f}") print(f" 最小概率: {probs.min().item():.6f}") print(f" 平均概率: {probs.mean().item():.6f}") # 获取top-20概率 top_probs, top_indices = torch.topk(probs, k=20) print(f"\n🏆 Top-20预测:") for i in range(20): idx = top_indices[0, i].item() prob = top_probs[0, i].item() print(f" {i + 1:2d}. ID {idx:5d}: {prob:.6f}")