32 lines
941 B
Plaintext
32 lines
941 B
Plaintext
import sys
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from model.dataset import PinyinInputDataset
|
||
from model.trainer import collate_fn, worker_init_fn
|
||
|
||
max_iter_length = 128 * 128
|
||
batch_size = 1024
|
||
|
||
if sys.platform == "win32":
|
||
dataset_path = "data"
|
||
else:
|
||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
||
|
||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||
dataloader = DataLoader(
|
||
dataset,
|
||
batch_size=batch_size,
|
||
num_workers=2,
|
||
pin_memory=torch.cuda.is_available(),
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景
|
||
persistent_workers=True, # 保持worker进程存活,避免重建开销
|
||
)
|
||
dataloader = list([i for i in dataloader])
|
||
for i, line in tqdm(enumerate(dataloader), total=max_iter_length / batch_size):
|
||
print((line["labels"] == 0).sum())
|