import sys sys.path.append("src") import time import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.model import InputMethodEngine from model.query import QueryEngine import random import re from importlib.resources import files from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset tokenizer = AutoTokenizer.from_pretrained( Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer" ) _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 CHAR_TO_ID["`"] = 27 # 显式添加反引号 CHAR_TO_ID["'"] = 28 # 显式添加引号 CHAR_TO_ID["-"] = 29 # 显式添加短横 def text_to_pinyin_ids(pinyin_str: str) -> List[int]: """ 将拼音字符串转换为 ID 列表。 支持 a-z 和 `。 未知字符映射为 0 (PAD/UNK)。 """ # 使用 dict.get(key, default) 处理未知字符,默认返回 0 return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] part1 = "他是一名大学生,在上海读" part2 = "dayi" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) masked_labels = [15, 4, 0, 0, 0, 0, 0, 0] part3 = "。" part4 = "可行|特别|伤害" encoded = tokenizer( f"{part4}|{part1}", part3, max_length=128, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) sample = { "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]), "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]), "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]), "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"] pinyin_ids = sample["pinyin_ids"] history_slot_ids = sample["history_slot_ids"] for k, v in sample.items(): if isinstance(v, str): print(f"{k}: {v}") start = time.time() res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) print(f'计算时长: {(time.time() - start) * 1000:4f}ms') sort_res = sorted( [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True ) print(sort_res[0:5]) query_engine = QueryEngine() query_engine.load() # 在test.py的res计算后添加: import torch.nn.functional as F # 计算softmax概率 probs = F.softmax(res, dim=-1) print(f"\n📊 概率分布分析:") print(f" 形状: {probs.shape}") print(f" 总概率和: {probs.sum().item():.6f}") print(f" 最大概率: {probs.max().item():.6f}") print(f" 最小概率: {probs.min().item():.6f}") print(f" 平均概率: {probs.mean().item():.6f}") # 获取top-20概率 top_probs, top_indices = torch.topk(probs, k=20) print(f"\n🏆 Top-20预测:") for i in range(20): idx = top_indices[0, i].item() prob = top_probs[0, i].item() print( f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}" )