import pickle from pathlib import Path from loguru import logger from torch.utils.data import DataLoader from tqdm import tqdm from suinput.dataset import PinyinInputDataset, custom_collate_with_txt, worker_init_fn from suinput.query import QueryEngine # 使用示例 if __name__ == "__main__": # 初始化查询引擎 query_engine = QueryEngine() query_engine.load() # 创建数据集 dataset = PinyinInputDataset( data_dir="/home/songsenand/DataSet/data", query_engine=query_engine, tokenizer_name="iic/nlp_structbert_backbone_lite_std", max_len=88, batch_query_size=300, shuffle=True, shuffle_buffer_size=4000, drop_py_rate=0 ) logger.info("数据集初始化") dataloader = DataLoader( dataset, batch_size=1024, num_workers=1, worker_init_fn=worker_init_fn, # pin_memory=True if torch.cuda.is_available() else False, collate_fn=custom_collate_with_txt, prefetch_factor=8, persistent_workers=True, shuffle=False, # 数据集内部已实现打乱 ) try: total = 5 for i, sample in tqdm(enumerate(dataloader), total=5): if i >= total: break # print(sample) pickle.dump( sample, open( f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb", ), ) except StopIteration: print("数据集为空")