SUInput/src/tmp_utils/gen_eval_dataset.py

55 lines
1.6 KiB
Python

import pickle
from pathlib import Path
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm
from suinput.dataset import PinyinInputDataset, custom_collate_with_txt, worker_init_fn
from suinput.query import QueryEngine
# 使用示例
if __name__ == "__main__":
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/DataSet/data",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
drop_py_rate=0
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=1024,
num_workers=1,
worker_init_fn=worker_init_fn,
# pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
)
try:
total = 5
for i, sample in tqdm(enumerate(dataloader), total=5):
if i >= total:
break
# print(sample)
pickle.dump(
sample,
open(
f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl",
"wb",
),
)
except StopIteration:
print("数据集为空")