SUimeModelTraner/test.py

30 lines
901 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.trainer import collate_fn, worker_init_fn
if sys.platform == "win32":
dataset_path = "data"
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
dataset = PinyinInputDataset(dataset_path, max_iter_length=128 * 128)
dataloader = DataLoader(
dataset,
batch_size=128,
num_workers=2,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch适合大内存场景
persistent_workers=True, # 保持worker进程存活避免重建开销
)
dataloader = list([i for i in dataloader])
print(len(dataloader[0]["labels"]))
for i, line in tqdm(enumerate(dataloader), total=128):
print(line["pinyin_ids"].squeeze(-1).shape)