From 504353e89544933ce961b877f8daa2643633f8c2 Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 9 Apr 2026 17:28:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=8B=BC=E9=9F=B3?= =?UTF-8?q?=E8=BD=ACID=E5=87=BD=E6=95=B0=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E6=8E=92=E5=BA=8F=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test.py | 108 ++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 34 deletions(-) diff --git a/test.py b/test.py index 4e391e7..b71e675 100644 --- a/test.py +++ b/test.py @@ -7,59 +7,99 @@ from tqdm import tqdm from model.dataset import PinyinInputDataset from model.model import InputMethodEngine from model.trainer import collate_fn, worker_init_fn +from .query import QueryEngine -max_iter_length = 5 -batch_size = 1 +import random +import re +from importlib.resources import files +from pathlib import Path +from typing import Dict, List, Tuple -if sys.platform == "win32": - dataset_path = "data" +import numpy as np +import torch +from datasets import load_dataset +from loguru import logger +from modelscope import AutoTokenizer +from pypinyin import lazy_pinyin +from pypinyin.contrib.tone_convert import to_initials +from torch.utils.data import IterableDataset + +tokenizer = AutoTokenizer.from_pretrained( + Path(str(__file__))) / 'src' / 'model' / "assets" / "tokenizer" + ) + +_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") + +CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 +CHAR_TO_ID["`"] = 27 # 显式添加反引号 +CHAR_TO_ID["'"] = 28 # 显式添加引号 +CHAR_TO_ID["-"] = 29 # 显式添加短横 + + +def text_to_pinyin_ids(pinyin_str: str) -> List[int]: + """ + 将拼音字符串转换为 ID 列表。 + 支持 a-z 和 `。 + 未知字符映射为 0 (PAD/UNK)。 + """ + # 使用 dict.get(key, default) 处理未知字符,默认返回 0 + return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] + + +part1 = "他是一名大学生,在上海读" +part2 = "shu" +pinyin_ids = text_to_pinyin_ids(part2) +len_py = len(pinyin_ids) +if len_py < 24: + pinyin_ids.extend([0] * (24 - len_py)) else: - dataset_path = "/home/songsenand/Data/corpus/CCI-Data/" + pinyin_ids = pinyin_ids[:24] +pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) + +part3 = "。" +part4 = "" + +encoded = tokenizer( + f"{part4}|{part1}", + part3, + max_length=128, + padding="max_length", + truncation=True, + return_tensors="pt", + return_token_type_ids=True, +) + +sample = { + "input_ids": torch.stack([encoded["input_ids"].squeeze(0)] + "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)], + "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)], + "history_slot_ids": torch.tensor( + masked_labels, dtype=torch.long + ), + "prefix": f"{part4}^{part1}", + "suffix": part3, + "pinyin": part2, + "pinyin_ids": pinyin_ids, +} dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) - -def create_dataloader(): - dataloader = DataLoader( - dataset, - batch_size=batch_size, - num_workers=1, # Limited to 2 for streaming dataset compatibility - pin_memory=torch.cuda.is_available(), - worker_init_fn=worker_init_fn, - collate_fn=collate_fn, - prefetch_factor=2, # Reduced from 64 to avoid memory issues - persistent_workers=True, - ) - return dataloader - - -samples = [] - -# Create the dataloader -dataloader = create_dataloader() -# Convert to list to test loading (as in original code) -dataloader_list = list([i for i in dataloader]) -print(f"✅ Successfully loaded {len(dataloader_list)} batches") - -# Process batches -for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)): - samples.append(line) - model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) -sample = samples[0] input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"] pinyin_ids = sample["pinyin_ids"] history_slot_ids = sample["history_slot_ids"] + for k, v in sample.items(): if isinstance(v, str): print(f"{k}: {v}") + res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) -sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1]) +sort_res = sorted([(i + 1, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True) print(sort_res[0:5]) # 在test.py的res计算后添加: