feat: 优化拼音获取逻辑,添加 pinyin_list 参数提升性能
This commit is contained in:
parent
f6c58ab4c6
commit
1bdbbe284c
|
|
@ -1,14 +1,14 @@
|
||||||
import torch
|
|
||||||
from torch.utils.data import IterableDataset, DataLoader
|
|
||||||
from datasets import load_dataset
|
|
||||||
from pypinyin import lazy_pinyin
|
|
||||||
import random
|
import random
|
||||||
from modelscope import AutoTokenizer
|
|
||||||
from typing import Tuple, List, Dict, Any
|
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from modelscope import AutoTokenizer
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
|
|
@ -101,7 +101,11 @@ class PinyinInputDataset(IterableDataset):
|
||||||
return bool(self.chinese_pattern.match(char))
|
return bool(self.chinese_pattern.match(char))
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self, text: str, start_idx: int, max_count: int = 3
|
self,
|
||||||
|
text: str,
|
||||||
|
start_idx: int,
|
||||||
|
max_count: int = 3,
|
||||||
|
pinyin_list: List[str] = None,
|
||||||
) -> List[Tuple[str, str]]:
|
) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
获取后续的中文字符及其拼音
|
获取后续的中文字符及其拼音
|
||||||
|
|
@ -127,6 +131,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
try:
|
try:
|
||||||
# 重新计算整个text的拼音可能效率低,但确保准确
|
# 重新计算整个text的拼音可能效率低,但确保准确
|
||||||
# 实际实现中可以考虑缓存或优化
|
# 实际实现中可以考虑缓存或优化
|
||||||
|
if pinyin_list is None:
|
||||||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
if i < len(pinyin_list):
|
if i < len(pinyin_list):
|
||||||
result.append((char, pinyin_list[i]))
|
result.append((char, pinyin_list[i]))
|
||||||
|
|
@ -478,7 +483,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取后续最多3个中文字符的拼音
|
# 获取后续最多3个中文字符的拼音
|
||||||
next_chars = self.get_next_chinese_chars(text, i, max_count=3)
|
next_chars = self.get_next_chinese_chars(
|
||||||
|
text, i, max_count=3, pinyin_list=pinyin_list
|
||||||
|
)
|
||||||
next_pinyins = [py] + [p for _, p in next_chars]
|
next_pinyins = [py] + [p for _, p in next_chars]
|
||||||
# 获取前文上下文(最多100字符)
|
# 获取前文上下文(最多100字符)
|
||||||
context = text[max(0, i - 100) : i]
|
context = text[max(0, i - 100) : i]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue