SUimeModelTraner/check_weights.py

23 lines
525 B
Python
Executable File

from model.dataset import PinyinInputDataset
from torch.utils.data import DataLoader
from model.trainer import collate_fn, worker_init_fn
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
dataloader = DataLoader(
data,
batch_size=1024,
num_workers=2,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
shuffle=False,
)
for i in dataloader:
print((i['labels'] == 1).sum())
break