import sys import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.dataset import PinyinInputDataset from model.trainer import collate_fn, worker_init_fn max_iter_length = 128 * 128 batch_size = 1024 if sys.platform == "win32": dataset_path = "data" else: dataset_path = "/home/songsenand/Data/corpus/CCI-Data/" dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=2, pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景 persistent_workers=True, # 保持worker进程存活,避免重建开销 ) dataloader = list([i for i in dataloader]) for i, line in tqdm(enumerate(dataloader), total=max_iter_length / batch_size): print((line["labels"] == 0).sum())