From 51f9ddbc707931f08ca5f3840b9b7f0567aab28a Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 21 Feb 2026 22:01:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8B=BC=E9=9F=B3=E7=BB=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E9=81=BF=E5=85=8D?= =?UTF-8?q?=E6=9C=AA=E5=A4=84=E7=90=86=E6=8B=BC=E9=9F=B3=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=9A=84=E7=B4=A2=E5=BC=95=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 91620e5..611ad24 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -440,7 +440,9 @@ class PinyinInputDataset(IterableDataset): "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], - "pg": torch.tensor([self.pg_groups[char_info["pinyin"][0]]]), + "pg": torch.tensor( + [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] + ), } # 根据调整因子重复样本