调整导入顺序并修复pickle保存逻辑
This commit is contained in:
parent
134c8a09cf
commit
d2d65c7efa
|
|
@ -1,13 +1,13 @@
|
||||||
from tqdm import tqdm
|
|
||||||
from loguru import logger
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt
|
import torch
|
||||||
from suinput.query import QueryEngine
|
from loguru import logger
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from suinput.dataset import PinyinInputDataset, custom_collate_with_txt, worker_init_fn
|
||||||
|
from suinput.query import QueryEngine
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -42,7 +42,13 @@ if __name__ == "__main__":
|
||||||
for i, sample in tqdm(enumerate(dataloader), total=5):
|
for i, sample in tqdm(enumerate(dataloader), total=5):
|
||||||
if i >= total:
|
if i >= total:
|
||||||
break
|
break
|
||||||
print(sample)
|
# print(sample)
|
||||||
# pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb"))
|
pickle.dump(
|
||||||
|
sample,
|
||||||
|
open(
|
||||||
|
f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl",
|
||||||
|
"wb",
|
||||||
|
),
|
||||||
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print("数据集为空")
|
print("数据集为空")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue