diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 8efea57..ffce948 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -71,6 +71,7 @@ class PinyinInputDataset(IterableDataset): max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 sample_context_section=[0.90, 0.95, 1], + drop_py_rate: float = 0.30, ): """ 初始化数据集 @@ -126,6 +127,7 @@ class PinyinInputDataset(IterableDataset): # 上下文采样方式概率区间 self.sample_context_section = sample_context_section + self.drop_py_rate = drop_py_rate def get_next_chinese_chars( self, @@ -440,7 +442,7 @@ class PinyinInputDataset(IterableDataset): pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12 prob = random.random() - if prob < 0.3: + if prob < self.drop_py_rate: py = "" else: py = processed_pinyin