diff --git a/src/model/trainer.py b/src/model/trainer.py index cbea273..f20ef98 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1100,13 +1100,14 @@ def create_dataloader( ) logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}") + fixed_max_seq_length = getattr(dataset, "max_seq_length", 128) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, worker_init_fn=worker_init_fn, - collate_fn=collate_fn, + collate_fn=preprocess_collate_fn(fixed_max_seq_length), prefetch_factor=2, # 减少预取以避免内存问题 persistent_workers=True, shuffle=shuffle,