添加拼音丢弃率参数并根据该参数决定是否丢弃拼音
This commit is contained in:
parent
398155721d
commit
a82279b02a
|
|
@ -71,6 +71,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
sample_context_section=[0.90, 0.95, 1],
|
||||
drop_py_rate: float = 0.30,
|
||||
):
|
||||
"""
|
||||
初始化数据集
|
||||
|
|
@ -126,6 +127,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 上下文采样方式概率区间
|
||||
self.sample_context_section = sample_context_section
|
||||
self.drop_py_rate = drop_py_rate
|
||||
|
||||
def get_next_chinese_chars(
|
||||
self,
|
||||
|
|
@ -440,7 +442,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12
|
||||
prob = random.random()
|
||||
if prob < 0.3:
|
||||
if prob < self.drop_py_rate:
|
||||
py = ""
|
||||
else:
|
||||
py = processed_pinyin
|
||||
|
|
|
|||
Loading…
Reference in New Issue