Compare commits
2 Commits
ea54c7da39
...
5b1c6fcb2b
| Author | SHA1 | Date |
|---|---|---|
|
|
5b1c6fcb2b | |
|
|
1cbb0b07c4 |
|
|
@ -1,8 +1,9 @@
|
|||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -12,9 +13,9 @@ from modelscope import AutoTokenizer
|
|||
from pypinyin import lazy_pinyin
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class PinyinInputDataset(IterableDataset):
|
||||
"""
|
||||
拼音输入法模拟数据集
|
||||
|
|
@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
|
||||
py_group_json_file: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
初始化数据集
|
||||
|
|
@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 加载数据集
|
||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||
|
||||
with open(py_group_json_file, "r") as f:
|
||||
self.py_groups = json.load(f)
|
||||
|
||||
# 加载拼音分组
|
||||
self.pg_groups = {
|
||||
"y": 0,
|
||||
"k": 0,
|
||||
"e": 0,
|
||||
"l": 1,
|
||||
"w": 1,
|
||||
"f": 1,
|
||||
"q": 2,
|
||||
"a": 2,
|
||||
"s": 2,
|
||||
"x": 3,
|
||||
"b": 3,
|
||||
"r": 3,
|
||||
"o": 4,
|
||||
"m": 4,
|
||||
"z": 4,
|
||||
"g": 5,
|
||||
"n": 5,
|
||||
"c": 5,
|
||||
"t": 6,
|
||||
"p": 6,
|
||||
"d": 6,
|
||||
"j": 7,
|
||||
"h": 7,
|
||||
}
|
||||
|
||||
def get_next_chinese_chars(
|
||||
self,
|
||||
|
|
@ -391,7 +415,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not char_info:
|
||||
continue
|
||||
|
||||
logger.info(f'获取字符信息: {char_info}')
|
||||
logger.info(f"获取字符信息: {char_info}")
|
||||
# 削峰填谷调整
|
||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||
if adjust_factor <= 0:
|
||||
|
|
@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 拼音处理
|
||||
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||||
if not processed_pinyin:
|
||||
continue
|
||||
|
||||
# Tokenize
|
||||
hint = self.tokenizer(
|
||||
|
|
@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
"char_id": torch.tensor([char_info["id"]]),
|
||||
"char": char,
|
||||
"freq": char_info["freq"],
|
||||
"pg": torch.tensor(
|
||||
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||
),
|
||||
}
|
||||
|
||||
# 根据调整因子重复样本
|
||||
|
|
@ -470,7 +495,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
# 批量收集需要查询的字符信息
|
||||
chinese_positions = [
|
||||
i for i, char in enumerate(text)
|
||||
i
|
||||
for i, char in enumerate(text)
|
||||
if self.query_engine.is_chinese_char(char)
|
||||
]
|
||||
|
||||
|
|
@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
由于是流式数据集,无法预先知道长度
|
||||
|
|
@ -591,7 +616,7 @@ if __name__ == "__main__":
|
|||
|
||||
# 初始化查询引擎
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load("./pinyin_char_pairs_info.json")
|
||||
query_engine.load()
|
||||
|
||||
# 创建数据集
|
||||
dataset = PinyinInputDataset(
|
||||
|
|
@ -645,4 +670,3 @@ if __name__ == "__main__":
|
|||
"""
|
||||
except StopIteration:
|
||||
print("数据集为空")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
# file name: query_engine.py
|
||||
import json
|
||||
import pickle
|
||||
import msgpack
|
||||
import gzip
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import msgpack
|
||||
|
||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||
|
||||
|
|
@ -27,14 +30,14 @@ class QueryEngine:
|
|||
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||||
|
||||
# 核心索引 - 提供O(1)查询
|
||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
||||
|
||||
# 辅助索引 - 快速获取详细信息
|
||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||||
|
||||
# 统计信息
|
||||
|
|
@ -45,12 +48,17 @@ class QueryEngine:
|
|||
|
||||
self.min_count = min_count
|
||||
|
||||
def load(self, file_path: str) -> Dict[str, Any]:
|
||||
def load(
|
||||
self,
|
||||
file_path: Union[str, Path] = (
|
||||
files(__package__) / "data" / "pinyin_char_statistics.json"
|
||||
),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
加载统计结果文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩
|
||||
file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩
|
||||
|
||||
Returns:
|
||||
元数据字典
|
||||
|
|
@ -75,9 +83,9 @@ class QueryEngine:
|
|||
|
||||
return self._counter_data.metadata
|
||||
|
||||
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
||||
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
|
||||
"""解析文件,支持多种格式"""
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
# 尝试解压
|
||||
|
|
@ -88,9 +96,9 @@ class QueryEngine:
|
|||
|
||||
# 尝试不同格式
|
||||
for parser_name, parser in [
|
||||
('msgpack', self._parse_msgpack),
|
||||
('pickle', self._parse_pickle),
|
||||
('json', self._parse_json)
|
||||
("msgpack", self._parse_msgpack),
|
||||
("pickle", self._parse_pickle),
|
||||
("json", self._parse_json),
|
||||
]:
|
||||
try:
|
||||
return parser(data)
|
||||
|
|
@ -110,7 +118,7 @@ class QueryEngine:
|
|||
|
||||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析json格式"""
|
||||
data_str = data.decode('utf-8')
|
||||
data_str = data.decode("utf-8")
|
||||
data_dict = json.loads(data_str)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
|
|
@ -118,10 +126,10 @@ class QueryEngine:
|
|||
"""字典转PinyinCharPairsCounter"""
|
||||
# 转换CharInfo字典
|
||||
pairs_dict = {}
|
||||
if 'pairs' in data_dict and data_dict['pairs']:
|
||||
for id_str, info_dict in data_dict['pairs'].items():
|
||||
if "pairs" in data_dict and data_dict["pairs"]:
|
||||
for id_str, info_dict in data_dict["pairs"].items():
|
||||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||
data_dict['pairs'] = pairs_dict
|
||||
data_dict["pairs"] = pairs_dict
|
||||
|
||||
return PinyinCharPairsCounter(**data_dict)
|
||||
|
||||
|
|
@ -222,7 +230,9 @@ class QueryEngine:
|
|||
|
||||
return results
|
||||
|
||||
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||
def query_by_pinyin(
|
||||
self, pinyin: str, limit: int = 0
|
||||
) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||
|
||||
|
|
@ -303,7 +313,9 @@ class QueryEngine:
|
|||
|
||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||
|
||||
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
|
||||
def get_char_info_by_char_pinyin(
|
||||
self, char: str, pinyin: str
|
||||
) -> Optional[CharInfo]:
|
||||
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
|
|
@ -319,7 +331,9 @@ class QueryEngine:
|
|||
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||
return self.query_by_id(char_info_id)
|
||||
|
||||
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
|
||||
def batch_get_char_pinyin_info(
|
||||
self, pairs: List[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], CharInfo]:
|
||||
"""批量获取汉字-拼音信息
|
||||
|
||||
Args:
|
||||
|
|
@ -340,7 +354,6 @@ class QueryEngine:
|
|||
result[pair] = None
|
||||
return result
|
||||
|
||||
|
||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||
"""
|
||||
批量ID查询 - O(n)时间复杂度
|
||||
|
|
@ -360,7 +373,9 @@ class QueryEngine:
|
|||
|
||||
return results
|
||||
|
||||
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||
def batch_query_by_chars(
|
||||
self, chars: List[str], limit_per_char: int = 0
|
||||
) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||
"""
|
||||
批量字符查询
|
||||
|
||||
|
|
@ -380,7 +395,9 @@ class QueryEngine:
|
|||
|
||||
return results
|
||||
|
||||
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
|
||||
def search_chars_by_prefix(
|
||||
self, prefix: str, limit: int = 20
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||
|
||||
|
|
@ -409,7 +426,7 @@ class QueryEngine:
|
|||
判断是否是汉字
|
||||
"""
|
||||
if not self.is_loaded():
|
||||
raise ValueError('请先调用 load() 方法加载数据')
|
||||
raise ValueError("请先调用 load() 方法加载数据")
|
||||
return char in self._char_to_ids
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
|
|
@ -422,16 +439,12 @@ class QueryEngine:
|
|||
if not self._loaded:
|
||||
return {"status": "not_loaded"}
|
||||
|
||||
top_chars = sorted(
|
||||
self._char_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
|
||||
:10
|
||||
]
|
||||
|
||||
top_pinyins = sorted(
|
||||
self._pinyin_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
|
||||
)[:10]
|
||||
|
||||
return {
|
||||
|
|
@ -445,7 +458,7 @@ class QueryEngine:
|
|||
"index_time_seconds": self._index_time,
|
||||
"top_chars": top_chars,
|
||||
"top_pinyins": top_pinyins,
|
||||
"metadata": self._counter_data.metadata
|
||||
"metadata": self._counter_data.metadata,
|
||||
}
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue