104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import sys
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from model.dataset import PinyinInputDataset
|
|
from model.trainer import collate_fn, worker_init_fn
|
|
|
|
# Try to import DataLoader2 from torchdata, fallback to standard DataLoader
|
|
try:
|
|
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
|
|
|
|
DATA_LOADER2_AVAILABLE = True
|
|
print("✅ Using DataLoader2 from torchdata")
|
|
except ImportError:
|
|
DATA_LOADER2_AVAILABLE = False
|
|
print("⚠️ torchdata not installed, falling back to standard DataLoader")
|
|
|
|
max_iter_length = 128 * 128
|
|
batch_size = 1024
|
|
|
|
if sys.platform == "win32":
|
|
dataset_path = "data"
|
|
else:
|
|
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
|
|
|
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
|
|
|
|
|
def create_dataloader():
|
|
"""
|
|
Create dataloader with DataLoader2 if available, otherwise fallback to DataLoader.
|
|
This function tries to handle streaming datasets better with DataLoader2.
|
|
"""
|
|
if DATA_LOADER2_AVAILABLE:
|
|
try:
|
|
# DataLoader2 configuration for streaming datasets
|
|
# Use MultiProcessingReadingService with careful worker settings
|
|
reading_service = MultiProcessingReadingService(
|
|
num_workers=2, # Start with 2 workers for streaming dataset
|
|
prefetch_factor=2, # Reduced prefetch for better memory management
|
|
persistent_workers=True,
|
|
pin_memory=torch.cuda.is_available(),
|
|
worker_init_fn=worker_init_fn,
|
|
)
|
|
|
|
dataloader = DataLoader2(
|
|
dataset,
|
|
reading_service=reading_service,
|
|
batch_size=batch_size,
|
|
collate_fn=collate_fn,
|
|
shuffle=False, # Dataset handles shuffling internally
|
|
)
|
|
print(f"✅ Created DataLoader2 with {2} workers")
|
|
return dataloader
|
|
except Exception as e:
|
|
print(f"⚠️ DataLoader2 creation failed: {e}, falling back to DataLoader")
|
|
|
|
# Fallback to standard DataLoader
|
|
print("📊 Using standard DataLoader")
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
num_workers=2, # Limited to 2 for streaming dataset compatibility
|
|
pin_memory=torch.cuda.is_available(),
|
|
worker_init_fn=worker_init_fn,
|
|
collate_fn=collate_fn,
|
|
prefetch_factor=2, # Reduced from 64 to avoid memory issues
|
|
persistent_workers=True,
|
|
)
|
|
return dataloader
|
|
|
|
|
|
# Create the dataloader
|
|
dataloader = create_dataloader()
|
|
|
|
# Test the dataloader
|
|
print(f"🔍 Testing dataloader with batch_size={batch_size}")
|
|
print(f" Dataset max_iter_length: {max_iter_length}")
|
|
print(f" Expected batches: {max_iter_length / batch_size:.0f}")
|
|
|
|
try:
|
|
# Convert to list to test loading (as in original code)
|
|
dataloader_list = list([i for i in dataloader])
|
|
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
|
|
|
|
# Process batches
|
|
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
|
|
zero_labels = (line["labels"] == 0).sum()
|
|
print(f"Batch {i}: labels==0 count = {zero_labels.item()}")
|
|
# Early exit for testing
|
|
if i >= 5: # Limit to 5 batches for quick testing
|
|
print("⚠️ Limited to 5 batches for testing")
|
|
break
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error during dataloader iteration: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
print("🏁 Test completed")
|