From 1bdbbe284cb915002eab95dfcff3b0aed5e6b93e Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 9 Feb 2026 10:53:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E6=8B=BC=E9=9F=B3?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20pinyin=5Flist=20=E5=8F=82=E6=95=B0=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 0ed0886..ff5cafb 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,14 +1,14 @@ -import torch -from torch.utils.data import IterableDataset, DataLoader -from datasets import load_dataset -from pypinyin import lazy_pinyin import random -from modelscope import AutoTokenizer -from typing import Tuple, List, Dict, Any import re -import numpy as np +from typing import Any, Dict, List, Tuple +import numpy as np +import torch +from datasets import load_dataset from loguru import logger +from modelscope import AutoTokenizer +from pypinyin import lazy_pinyin +from torch.utils.data import DataLoader, IterableDataset class PinyinInputDataset(IterableDataset): @@ -101,7 +101,11 @@ class PinyinInputDataset(IterableDataset): return bool(self.chinese_pattern.match(char)) def get_next_chinese_chars( - self, text: str, start_idx: int, max_count: int = 3 + self, + text: str, + start_idx: int, + max_count: int = 3, + pinyin_list: List[str] = None, ) -> List[Tuple[str, str]]: """ 获取后续的中文字符及其拼音 @@ -127,7 +131,8 @@ class PinyinInputDataset(IterableDataset): try: # 重新计算整个text的拼音可能效率低,但确保准确 # 实际实现中可以考虑缓存或优化 - pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) + if pinyin_list is None: + pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) if i < len(pinyin_list): result.append((char, pinyin_list[i])) count += 1 @@ -478,7 +483,9 @@ class PinyinInputDataset(IterableDataset): continue # 获取后续最多3个中文字符的拼音 - next_chars = self.get_next_chinese_chars(text, i, max_count=3) + next_chars = self.get_next_chinese_chars( + text, i, max_count=3, pinyin_list=pinyin_list + ) next_pinyins = [py] + [p for _, p in next_chars] # 获取前文上下文(最多100字符) context = text[max(0, i - 100) : i]