From a82279b02a504f0ab98e50cb7cf74321a7ae2fce Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 22 Feb 2026 15:36:06 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=8B=BC=E9=9F=B3=E4=B8=A2?= =?UTF-8?q?=E5=BC=83=E7=8E=87=E5=8F=82=E6=95=B0=E5=B9=B6=E6=A0=B9=E6=8D=AE?= =?UTF-8?q?=E8=AF=A5=E5=8F=82=E6=95=B0=E5=86=B3=E5=AE=9A=E6=98=AF=E5=90=A6?= =?UTF-8?q?=E4=B8=A2=E5=BC=83=E6=8B=BC=E9=9F=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 8efea57..ffce948 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -71,6 +71,7 @@ class PinyinInputDataset(IterableDataset): max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 sample_context_section=[0.90, 0.95, 1], + drop_py_rate: float = 0.30, ): """ 初始化数据集 @@ -126,6 +127,7 @@ class PinyinInputDataset(IterableDataset): # 上下文采样方式概率区间 self.sample_context_section = sample_context_section + self.drop_py_rate = drop_py_rate def get_next_chinese_chars( self, @@ -440,7 +442,7 @@ class PinyinInputDataset(IterableDataset): pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12 prob = random.random() - if prob < 0.3: + if prob < self.drop_py_rate: py = "" else: py = processed_pinyin