feat: 优化拼音获取逻辑,添加 pinyin_list 参数提升性能
This commit is contained in:
parent
f6c58ab4c6
commit
1bdbbe284c
|
|
@ -1,14 +1,14 @@
|
|||
import torch
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
from pypinyin import lazy_pinyin
|
||||
import random
|
||||
from modelscope import AutoTokenizer
|
||||
from typing import Tuple, List, Dict, Any
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
from pypinyin import lazy_pinyin
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
|
||||
class PinyinInputDataset(IterableDataset):
|
||||
|
|
@ -101,7 +101,11 @@ class PinyinInputDataset(IterableDataset):
|
|||
return bool(self.chinese_pattern.match(char))
|
||||
|
||||
def get_next_chinese_chars(
|
||||
self, text: str, start_idx: int, max_count: int = 3
|
||||
self,
|
||||
text: str,
|
||||
start_idx: int,
|
||||
max_count: int = 3,
|
||||
pinyin_list: List[str] = None,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
获取后续的中文字符及其拼音
|
||||
|
|
@ -127,6 +131,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
try:
|
||||
# 重新计算整个text的拼音可能效率低,但确保准确
|
||||
# 实际实现中可以考虑缓存或优化
|
||||
if pinyin_list is None:
|
||||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
if i < len(pinyin_list):
|
||||
result.append((char, pinyin_list[i]))
|
||||
|
|
@ -478,7 +483,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
continue
|
||||
|
||||
# 获取后续最多3个中文字符的拼音
|
||||
next_chars = self.get_next_chinese_chars(text, i, max_count=3)
|
||||
next_chars = self.get_next_chinese_chars(
|
||||
text, i, max_count=3, pinyin_list=pinyin_list
|
||||
)
|
||||
next_pinyins = [py] + [p for _, p in next_chars]
|
||||
# 获取前文上下文(最多100字符)
|
||||
context = text[max(0, i - 100) : i]
|
||||
|
|
|
|||
Loading…
Reference in New Issue