feat: 优化拼音获取逻辑,添加 pinyin_list 参数提升性能

This commit is contained in:
songsenand 2026-02-09 10:53:27 +08:00
parent f6c58ab4c6
commit 1bdbbe284c
1 changed files with 17 additions and 10 deletions

View File

@ -1,14 +1,14 @@
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from pypinyin import lazy_pinyin
import random
from modelscope import AutoTokenizer
from typing import Tuple, List, Dict, Any
import re
import numpy as np
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset
class PinyinInputDataset(IterableDataset):
@ -101,7 +101,11 @@ class PinyinInputDataset(IterableDataset):
return bool(self.chinese_pattern.match(char))
def get_next_chinese_chars(
self, text: str, start_idx: int, max_count: int = 3
self,
text: str,
start_idx: int,
max_count: int = 3,
pinyin_list: List[str] = None,
) -> List[Tuple[str, str]]:
"""
获取后续的中文字符及其拼音
@ -127,6 +131,7 @@ class PinyinInputDataset(IterableDataset):
try:
# 重新计算整个text的拼音可能效率低但确保准确
# 实际实现中可以考虑缓存或优化
if pinyin_list is None:
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
if i < len(pinyin_list):
result.append((char, pinyin_list[i]))
@ -478,7 +483,9 @@ class PinyinInputDataset(IterableDataset):
continue
# 获取后续最多3个中文字符的拼音
next_chars = self.get_next_chinese_chars(text, i, max_count=3)
next_chars = self.get_next_chinese_chars(
text, i, max_count=3, pinyin_list=pinyin_list
)
next_pinyins = [py] + [p for _, p in next_chars]
# 获取前文上下文最多100字符
context = text[max(0, i - 100) : i]