fix(trainer): 使用固定最大序列长度的collate函数以避免内存问题

This commit is contained in:
songsenand 2026-05-15 14:47:31 +08:00
parent 722912f296
commit 71ef54e3d4
1 changed files with 2 additions and 1 deletions

View File

@ -1100,13 +1100,14 @@ def create_dataloader(
)
logger.info(f"📊 使用标准DataLoaderworker数量: {num_workers}")
fixed_max_seq_length = getattr(dataset, "max_seq_length", 128)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
shuffle=shuffle,