diff --git a/src/tmp_utils/gen_eval_dataset.py b/src/tmp_utils/gen_eval_dataset.py index 24a1852..d13b4c3 100644 --- a/src/tmp_utils/gen_eval_dataset.py +++ b/src/tmp_utils/gen_eval_dataset.py @@ -1,13 +1,13 @@ -from tqdm import tqdm -from loguru import logger -import torch -from torch.utils.data import DataLoader import pickle from pathlib import Path -from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt -from suinput.query import QueryEngine +import torch +from loguru import logger +from torch.utils.data import DataLoader +from tqdm import tqdm +from suinput.dataset import PinyinInputDataset, custom_collate_with_txt, worker_init_fn +from suinput.query import QueryEngine # 使用示例 if __name__ == "__main__": @@ -42,7 +42,13 @@ if __name__ == "__main__": for i, sample in tqdm(enumerate(dataloader), total=5): if i >= total: break - print(sample) - # pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb")) + # print(sample) + pickle.dump( + sample, + open( + f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", + "wb", + ), + ) except StopIteration: print("数据集为空")