from model.dataset import PinyinInputDataset from torch.utils.data import DataLoader from model.trainer import collate_fn, worker_init_fn data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/') dataloader = DataLoader( data, batch_size=1024, num_workers=2, worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=2, # 减少预取以避免内存问题 persistent_workers=True, shuffle=False, ) for i in dataloader: print((i['labels'] == 1).sum()) break