From 710cfe7fc20951a99141d71cbb6091db536ebd5f Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 16 Apr 2026 22:35:59 +0800 Subject: [PATCH] =?UTF-8?q?fix(dataset,=20trainer):=20=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E5=92=8C=E8=AE=AD=E7=BB=83=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=BB=A5=E6=8F=90=E9=AB=98=E6=A8=A1=E5=9E=8B=E6=95=88?= =?UTF-8?q?=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/dataset.py | 2 +- src/model/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model/dataset.py b/src/model/dataset.py index 131ed34..4a07de5 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -43,7 +43,7 @@ class PinyinInputDataset(IterableDataset): text_field: str = "text", py_style_weight=(9, 2, 1), shuffle_buffer_size: int = 100000, - retention_ratio: float = 0.5, + retention_ratio: float = 0.8, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ): # 频率调整参数 (可根据需要调整) diff --git a/src/model/trainer.py b/src/model/trainer.py index f63309d..1509618 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1221,7 +1221,7 @@ def train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, + shuffle_buffer_size=2000000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, )