feat: 添加拼音转ID函数并优化结果排序逻辑
This commit is contained in:
parent
6ee28e0aa5
commit
504353e895
108
test.py
108
test.py
|
|
@ -7,59 +7,99 @@ from tqdm import tqdm
|
||||||
from model.dataset import PinyinInputDataset
|
from model.dataset import PinyinInputDataset
|
||||||
from model.model import InputMethodEngine
|
from model.model import InputMethodEngine
|
||||||
from model.trainer import collate_fn, worker_init_fn
|
from model.trainer import collate_fn, worker_init_fn
|
||||||
|
from .query import QueryEngine
|
||||||
|
|
||||||
max_iter_length = 5
|
import random
|
||||||
batch_size = 1
|
import re
|
||||||
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
if sys.platform == "win32":
|
import numpy as np
|
||||||
dataset_path = "data"
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from loguru import logger
|
||||||
|
from modelscope import AutoTokenizer
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from pypinyin.contrib.tone_convert import to_initials
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||||
|
|
||||||
|
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||||||
|
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||||||
|
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||||||
|
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
将拼音字符串转换为 ID 列表。
|
||||||
|
支持 a-z 和 `。
|
||||||
|
未知字符映射为 0 (PAD/UNK)。
|
||||||
|
"""
|
||||||
|
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||||||
|
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||||
|
|
||||||
|
|
||||||
|
part1 = "他是一名大学生,在上海读"
|
||||||
|
part2 = "shu"
|
||||||
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
|
len_py = len(pinyin_ids)
|
||||||
|
if len_py < 24:
|
||||||
|
pinyin_ids.extend([0] * (24 - len_py))
|
||||||
else:
|
else:
|
||||||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
pinyin_ids = pinyin_ids[:24]
|
||||||
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||||
|
|
||||||
|
part3 = "。"
|
||||||
|
part4 = ""
|
||||||
|
|
||||||
|
encoded = tokenizer(
|
||||||
|
f"{part4}|{part1}",
|
||||||
|
part3,
|
||||||
|
max_length=128,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_token_type_ids=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]
|
||||||
|
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)],
|
||||||
|
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)],
|
||||||
|
"history_slot_ids": torch.tensor(
|
||||||
|
masked_labels, dtype=torch.long
|
||||||
|
),
|
||||||
|
"prefix": f"{part4}^{part1}",
|
||||||
|
"suffix": part3,
|
||||||
|
"pinyin": part2,
|
||||||
|
"pinyin_ids": pinyin_ids,
|
||||||
|
}
|
||||||
|
|
||||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader():
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=1, # Limited to 2 for streaming dataset compatibility
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
prefetch_factor=2, # Reduced from 64 to avoid memory issues
|
|
||||||
persistent_workers=True,
|
|
||||||
)
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
# Create the dataloader
|
|
||||||
dataloader = create_dataloader()
|
|
||||||
# Convert to list to test loading (as in original code)
|
|
||||||
dataloader_list = list([i for i in dataloader])
|
|
||||||
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
|
|
||||||
|
|
||||||
# Process batches
|
|
||||||
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
|
|
||||||
samples.append(line)
|
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
sample = samples[0]
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
token_type_ids = sample["token_type_ids"]
|
token_type_ids = sample["token_type_ids"]
|
||||||
attention_mask = sample["attention_mask"]
|
attention_mask = sample["attention_mask"]
|
||||||
pinyin_ids = sample["pinyin_ids"]
|
pinyin_ids = sample["pinyin_ids"]
|
||||||
history_slot_ids = sample["history_slot_ids"]
|
history_slot_ids = sample["history_slot_ids"]
|
||||||
|
|
||||||
for k, v in sample.items():
|
for k, v in sample.items():
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
print(f"{k}: {v}")
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||||
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1])
|
sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True)
|
||||||
print(sort_res[0:5])
|
print(sort_res[0:5])
|
||||||
|
|
||||||
# 在test.py的res计算后添加:
|
# 在test.py的res计算后添加:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue