diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 0ed0886..ff5cafb 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,14 +1,14 @@ -import torch -from torch.utils.data import IterableDataset, DataLoader -from datasets import load_dataset -from pypinyin import lazy_pinyin import random -from modelscope import AutoTokenizer -from typing import Tuple, List, Dict, Any import re -import numpy as np +from typing import Any, Dict, List, Tuple +import numpy as np +import torch +from datasets import load_dataset from loguru import logger +from modelscope import AutoTokenizer +from pypinyin import lazy_pinyin +from torch.utils.data import DataLoader, IterableDataset class PinyinInputDataset(IterableDataset): @@ -101,7 +101,11 @@ class PinyinInputDataset(IterableDataset): return bool(self.chinese_pattern.match(char)) def get_next_chinese_chars( - self, text: str, start_idx: int, max_count: int = 3 + self, + text: str, + start_idx: int, + max_count: int = 3, + pinyin_list: List[str] = None, ) -> List[Tuple[str, str]]: """ 获取后续的中文字符及其拼音 @@ -127,7 +131,8 @@ class PinyinInputDataset(IterableDataset): try: # 重新计算整个text的拼音可能效率低,但确保准确 # 实际实现中可以考虑缓存或优化 - pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) + if pinyin_list is None: + pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) if i < len(pinyin_list): result.append((char, pinyin_list[i])) count += 1 @@ -478,7 +483,9 @@ class PinyinInputDataset(IterableDataset): continue # 获取后续最多3个中文字符的拼音 - next_chars = self.get_next_chinese_chars(text, i, max_count=3) + next_chars = self.get_next_chinese_chars( + text, i, max_count=3, pinyin_list=pinyin_list + ) next_pinyins = [py] + [p for _, p in next_chars] # 获取前文上下文(最多100字符) context = text[max(0, i - 100) : i]