fix(trainer): 使用固定最大序列长度的collate函数以避免内存问题

This commit is contained in:
songsenand 2026-05-15 14:47:31 +08:00
parent 722912f296
commit 71ef54e3d4
1 changed files with 2 additions and 1 deletions

View File

@ -1100,13 +1100,14 @@ def create_dataloader(
) )
logger.info(f"📊 使用标准DataLoaderworker数量: {num_workers}") logger.info(f"📊 使用标准DataLoaderworker数量: {num_workers}")
fixed_max_seq_length = getattr(dataset, "max_seq_length", 128)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
collate_fn=collate_fn, collate_fn=preprocess_collate_fn(fixed_max_seq_length),
prefetch_factor=2, # 减少预取以避免内存问题 prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True, persistent_workers=True,
shuffle=shuffle, shuffle=shuffle,