fix(trainer): 使用固定最大序列长度的collate函数以避免内存问题
This commit is contained in:
parent
722912f296
commit
71ef54e3d4
|
|
@ -1100,13 +1100,14 @@ def create_dataloader(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||||||
|
fixed_max_seq_length = getattr(dataset, "max_seq_length", 128)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
collate_fn=collate_fn,
|
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
|
||||||
prefetch_factor=2, # 减少预取以避免内存问题
|
prefetch_factor=2, # 减少预取以避免内存问题
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue