diff --git a/src/model/dataset.py b/src/model/dataset.py index 131ed34..4a07de5 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -43,7 +43,7 @@ class PinyinInputDataset(IterableDataset): text_field: str = "text", py_style_weight=(9, 2, 1), shuffle_buffer_size: int = 100000, - retention_ratio: float = 0.5, + retention_ratio: float = 0.8, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ): # 频率调整参数 (可根据需要调整) diff --git a/src/model/trainer.py b/src/model/trainer.py index f63309d..1509618 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1221,7 +1221,7 @@ def train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, + shuffle_buffer_size=2000000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, )