fix(trainer): 使用固定最大序列长度的collate函数以避免内存问题
This commit is contained in:
parent
722912f296
commit
71ef54e3d4
|
|
@ -1100,13 +1100,14 @@ def create_dataloader(
|
|||
)
|
||||
|
||||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||||
fixed_max_seq_length = getattr(dataset, "max_seq_length", 128)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
|
||||
prefetch_factor=2, # 减少预取以避免内存问题
|
||||
persistent_workers=True,
|
||||
shuffle=shuffle,
|
||||
|
|
|
|||
Loading…
Reference in New Issue