feat: 添加拼音转ID函数并优化结果排序逻辑

This commit is contained in:
songsenand 2026-04-09 17:28:52 +08:00
parent 6ee28e0aa5
commit 504353e895
1 changed files with 74 additions and 34 deletions

108
test.py
View File

@ -7,59 +7,99 @@ from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.model import InputMethodEngine
from model.trainer import collate_fn, worker_init_fn
from .query import QueryEngine
max_iter_length = 5
batch_size = 1
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
if sys.platform == "win32":
dataset_path = "data"
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
)
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
"""
将拼音字符串转换为 ID 列表
支持 a-z `
未知字符映射为 0 (PAD/UNK)
"""
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "他是一名大学生,在上海读"
part2 = "shu"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
part3 = ""
part4 = ""
encoded = tokenizer(
f"{part4}|{part1}",
part3,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
sample = {
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
"history_slot_ids": torch.tensor(
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
def create_dataloader():
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=1, # Limited to 2 for streaming dataset compatibility
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # Reduced from 64 to avoid memory issues
persistent_workers=True,
)
return dataloader
samples = []
# Create the dataloader
dataloader = create_dataloader()
# Convert to list to test loading (as in original code)
dataloader_list = list([i for i in dataloader])
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
# Process batches
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
samples.append(line)
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
sample = samples[0]
input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"]
attention_mask = sample["attention_mask"]
pinyin_ids = sample["pinyin_ids"]
history_slot_ids = sample["history_slot_ids"]
for k, v in sample.items():
if isinstance(v, str):
print(f"{k}: {v}")
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
print(sort_res[0:5])
# 在test.py的res计算后添加