feat: 添加拼音转ID函数并优化结果排序逻辑
This commit is contained in:
parent
6ee28e0aa5
commit
504353e895
108
test.py
108
test.py
|
|
@ -7,59 +7,99 @@ from tqdm import tqdm
|
|||
from model.dataset import PinyinInputDataset
|
||||
from model.model import InputMethodEngine
|
||||
from model.trainer import collate_fn, worker_init_fn
|
||||
from .query import QueryEngine
|
||||
|
||||
max_iter_length = 5
|
||||
batch_size = 1
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
if sys.platform == "win32":
|
||||
dataset_path = "data"
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
|
||||
)
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||||
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||||
|
||||
|
||||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||
"""
|
||||
将拼音字符串转换为 ID 列表。
|
||||
支持 a-z 和 `。
|
||||
未知字符映射为 0 (PAD/UNK)。
|
||||
"""
|
||||
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||
|
||||
|
||||
part1 = "他是一名大学生,在上海读"
|
||||
part2 = "shu"
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
pinyin_ids.extend([0] * (24 - len_py))
|
||||
else:
|
||||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
part3 = "。"
|
||||
part4 = ""
|
||||
|
||||
encoded = tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=128,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
sample = {
|
||||
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
|
||||
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
|
||||
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
|
||||
"history_slot_ids": torch.tensor(
|
||||
masked_labels, dtype=torch.long
|
||||
),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": part2,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
|
||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||||
|
||||
|
||||
def create_dataloader():
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=1, # Limited to 2 for streaming dataset compatibility
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2, # Reduced from 64 to avoid memory issues
|
||||
persistent_workers=True,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
samples = []
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = create_dataloader()
|
||||
# Convert to list to test loading (as in original code)
|
||||
dataloader_list = list([i for i in dataloader])
|
||||
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
|
||||
|
||||
# Process batches
|
||||
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
|
||||
samples.append(line)
|
||||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
sample = samples[0]
|
||||
input_ids = sample["input_ids"]
|
||||
token_type_ids = sample["token_type_ids"]
|
||||
attention_mask = sample["attention_mask"]
|
||||
pinyin_ids = sample["pinyin_ids"]
|
||||
history_slot_ids = sample["history_slot_ids"]
|
||||
|
||||
for k, v in sample.items():
|
||||
if isinstance(v, str):
|
||||
print(f"{k}: {v}")
|
||||
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
|
||||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
|
||||
print(sort_res[0:5])
|
||||
|
||||
# 在test.py的res计算后添加:
|
||||
|
|
|
|||
Loading…
Reference in New Issue