55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
import pickle
|
|
from pathlib import Path
|
|
|
|
from loguru import logger
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from suinput.dataset import PinyinInputDataset, custom_collate_with_txt, worker_init_fn
|
|
from suinput.query import QueryEngine
|
|
|
|
# 使用示例
|
|
if __name__ == "__main__":
|
|
# 初始化查询引擎
|
|
query_engine = QueryEngine()
|
|
query_engine.load()
|
|
|
|
# 创建数据集
|
|
dataset = PinyinInputDataset(
|
|
data_dir="/home/songsenand/DataSet/data",
|
|
query_engine=query_engine,
|
|
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
|
|
max_len=88,
|
|
batch_query_size=300,
|
|
shuffle=True,
|
|
shuffle_buffer_size=4000,
|
|
drop_py_rate=0
|
|
)
|
|
logger.info("数据集初始化")
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=1024,
|
|
num_workers=1,
|
|
worker_init_fn=worker_init_fn,
|
|
# pin_memory=True if torch.cuda.is_available() else False,
|
|
collate_fn=custom_collate_with_txt,
|
|
prefetch_factor=8,
|
|
persistent_workers=True,
|
|
shuffle=False, # 数据集内部已实现打乱
|
|
)
|
|
try:
|
|
total = 5
|
|
for i, sample in tqdm(enumerate(dataloader), total=5):
|
|
if i >= total:
|
|
break
|
|
# print(sample)
|
|
pickle.dump(
|
|
sample,
|
|
open(
|
|
f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl",
|
|
"wb",
|
|
),
|
|
)
|
|
except StopIteration:
|
|
print("数据集为空")
|