添加拼音丢弃率参数并根据该参数决定是否丢弃拼音

This commit is contained in:
songsenand 2026-02-22 15:36:06 +08:00
parent 398155721d
commit a82279b02a
1 changed files with 3 additions and 1 deletions

View File

@ -71,6 +71,7 @@ class PinyinInputDataset(IterableDataset):
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section=[0.90, 0.95, 1],
drop_py_rate: float = 0.30,
):
"""
初始化数据集
@ -126,6 +127,7 @@ class PinyinInputDataset(IterableDataset):
# 上下文采样方式概率区间
self.sample_context_section = sample_context_section
self.drop_py_rate = drop_py_rate
def get_next_chinese_chars(
self,
@ -440,7 +442,7 @@ class PinyinInputDataset(IterableDataset):
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12
prob = random.random()
if prob < 0.3:
if prob < self.drop_py_rate:
py = ""
else:
py = processed_pinyin