添加拼音丢弃率参数并根据该参数决定是否丢弃拼音
This commit is contained in:
parent
398155721d
commit
a82279b02a
|
|
@ -71,6 +71,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
sample_context_section=[0.90, 0.95, 1],
|
sample_context_section=[0.90, 0.95, 1],
|
||||||
|
drop_py_rate: float = 0.30,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -126,6 +127,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 上下文采样方式概率区间
|
# 上下文采样方式概率区间
|
||||||
self.sample_context_section = sample_context_section
|
self.sample_context_section = sample_context_section
|
||||||
|
self.drop_py_rate = drop_py_rate
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self,
|
self,
|
||||||
|
|
@ -440,7 +442,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12
|
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12
|
||||||
prob = random.random()
|
prob = random.random()
|
||||||
if prob < 0.3:
|
if prob < self.drop_py_rate:
|
||||||
py = ""
|
py = ""
|
||||||
else:
|
else:
|
||||||
py = processed_pinyin
|
py = processed_pinyin
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue