fix(dataset, trainer): 调整数据集和训练参数以提高模型效果
This commit is contained in:
parent
3175ace9c5
commit
710cfe7fc2
|
|
@ -43,7 +43,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size: int = 100000,
|
shuffle_buffer_size: int = 100000,
|
||||||
retention_ratio: float = 0.5,
|
retention_ratio: float = 0.8,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
):
|
):
|
||||||
# 频率调整参数 (可根据需要调整)
|
# 频率调整参数 (可根据需要调整)
|
||||||
|
|
|
||||||
|
|
@ -1221,7 +1221,7 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=100000,
|
shuffle_buffer_size=2000000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue