SUimeModelTraner/test.py

104 lines
3.5 KiB
Python

import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.trainer import collate_fn, worker_init_fn
# Try to import DataLoader2 from torchdata, fallback to standard DataLoader
try:
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
DATA_LOADER2_AVAILABLE = True
print("✅ Using DataLoader2 from torchdata")
except ImportError:
DATA_LOADER2_AVAILABLE = False
print("⚠️ torchdata not installed, falling back to standard DataLoader")
max_iter_length = 128 * 128
batch_size = 1024
if sys.platform == "win32":
dataset_path = "data"
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
def create_dataloader():
"""
Create dataloader with DataLoader2 if available, otherwise fallback to DataLoader.
This function tries to handle streaming datasets better with DataLoader2.
"""
if DATA_LOADER2_AVAILABLE:
try:
# DataLoader2 configuration for streaming datasets
# Use MultiProcessingReadingService with careful worker settings
reading_service = MultiProcessingReadingService(
num_workers=2, # Start with 2 workers for streaming dataset
prefetch_factor=2, # Reduced prefetch for better memory management
persistent_workers=True,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
)
dataloader = DataLoader2(
dataset,
reading_service=reading_service,
batch_size=batch_size,
collate_fn=collate_fn,
shuffle=False, # Dataset handles shuffling internally
)
print(f"✅ Created DataLoader2 with {2} workers")
return dataloader
except Exception as e:
print(f"⚠️ DataLoader2 creation failed: {e}, falling back to DataLoader")
# Fallback to standard DataLoader
print("📊 Using standard DataLoader")
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=2, # Limited to 2 for streaming dataset compatibility
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # Reduced from 64 to avoid memory issues
persistent_workers=True,
)
return dataloader
# Create the dataloader
dataloader = create_dataloader()
# Test the dataloader
print(f"🔍 Testing dataloader with batch_size={batch_size}")
print(f" Dataset max_iter_length: {max_iter_length}")
print(f" Expected batches: {max_iter_length / batch_size:.0f}")
try:
# Convert to list to test loading (as in original code)
dataloader_list = list([i for i in dataloader])
print(f"✅ Successfully loaded {len(dataloader_list)} batches")
# Process batches
for i, line in tqdm(enumerate(dataloader_list), total=len(dataloader_list)):
zero_labels = (line["labels"] == 0).sum()
print(f"Batch {i}: labels==0 count = {zero_labels.item()}")
# Early exit for testing
if i >= 5: # Limit to 5 batches for quick testing
print("⚠️ Limited to 5 batches for testing")
break
except Exception as e:
print(f"❌ Error during dataloader iteration: {e}")
import traceback
traceback.print_exc()
print("🏁 Test completed")