import sys import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.dataset import PinyinInputDataset from model.trainer import collate_fn, worker_init_fn # Try to import DataLoader2 from torchdata, fallback to standard DataLoader try: from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService DATA_LOADER2_AVAILABLE = True print("✅ Using DataLoader2 from torchdata") except ImportError: DATA_LOADER2_AVAILABLE = False print("⚠️ torchdata not installed, falling back to standard DataLoader") max_iter_length = 128 * 128 batch_size = 1024 if sys.platform == "win32": dataset_path = "data" else: dataset_path = "/home/songsenand/Data/corpus/CCI-Data/" dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) def create_dataloader(): """ Create dataloader with DataLoader2 if available, otherwise fallback to DataLoader. This function tries to handle streaming datasets better with DataLoader2. """ if DATA_LOADER2_AVAILABLE: try: # DataLoader2 configuration for streaming datasets # Use MultiProcessingReadingService with careful worker settings reading_service = MultiProcessingReadingService( num_workers=2, # Start with 2 workers for streaming dataset prefetch_factor=2, # Reduced prefetch for better memory management persistent_workers=True, pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, ) dataloader = DataLoader2( dataset, reading_service=reading_service, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, # Dataset handles shuffling internally ) print(f"✅ Created DataLoader2 with {2} workers") return dataloader except Exception as e: print(f"⚠️ DataLoader2 creation failed: {e}, falling back to DataLoader") # Fallback to standard DataLoader print("📊 Using standard DataLoader") dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=2, # Limited to 2 for streaming dataset compatibility pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=2, # Reduced from 64 to avoid memory issues persistent_workers=True, ) return dataloader # Create the dataloader dataloader = create_dataloader() # Test the dataloader print(f"🔍 Testing dataloader with batch_size={batch_size}") print(f" Dataset max_iter_length: {max_iter_length}") print(f" Expected batches: {max_iter_length / batch_size:.0f}") try: # Convert to list to test loading (as in original code) dataloader_list = list([i for i in dataloader]) print(f"✅ Successfully loaded {len(dataloader_list)} batches") # Process batches for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)): zero_labels = (line["labels"] == 0).sum() print(f"Batch {i}: labels==0 count = {zero_labels.item()}") # Early exit for testing if i >= 5: # Limit to 5 batches for quick testing print("⚠️ Limited to 5 batches for testing") break except Exception as e: print(f"❌ Error during dataloader iteration: {e}") import traceback traceback.print_exc() print("🏁 Test completed")