Compare commits
No commits in common. "5b1c6fcb2b0b3556d6032fb7be082ba45f991d68" and "ea54c7da393c68c1e214541216fba47a410d1b96" have entirely different histories.
5b1c6fcb2b
...
ea54c7da39
|
|
@ -1,9 +1,8 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
from importlib.resources import files
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -13,8 +12,8 @@ from modelscope import AutoTokenizer
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
from torch.utils.data import DataLoader, IterableDataset
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
|
|
@ -47,7 +46,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
py_group_json_file: Optional[Dict[str, int]] = None,
|
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -98,32 +97,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||||
|
|
||||||
# 加载拼音分组
|
with open(py_group_json_file, "r") as f:
|
||||||
self.pg_groups = {
|
self.py_groups = json.load(f)
|
||||||
"y": 0,
|
|
||||||
"k": 0,
|
|
||||||
"e": 0,
|
|
||||||
"l": 1,
|
|
||||||
"w": 1,
|
|
||||||
"f": 1,
|
|
||||||
"q": 2,
|
|
||||||
"a": 2,
|
|
||||||
"s": 2,
|
|
||||||
"x": 3,
|
|
||||||
"b": 3,
|
|
||||||
"r": 3,
|
|
||||||
"o": 4,
|
|
||||||
"m": 4,
|
|
||||||
"z": 4,
|
|
||||||
"g": 5,
|
|
||||||
"n": 5,
|
|
||||||
"c": 5,
|
|
||||||
"t": 6,
|
|
||||||
"p": 6,
|
|
||||||
"d": 6,
|
|
||||||
"j": 7,
|
|
||||||
"h": 7,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self,
|
self,
|
||||||
|
|
@ -414,8 +390,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
char_info = char_info_map.get((char, pinyin))
|
char_info = char_info_map.get((char, pinyin))
|
||||||
if not char_info:
|
if not char_info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"获取字符信息: {char_info}")
|
logger.info(f'获取字符信息: {char_info}')
|
||||||
# 削峰填谷调整
|
# 削峰填谷调整
|
||||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||||
if adjust_factor <= 0:
|
if adjust_factor <= 0:
|
||||||
|
|
@ -426,6 +402,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 拼音处理
|
# 拼音处理
|
||||||
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||||||
|
if not processed_pinyin:
|
||||||
|
continue
|
||||||
|
|
||||||
# Tokenize
|
# Tokenize
|
||||||
hint = self.tokenizer(
|
hint = self.tokenizer(
|
||||||
|
|
@ -445,9 +423,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"char_id": torch.tensor([char_info["id"]]),
|
"char_id": torch.tensor([char_info["id"]]),
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
"pg": torch.tensor(
|
|
||||||
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据调整因子重复样本
|
# 根据调整因子重复样本
|
||||||
|
|
@ -461,7 +436,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if not self.shuffle:
|
if not self.shuffle:
|
||||||
yield from batch_samples
|
yield from batch_samples
|
||||||
return
|
return
|
||||||
|
|
||||||
# 使用numpy批量操作代替random.shuffle
|
# 使用numpy批量操作代替random.shuffle
|
||||||
if batch_samples:
|
if batch_samples:
|
||||||
indices = np.random.permutation(len(batch_samples))
|
indices = np.random.permutation(len(batch_samples))
|
||||||
|
|
@ -495,15 +470,14 @@ class PinyinInputDataset(IterableDataset):
|
||||||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
# 批量收集需要查询的字符信息
|
# 批量收集需要查询的字符信息
|
||||||
chinese_positions = [
|
chinese_positions = [
|
||||||
i
|
i for i, char in enumerate(text)
|
||||||
for i, char in enumerate(text)
|
|
||||||
if self.query_engine.is_chinese_char(char)
|
if self.query_engine.is_chinese_char(char)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 批量收集需要查询的字符信息
|
# 批量收集需要查询的字符信息
|
||||||
char_pinyin_batch = []
|
char_pinyin_batch = []
|
||||||
char_positions = []
|
char_positions = []
|
||||||
|
|
||||||
for i in chinese_positions: # 只遍历中文字符位置
|
for i in chinese_positions: # 只遍历中文字符位置
|
||||||
char = text[i]
|
char = text[i]
|
||||||
py = pinyin_list[i]
|
py = pinyin_list[i]
|
||||||
|
|
@ -544,6 +518,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
)
|
)
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
由于是流式数据集,无法预先知道长度
|
由于是流式数据集,无法预先知道长度
|
||||||
|
|
@ -616,7 +591,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# 初始化查询引擎
|
# 初始化查询引擎
|
||||||
query_engine = QueryEngine()
|
query_engine = QueryEngine()
|
||||||
query_engine.load()
|
query_engine.load("./pinyin_char_pairs_info.json")
|
||||||
|
|
||||||
# 创建数据集
|
# 创建数据集
|
||||||
dataset = PinyinInputDataset(
|
dataset = PinyinInputDataset(
|
||||||
|
|
@ -650,7 +625,7 @@ if __name__ == "__main__":
|
||||||
break
|
break
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
cProfile.run('profile_func(dataloader)')
|
cProfile.run('profile_func(dataloader)')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -670,3 +645,4 @@ if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print("数据集为空")
|
print("数据集为空")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,11 @@
|
||||||
# file name: query_engine.py
|
# file name: query_engine.py
|
||||||
import gzip
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
|
||||||
from importlib.resources import files
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
import gzip
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||||
|
|
||||||
|
|
@ -16,7 +13,7 @@ from .char_info import CharInfo, PinyinCharPairsCounter
|
||||||
class QueryEngine:
|
class QueryEngine:
|
||||||
"""
|
"""
|
||||||
高效拼音-字符查询引擎
|
高效拼音-字符查询引擎
|
||||||
|
|
||||||
特性:
|
特性:
|
||||||
1. O(1)时间复杂度的ID查询
|
1. O(1)时间复杂度的ID查询
|
||||||
2. O(1)时间复杂度的字符查询
|
2. O(1)时间复杂度的字符查询
|
||||||
|
|
@ -24,22 +21,22 @@ class QueryEngine:
|
||||||
4. 内存友好,构建高效索引
|
4. 内存友好,构建高效索引
|
||||||
5. 支持批量查询和前缀搜索
|
5. 支持批量查询和前缀搜索
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, min_count: int = 109):
|
def __init__(self, min_count: int = 109):
|
||||||
"""初始化查询引擎"""
|
"""初始化查询引擎"""
|
||||||
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||||||
|
|
||||||
# 核心索引 - 提供O(1)查询
|
# 核心索引 - 提供O(1)查询
|
||||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||||
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
||||||
|
|
||||||
# 辅助索引 - 快速获取详细信息
|
# 辅助索引 - 快速获取详细信息
|
||||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||||
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
self._total_pairs = 0
|
self._total_pairs = 0
|
||||||
|
|
@ -47,96 +44,91 @@ class QueryEngine:
|
||||||
self._index_time = 0.0
|
self._index_time = 0.0
|
||||||
|
|
||||||
self.min_count = min_count
|
self.min_count = min_count
|
||||||
|
|
||||||
def load(
|
def load(self, file_path: str) -> Dict[str, Any]:
|
||||||
self,
|
|
||||||
file_path: Union[str, Path] = (
|
|
||||||
files(__package__) / "data" / "pinyin_char_statistics.json"
|
|
||||||
),
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
加载统计结果文件
|
加载统计结果文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩
|
file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
元数据字典
|
元数据字典
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: 文件不存在
|
FileNotFoundError: 文件不存在
|
||||||
ValueError: 文件格式不支持
|
ValueError: 文件格式不支持
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 读取并解析文件
|
# 读取并解析文件
|
||||||
self._counter_data = self.parse_file(file_path)
|
self._counter_data = self.parse_file(file_path)
|
||||||
|
|
||||||
# 构建索引
|
# 构建索引
|
||||||
self._build_indices()
|
self._build_indices()
|
||||||
|
|
||||||
self._load_time = time.time() - start_time
|
self._load_time = time.time() - start_time
|
||||||
self._loaded = True
|
self._loaded = True
|
||||||
|
|
||||||
return self._counter_data.metadata
|
return self._counter_data.metadata
|
||||||
|
|
||||||
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
|
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
||||||
"""解析文件,支持多种格式"""
|
"""解析文件,支持多种格式"""
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, 'rb') as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
# 尝试解压
|
# 尝试解压
|
||||||
try:
|
try:
|
||||||
data = gzip.decompress(data)
|
data = gzip.decompress(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 尝试不同格式
|
# 尝试不同格式
|
||||||
for parser_name, parser in [
|
for parser_name, parser in [
|
||||||
("msgpack", self._parse_msgpack),
|
('msgpack', self._parse_msgpack),
|
||||||
("pickle", self._parse_pickle),
|
('pickle', self._parse_pickle),
|
||||||
("json", self._parse_json),
|
('json', self._parse_json)
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
return parser(data)
|
return parser(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise ValueError("无法解析文件格式")
|
raise ValueError("无法解析文件格式")
|
||||||
|
|
||||||
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
"""解析msgpack格式"""
|
"""解析msgpack格式"""
|
||||||
data_dict = msgpack.unpackb(data, raw=False)
|
data_dict = msgpack.unpackb(data, raw=False)
|
||||||
return self._dict_to_counter(data_dict)
|
return self._dict_to_counter(data_dict)
|
||||||
|
|
||||||
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
"""解析pickle格式"""
|
"""解析pickle格式"""
|
||||||
return pickle.loads(data)
|
return pickle.loads(data)
|
||||||
|
|
||||||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
"""解析json格式"""
|
"""解析json格式"""
|
||||||
data_str = data.decode("utf-8")
|
data_str = data.decode('utf-8')
|
||||||
data_dict = json.loads(data_str)
|
data_dict = json.loads(data_str)
|
||||||
return self._dict_to_counter(data_dict)
|
return self._dict_to_counter(data_dict)
|
||||||
|
|
||||||
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
||||||
"""字典转PinyinCharPairsCounter"""
|
"""字典转PinyinCharPairsCounter"""
|
||||||
# 转换CharInfo字典
|
# 转换CharInfo字典
|
||||||
pairs_dict = {}
|
pairs_dict = {}
|
||||||
if "pairs" in data_dict and data_dict["pairs"]:
|
if 'pairs' in data_dict and data_dict['pairs']:
|
||||||
for id_str, info_dict in data_dict["pairs"].items():
|
for id_str, info_dict in data_dict['pairs'].items():
|
||||||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||||
data_dict["pairs"] = pairs_dict
|
data_dict['pairs'] = pairs_dict
|
||||||
|
|
||||||
return PinyinCharPairsCounter(**data_dict)
|
return PinyinCharPairsCounter(**data_dict)
|
||||||
|
|
||||||
def _build_indices(self):
|
def _build_indices(self):
|
||||||
"""构建所有查询索引"""
|
"""构建所有查询索引"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 重置索引
|
# 重置索引
|
||||||
self._id_to_info.clear()
|
self._id_to_info.clear()
|
||||||
self._char_to_ids.clear()
|
self._char_to_ids.clear()
|
||||||
|
|
@ -145,13 +137,13 @@ class QueryEngine:
|
||||||
self._char_freq.clear()
|
self._char_freq.clear()
|
||||||
self._pinyin_freq.clear()
|
self._pinyin_freq.clear()
|
||||||
self._char_pinyin_map.clear()
|
self._char_pinyin_map.clear()
|
||||||
|
|
||||||
# 复制频率数据
|
# 复制频率数据
|
||||||
if self._counter_data.chars:
|
if self._counter_data.chars:
|
||||||
self._char_freq = self._counter_data.chars.copy()
|
self._char_freq = self._counter_data.chars.copy()
|
||||||
if self._counter_data.pinyins:
|
if self._counter_data.pinyins:
|
||||||
self._pinyin_freq = self._counter_data.pinyins.copy()
|
self._pinyin_freq = self._counter_data.pinyins.copy()
|
||||||
|
|
||||||
# 构建核心索引
|
# 构建核心索引
|
||||||
for char_info in self._counter_data.pairs.values():
|
for char_info in self._counter_data.pairs.values():
|
||||||
if char_info.count < self.min_count:
|
if char_info.count < self.min_count:
|
||||||
|
|
@ -159,169 +151,165 @@ class QueryEngine:
|
||||||
char = char_info.char
|
char = char_info.char
|
||||||
pinyin = char_info.pinyin
|
pinyin = char_info.pinyin
|
||||||
char_info_id = char_info.id
|
char_info_id = char_info.id
|
||||||
|
|
||||||
# ID索引
|
# ID索引
|
||||||
self._id_to_info[char_info_id] = char_info
|
self._id_to_info[char_info_id] = char_info
|
||||||
|
|
||||||
# 字符索引
|
# 字符索引
|
||||||
if char not in self._char_to_ids:
|
if char not in self._char_to_ids:
|
||||||
self._char_to_ids[char] = []
|
self._char_to_ids[char] = []
|
||||||
self._char_to_ids[char].append(char_info_id)
|
self._char_to_ids[char].append(char_info_id)
|
||||||
|
|
||||||
# 拼音索引
|
# 拼音索引
|
||||||
if pinyin not in self._pinyin_to_ids:
|
if pinyin not in self._pinyin_to_ids:
|
||||||
self._pinyin_to_ids[pinyin] = []
|
self._pinyin_to_ids[pinyin] = []
|
||||||
self._pinyin_to_ids[pinyin].append(char_info_id)
|
self._pinyin_to_ids[pinyin].append(char_info_id)
|
||||||
|
|
||||||
# 字符-拼音映射
|
# 字符-拼音映射
|
||||||
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||||||
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
|
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
|
||||||
|
|
||||||
self._total_pairs = len(self._id_to_info)
|
self._total_pairs = len(self._id_to_info)
|
||||||
self._index_time = time.time() - start_time
|
self._index_time = time.time() - start_time
|
||||||
|
|
||||||
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
||||||
"""
|
"""
|
||||||
通过ID查询字符信息 - O(1)时间复杂度
|
通过ID查询字符信息 - O(1)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id: 记录ID
|
id: 记录ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CharInfo对象,不存在则返回None
|
CharInfo对象,不存在则返回None
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
return self._id_to_info.get(id)
|
return self._id_to_info.get(id)
|
||||||
|
|
||||||
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||||
"""
|
"""
|
||||||
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
char: 汉字字符
|
char: 汉字字符
|
||||||
limit: 返回结果数量限制,0表示返回所有
|
limit: 返回结果数量限制,0表示返回所有
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
if char not in self._char_to_ids:
|
if char not in self._char_to_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 获取所有相关ID
|
# 获取所有相关ID
|
||||||
ids = self._char_to_ids[char]
|
ids = self._char_to_ids[char]
|
||||||
|
|
||||||
# 构建结果并排序
|
# 构建结果并排序
|
||||||
results = []
|
results = []
|
||||||
for char_info_id in ids:
|
for char_info_id in ids:
|
||||||
char_info = self._id_to_info[char_info_id]
|
char_info = self._id_to_info[char_info_id]
|
||||||
results.append((char_info_id, char_info.pinyin, char_info.count))
|
results.append((char_info_id, char_info.pinyin, char_info.count))
|
||||||
|
|
||||||
# 按次数降序排序
|
# 按次数降序排序
|
||||||
results.sort(key=lambda x: x[2], reverse=True)
|
results.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
# 应用限制
|
# 应用限制
|
||||||
if limit > 0 and len(results) > limit:
|
if limit > 0 and len(results) > limit:
|
||||||
results = results[:limit]
|
results = results[:limit]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def query_by_pinyin(
|
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||||
self, pinyin: str, limit: int = 0
|
|
||||||
) -> List[Tuple[int, str, int]]:
|
|
||||||
"""
|
"""
|
||||||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pinyin: 拼音字符串
|
pinyin: 拼音字符串
|
||||||
limit: 返回结果数量限制,0表示返回所有
|
limit: 返回结果数量限制,0表示返回所有
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
if pinyin not in self._pinyin_to_ids:
|
if pinyin not in self._pinyin_to_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 获取所有相关ID
|
# 获取所有相关ID
|
||||||
ids = self._pinyin_to_ids[pinyin]
|
ids = self._pinyin_to_ids[pinyin]
|
||||||
|
|
||||||
# 构建结果并排序
|
# 构建结果并排序
|
||||||
results = []
|
results = []
|
||||||
for char_info_id in ids:
|
for char_info_id in ids:
|
||||||
char_info = self._id_to_info[char_info_id]
|
char_info = self._id_to_info[char_info_id]
|
||||||
results.append((char_info_id, char_info.char, char_info.count))
|
results.append((char_info_id, char_info.char, char_info.count))
|
||||||
|
|
||||||
# 按次数降序排序
|
# 按次数降序排序
|
||||||
results.sort(key=lambda x: x[2], reverse=True)
|
results.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
# 应用限制
|
# 应用限制
|
||||||
if limit > 0 and len(results) > limit:
|
if limit > 0 and len(results) > limit:
|
||||||
results = results[:limit]
|
results = results[:limit]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_char_frequency(self, char: str) -> int:
|
def get_char_frequency(self, char: str) -> int:
|
||||||
"""
|
"""
|
||||||
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
char: 汉字字符
|
char: 汉字字符
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
总出现次数
|
总出现次数
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
return self._char_freq.get(char, 0)
|
return self._char_freq.get(char, 0)
|
||||||
|
|
||||||
def get_pinyin_frequency(self, pinyin: str) -> int:
|
def get_pinyin_frequency(self, pinyin: str) -> int:
|
||||||
"""
|
"""
|
||||||
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pinyin: 拼音字符串
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
总出现次数
|
总出现次数
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
return self._pinyin_freq.get(pinyin, 0)
|
return self._pinyin_freq.get(pinyin, 0)
|
||||||
|
|
||||||
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
||||||
"""
|
"""
|
||||||
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
char: 汉字字符
|
char: 汉字字符
|
||||||
pinyin: 拼音字符串
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
出现次数
|
出现次数
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||||
|
|
||||||
def get_char_info_by_char_pinyin(
|
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
|
||||||
self, char: str, pinyin: str
|
|
||||||
) -> Optional[CharInfo]:
|
|
||||||
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
char: 汉字字符
|
char: 汉字字符
|
||||||
pinyin: 拼音字符串
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ID和频率
|
ID和频率
|
||||||
"""
|
"""
|
||||||
|
|
@ -331,14 +319,12 @@ class QueryEngine:
|
||||||
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||||
return self.query_by_id(char_info_id)
|
return self.query_by_id(char_info_id)
|
||||||
|
|
||||||
def batch_get_char_pinyin_info(
|
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
|
||||||
self, pairs: List[Tuple[str, str]]
|
|
||||||
) -> Dict[Tuple[str, str], CharInfo]:
|
|
||||||
"""批量获取汉字-拼音信息
|
"""批量获取汉字-拼音信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pairs: 汉字-拼音列表
|
pairs: 汉字-拼音列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None)
|
字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None)
|
||||||
"""
|
"""
|
||||||
|
|
@ -353,72 +339,69 @@ class QueryEngine:
|
||||||
else:
|
else:
|
||||||
result[pair] = None
|
result[pair] = None
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||||
"""
|
"""
|
||||||
批量ID查询 - O(n)时间复杂度
|
批量ID查询 - O(n)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids: ID列表
|
ids: ID列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
字典,key为ID,value为CharInfo对象(不存在则为None)
|
字典,key为ID,value为CharInfo对象(不存在则为None)
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for id_value in ids:
|
for id_value in ids:
|
||||||
results[id_value] = self._id_to_info.get(id_value, None)
|
results[id_value] = self._id_to_info.get(id_value, None)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def batch_query_by_chars(
|
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||||
self, chars: List[str], limit_per_char: int = 0
|
|
||||||
) -> Dict[str, List[Tuple[int, str, int]]]:
|
|
||||||
"""
|
"""
|
||||||
批量字符查询
|
批量字符查询
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chars: 字符列表
|
chars: 字符列表
|
||||||
limit_per_char: 每个字符的结果数量限制
|
limit_per_char: 每个字符的结果数量限制
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
字典,key为字符,value为查询结果列表
|
字典,key为字符,value为查询结果列表
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for char in chars:
|
for char in chars:
|
||||||
results[char] = self.query_by_char(char, limit_per_char)
|
results[char] = self.query_by_char(char, limit_per_char)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_chars_by_prefix(
|
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
|
||||||
self, prefix: str, limit: int = 20
|
|
||||||
) -> List[Tuple[str, int]]:
|
|
||||||
"""
|
"""
|
||||||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prefix: 字符前缀
|
prefix: 字符前缀
|
||||||
limit: 返回结果数量限制
|
limit: 返回结果数量限制
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
列表,每个元素为(字符, 总频率),按频率降序排序
|
列表,每个元素为(字符, 总频率),按频率降序排序
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for char, freq in self._char_freq.items():
|
for char, freq in self._char_freq.items():
|
||||||
if char.startswith(prefix):
|
if char.startswith(prefix):
|
||||||
matches.append((char, freq))
|
matches.append((char, freq))
|
||||||
|
|
||||||
# 按频率降序排序
|
# 按频率降序排序
|
||||||
matches.sort(key=lambda x: x[1], reverse=True)
|
matches.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
return matches[:limit] if limit > 0 else matches
|
return matches[:limit] if limit > 0 else matches
|
||||||
|
|
||||||
def is_chinese_char(self, char: str) -> bool:
|
def is_chinese_char(self, char: str) -> bool:
|
||||||
|
|
@ -426,27 +409,31 @@ class QueryEngine:
|
||||||
判断是否是汉字
|
判断是否是汉字
|
||||||
"""
|
"""
|
||||||
if not self.is_loaded():
|
if not self.is_loaded():
|
||||||
raise ValueError("请先调用 load() 方法加载数据")
|
raise ValueError('请先调用 load() 方法加载数据')
|
||||||
return char in self._char_to_ids
|
return char in self._char_to_ids
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取系统统计信息
|
获取系统统计信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
统计信息字典
|
统计信息字典
|
||||||
"""
|
"""
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
return {"status": "not_loaded"}
|
return {"status": "not_loaded"}
|
||||||
|
|
||||||
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
|
top_chars = sorted(
|
||||||
:10
|
self._char_freq.items(),
|
||||||
]
|
key=lambda x: x[1],
|
||||||
|
reverse=True
|
||||||
top_pinyins = sorted(
|
|
||||||
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
|
|
||||||
)[:10]
|
)[:10]
|
||||||
|
|
||||||
|
top_pinyins = sorted(
|
||||||
|
self._pinyin_freq.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True
|
||||||
|
)[:10]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "loaded",
|
"status": "loaded",
|
||||||
"timestamp": self._counter_data.timestamp,
|
"timestamp": self._counter_data.timestamp,
|
||||||
|
|
@ -458,13 +445,13 @@ class QueryEngine:
|
||||||
"index_time_seconds": self._index_time,
|
"index_time_seconds": self._index_time,
|
||||||
"top_chars": top_chars,
|
"top_chars": top_chars,
|
||||||
"top_pinyins": top_pinyins,
|
"top_pinyins": top_pinyins,
|
||||||
"metadata": self._counter_data.metadata,
|
"metadata": self._counter_data.metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
"""检查数据是否已加载"""
|
"""检查数据是否已加载"""
|
||||||
return self._loaded
|
return self._loaded
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""清除所有数据和索引,释放内存"""
|
"""清除所有数据和索引,释放内存"""
|
||||||
self._counter_data = None
|
self._counter_data = None
|
||||||
|
|
@ -478,4 +465,4 @@ class QueryEngine:
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
self._total_pairs = 0
|
self._total_pairs = 0
|
||||||
self._load_time = 0.0
|
self._load_time = 0.0
|
||||||
self._index_time = 0.0
|
self._index_time = 0.0
|
||||||
Loading…
Reference in New Issue