diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 91620e5..611ad24 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -440,7 +440,9 @@ class PinyinInputDataset(IterableDataset): "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], - "pg": torch.tensor([self.pg_groups[char_info["pinyin"][0]]]), + "pg": torch.tensor( + [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] + ), } # 根据调整因子重复样本