SUimeModelTraner/scripts/check_weights.py

28 lines
618 B
Python
Executable File

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.dataset import PinyinInputDataset
from torch.utils.data import DataLoader
from model.trainer import collate_fn, worker_init_fn
data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
dataloader = DataLoader(
data,
batch_size=1024,
num_workers=2,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
shuffle=False,
)
for i in dataloader:
print((i["labels"] == 1).sum())
break