Compare commits

...

2 Commits

7 changed files with 186 additions and 149 deletions

View File

@ -1,8 +1,9 @@
import random
from typing import Any, Dict, List, Tuple
import os
import json import json
import os
import random
from importlib.resources import files
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -12,9 +13,9 @@ from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PinyinInputDataset(IterableDataset): class PinyinInputDataset(IterableDataset):
""" """
拼音输入法模拟数据集 拼音输入法模拟数据集
@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值 repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率 max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望 max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json", py_group_json_file: Optional[Dict[str, int]] = None,
): ):
""" """
初始化数据集 初始化数据集
@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset):
# 加载数据集 # 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True) self.dataset = load_dataset(data_dir, split="train", streaming=True)
with open(py_group_json_file, "r") as f: # 加载拼音分组
self.py_groups = json.load(f) self.pg_groups = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
def get_next_chinese_chars( def get_next_chinese_chars(
self, self,
@ -390,8 +414,8 @@ class PinyinInputDataset(IterableDataset):
char_info = char_info_map.get((char, pinyin)) char_info = char_info_map.get((char, pinyin))
if not char_info: if not char_info:
continue continue
logger.info(f'获取字符信息: {char_info}') logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整 # 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"]) adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0: if adjust_factor <= 0:
@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset):
# 拼音处理 # 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins) processed_pinyin = self.process_pinyin_sequence(next_pinyins)
if not processed_pinyin:
continue
# Tokenize # Tokenize
hint = self.tokenizer( hint = self.tokenizer(
@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset):
"char_id": torch.tensor([char_info["id"]]), "char_id": torch.tensor([char_info["id"]]),
"char": char, "char": char,
"freq": char_info["freq"], "freq": char_info["freq"],
"pg": torch.tensor(
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
),
} }
# 根据调整因子重复样本 # 根据调整因子重复样本
@ -436,7 +461,7 @@ class PinyinInputDataset(IterableDataset):
if not self.shuffle: if not self.shuffle:
yield from batch_samples yield from batch_samples
return return
# 使用numpy批量操作代替random.shuffle # 使用numpy批量操作代替random.shuffle
if batch_samples: if batch_samples:
indices = np.random.permutation(len(batch_samples)) indices = np.random.permutation(len(batch_samples))
@ -470,14 +495,15 @@ class PinyinInputDataset(IterableDataset):
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息 # 批量收集需要查询的字符信息
chinese_positions = [ chinese_positions = [
i for i, char in enumerate(text) i
for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char) if self.query_engine.is_chinese_char(char)
] ]
# 批量收集需要查询的字符信息 # 批量收集需要查询的字符信息
char_pinyin_batch = [] char_pinyin_batch = []
char_positions = [] char_positions = []
for i in chinese_positions: # 只遍历中文字符位置 for i in chinese_positions: # 只遍历中文字符位置
char = text[i] char = text[i]
py = pinyin_list[i] py = pinyin_list[i]
@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset):
) )
yield from self._shuffle_and_yield(batch_samples) yield from self._shuffle_and_yield(batch_samples)
def __len__(self): def __len__(self):
""" """
由于是流式数据集无法预先知道长度 由于是流式数据集无法预先知道长度
@ -591,7 +616,7 @@ if __name__ == "__main__":
# 初始化查询引擎 # 初始化查询引擎
query_engine = QueryEngine() query_engine = QueryEngine()
query_engine.load("./pinyin_char_pairs_info.json") query_engine.load()
# 创建数据集 # 创建数据集
dataset = PinyinInputDataset( dataset = PinyinInputDataset(
@ -625,7 +650,7 @@ if __name__ == "__main__":
break break
return return
cProfile.run('profile_func(dataloader)') cProfile.run('profile_func(dataloader)')
""" """
@ -645,4 +670,3 @@ if __name__ == "__main__":
""" """
except StopIteration: except StopIteration:
print("数据集为空") print("数据集为空")

View File

@ -1,11 +1,14 @@
# file name: query_engine.py # file name: query_engine.py
import json
import pickle
import msgpack
import gzip import gzip
from typing import Dict, List, Optional, Tuple, Any import json
import time
import os import os
import pickle
import time
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import msgpack
from .char_info import CharInfo, PinyinCharPairsCounter from .char_info import CharInfo, PinyinCharPairsCounter
@ -13,7 +16,7 @@ from .char_info import CharInfo, PinyinCharPairsCounter
class QueryEngine: class QueryEngine:
""" """
高效拼音-字符查询引擎 高效拼音-字符查询引擎
特性: 特性:
1. O(1)时间复杂度的ID查询 1. O(1)时间复杂度的ID查询
2. O(1)时间复杂度的字符查询 2. O(1)时间复杂度的字符查询
@ -21,22 +24,22 @@ class QueryEngine:
4. 内存友好构建高效索引 4. 内存友好构建高效索引
5. 支持批量查询和前缀搜索 5. 支持批量查询和前缀搜索
""" """
def __init__(self, min_count: int = 109): def __init__(self, min_count: int = 109):
"""初始化查询引擎""" """初始化查询引擎"""
self._counter_data: Optional[PinyinCharPairsCounter] = None self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询 # 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {} self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
# 辅助索引 - 快速获取详细信息 # 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率 self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率 self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息 # 统计信息
self._loaded = False self._loaded = False
self._total_pairs = 0 self._total_pairs = 0
@ -44,91 +47,96 @@ class QueryEngine:
self._index_time = 0.0 self._index_time = 0.0
self.min_count = min_count self.min_count = min_count
def load(self, file_path: str) -> Dict[str, Any]: def load(
self,
file_path: Union[str, Path] = (
files(__package__) / "data" / "pinyin_char_statistics.json"
),
) -> Dict[str, Any]:
""" """
加载统计结果文件 加载统计结果文件
Args: Args:
file_path: 文件路径支持msgpack/pickle/json格式自动检测压缩 file_path: 文件路径文件支持msgpack/pickle/json格式自动检测压缩
Returns: Returns:
元数据字典 元数据字典
Raises: Raises:
FileNotFoundError: 文件不存在 FileNotFoundError: 文件不存在
ValueError: 文件格式不支持 ValueError: 文件格式不支持
""" """
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}") raise FileNotFoundError(f"文件不存在: {file_path}")
start_time = time.time() start_time = time.time()
# 读取并解析文件 # 读取并解析文件
self._counter_data = self.parse_file(file_path) self._counter_data = self.parse_file(file_path)
# 构建索引 # 构建索引
self._build_indices() self._build_indices()
self._load_time = time.time() - start_time self._load_time = time.time() - start_time
self._loaded = True self._loaded = True
return self._counter_data.metadata return self._counter_data.metadata
def parse_file(self, file_path: str) -> PinyinCharPairsCounter: def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式""" """解析文件,支持多种格式"""
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
data = f.read() data = f.read()
# 尝试解压 # 尝试解压
try: try:
data = gzip.decompress(data) data = gzip.decompress(data)
except Exception: except Exception:
pass pass
# 尝试不同格式 # 尝试不同格式
for parser_name, parser in [ for parser_name, parser in [
('msgpack', self._parse_msgpack), ("msgpack", self._parse_msgpack),
('pickle', self._parse_pickle), ("pickle", self._parse_pickle),
('json', self._parse_json) ("json", self._parse_json),
]: ]:
try: try:
return parser(data) return parser(data)
except Exception: except Exception:
continue continue
raise ValueError("无法解析文件格式") raise ValueError("无法解析文件格式")
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter: def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
"""解析msgpack格式""" """解析msgpack格式"""
data_dict = msgpack.unpackb(data, raw=False) data_dict = msgpack.unpackb(data, raw=False)
return self._dict_to_counter(data_dict) return self._dict_to_counter(data_dict)
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter: def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
"""解析pickle格式""" """解析pickle格式"""
return pickle.loads(data) return pickle.loads(data)
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter: def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式""" """解析json格式"""
data_str = data.decode('utf-8') data_str = data.decode("utf-8")
data_dict = json.loads(data_str) data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict) return self._dict_to_counter(data_dict)
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter: def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
"""字典转PinyinCharPairsCounter""" """字典转PinyinCharPairsCounter"""
# 转换CharInfo字典 # 转换CharInfo字典
pairs_dict = {} pairs_dict = {}
if 'pairs' in data_dict and data_dict['pairs']: if "pairs" in data_dict and data_dict["pairs"]:
for id_str, info_dict in data_dict['pairs'].items(): for id_str, info_dict in data_dict["pairs"].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict) pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict['pairs'] = pairs_dict data_dict["pairs"] = pairs_dict
return PinyinCharPairsCounter(**data_dict) return PinyinCharPairsCounter(**data_dict)
def _build_indices(self): def _build_indices(self):
"""构建所有查询索引""" """构建所有查询索引"""
start_time = time.time() start_time = time.time()
# 重置索引 # 重置索引
self._id_to_info.clear() self._id_to_info.clear()
self._char_to_ids.clear() self._char_to_ids.clear()
@ -137,13 +145,13 @@ class QueryEngine:
self._char_freq.clear() self._char_freq.clear()
self._pinyin_freq.clear() self._pinyin_freq.clear()
self._char_pinyin_map.clear() self._char_pinyin_map.clear()
# 复制频率数据 # 复制频率数据
if self._counter_data.chars: if self._counter_data.chars:
self._char_freq = self._counter_data.chars.copy() self._char_freq = self._counter_data.chars.copy()
if self._counter_data.pinyins: if self._counter_data.pinyins:
self._pinyin_freq = self._counter_data.pinyins.copy() self._pinyin_freq = self._counter_data.pinyins.copy()
# 构建核心索引 # 构建核心索引
for char_info in self._counter_data.pairs.values(): for char_info in self._counter_data.pairs.values():
if char_info.count < self.min_count: if char_info.count < self.min_count:
@ -151,165 +159,169 @@ class QueryEngine:
char = char_info.char char = char_info.char
pinyin = char_info.pinyin pinyin = char_info.pinyin
char_info_id = char_info.id char_info_id = char_info.id
# ID索引 # ID索引
self._id_to_info[char_info_id] = char_info self._id_to_info[char_info_id] = char_info
# 字符索引 # 字符索引
if char not in self._char_to_ids: if char not in self._char_to_ids:
self._char_to_ids[char] = [] self._char_to_ids[char] = []
self._char_to_ids[char].append(char_info_id) self._char_to_ids[char].append(char_info_id)
# 拼音索引 # 拼音索引
if pinyin not in self._pinyin_to_ids: if pinyin not in self._pinyin_to_ids:
self._pinyin_to_ids[pinyin] = [] self._pinyin_to_ids[pinyin] = []
self._pinyin_to_ids[pinyin].append(char_info_id) self._pinyin_to_ids[pinyin].append(char_info_id)
# 字符-拼音映射 # 字符-拼音映射
self._char_pinyin_map[(char, pinyin)] = char_info.count self._char_pinyin_map[(char, pinyin)] = char_info.count
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
self._total_pairs = len(self._id_to_info) self._total_pairs = len(self._id_to_info)
self._index_time = time.time() - start_time self._index_time = time.time() - start_time
def query_by_id(self, id: int) -> Optional[CharInfo]: def query_by_id(self, id: int) -> Optional[CharInfo]:
""" """
通过ID查询字符信息 - O(1)时间复杂度 通过ID查询字符信息 - O(1)时间复杂度
Args: Args:
id: 记录ID id: 记录ID
Returns: Returns:
CharInfo对象不存在则返回None CharInfo对象不存在则返回None
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
return self._id_to_info.get(id) return self._id_to_info.get(id)
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]: def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
""" """
通过字符查询拼音信息 - O(1) + O(k)时间复杂度k为结果数 通过字符查询拼音信息 - O(1) + O(k)时间复杂度k为结果数
Args: Args:
char: 汉字字符 char: 汉字字符
limit: 返回结果数量限制0表示返回所有 limit: 返回结果数量限制0表示返回所有
Returns: Returns:
列表每个元素为(id, 拼音, 次数)按次数降序排序 列表每个元素为(id, 拼音, 次数)按次数降序排序
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
if char not in self._char_to_ids: if char not in self._char_to_ids:
return [] return []
# 获取所有相关ID # 获取所有相关ID
ids = self._char_to_ids[char] ids = self._char_to_ids[char]
# 构建结果并排序 # 构建结果并排序
results = [] results = []
for char_info_id in ids: for char_info_id in ids:
char_info = self._id_to_info[char_info_id] char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.pinyin, char_info.count)) results.append((char_info_id, char_info.pinyin, char_info.count))
# 按次数降序排序 # 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True) results.sort(key=lambda x: x[2], reverse=True)
# 应用限制 # 应用限制
if limit > 0 and len(results) > limit: if limit > 0 and len(results) > limit:
results = results[:limit] results = results[:limit]
return results return results
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]: def query_by_pinyin(
self, pinyin: str, limit: int = 0
) -> List[Tuple[int, str, int]]:
""" """
通过拼音查询字符信息 - O(1) + O(k)时间复杂度 通过拼音查询字符信息 - O(1) + O(k)时间复杂度
Args: Args:
pinyin: 拼音字符串 pinyin: 拼音字符串
limit: 返回结果数量限制0表示返回所有 limit: 返回结果数量限制0表示返回所有
Returns: Returns:
列表每个元素为(id, 字符, 次数)按次数降序排序 列表每个元素为(id, 字符, 次数)按次数降序排序
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
if pinyin not in self._pinyin_to_ids: if pinyin not in self._pinyin_to_ids:
return [] return []
# 获取所有相关ID # 获取所有相关ID
ids = self._pinyin_to_ids[pinyin] ids = self._pinyin_to_ids[pinyin]
# 构建结果并排序 # 构建结果并排序
results = [] results = []
for char_info_id in ids: for char_info_id in ids:
char_info = self._id_to_info[char_info_id] char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.char, char_info.count)) results.append((char_info_id, char_info.char, char_info.count))
# 按次数降序排序 # 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True) results.sort(key=lambda x: x[2], reverse=True)
# 应用限制 # 应用限制
if limit > 0 and len(results) > limit: if limit > 0 and len(results) > limit:
results = results[:limit] results = results[:limit]
return results return results
def get_char_frequency(self, char: str) -> int: def get_char_frequency(self, char: str) -> int:
""" """
获取字符的总出现频率所有拼音变体之和 - O(1)时间复杂度 获取字符的总出现频率所有拼音变体之和 - O(1)时间复杂度
Args: Args:
char: 汉字字符 char: 汉字字符
Returns: Returns:
总出现次数 总出现次数
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
return self._char_freq.get(char, 0) return self._char_freq.get(char, 0)
def get_pinyin_frequency(self, pinyin: str) -> int: def get_pinyin_frequency(self, pinyin: str) -> int:
""" """
获取拼音的总出现频率所有字符之和 - O(1)时间复杂度 获取拼音的总出现频率所有字符之和 - O(1)时间复杂度
Args: Args:
pinyin: 拼音字符串 pinyin: 拼音字符串
Returns: Returns:
总出现次数 总出现次数
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
return self._pinyin_freq.get(pinyin, 0) return self._pinyin_freq.get(pinyin, 0)
def get_char_pinyin_count(self, char: str, pinyin: str) -> int: def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
""" """
获取特定字符-拼音对的出现次数 - O(1)时间复杂度 获取特定字符-拼音对的出现次数 - O(1)时间复杂度
Args: Args:
char: 汉字字符 char: 汉字字符
pinyin: 拼音字符串 pinyin: 拼音字符串
Returns: Returns:
出现次数 出现次数
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
return self._char_pinyin_map.get((char, pinyin), 0) return self._char_pinyin_map.get((char, pinyin), 0)
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]: def get_char_info_by_char_pinyin(
self, char: str, pinyin: str
) -> Optional[CharInfo]:
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度 """获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
Args: Args:
char: 汉字字符 char: 汉字字符
pinyin: 拼音字符串 pinyin: 拼音字符串
Returns: Returns:
ID和频率 ID和频率
""" """
@ -319,12 +331,14 @@ class QueryEngine:
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None) char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
return self.query_by_id(char_info_id) return self.query_by_id(char_info_id)
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]: def batch_get_char_pinyin_info(
self, pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], CharInfo]:
"""批量获取汉字-拼音信息 """批量获取汉字-拼音信息
Args: Args:
pairs: 汉字-拼音列表 pairs: 汉字-拼音列表
Returns: Returns:
字典key为汉字-拼音对value为CharInfo对象不存在则为None 字典key为汉字-拼音对value为CharInfo对象不存在则为None
""" """
@ -339,69 +353,72 @@ class QueryEngine:
else: else:
result[pair] = None result[pair] = None
return result return result
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]: def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
""" """
批量ID查询 - O(n)时间复杂度 批量ID查询 - O(n)时间复杂度
Args: Args:
ids: ID列表 ids: ID列表
Returns: Returns:
字典key为IDvalue为CharInfo对象不存在则为None 字典key为IDvalue为CharInfo对象不存在则为None
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
results = {} results = {}
for id_value in ids: for id_value in ids:
results[id_value] = self._id_to_info.get(id_value, None) results[id_value] = self._id_to_info.get(id_value, None)
return results return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]: def batch_query_by_chars(
self, chars: List[str], limit_per_char: int = 0
) -> Dict[str, List[Tuple[int, str, int]]]:
""" """
批量字符查询 批量字符查询
Args: Args:
chars: 字符列表 chars: 字符列表
limit_per_char: 每个字符的结果数量限制 limit_per_char: 每个字符的结果数量限制
Returns: Returns:
字典key为字符value为查询结果列表 字典key为字符value为查询结果列表
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
results = {} results = {}
for char in chars: for char in chars:
results[char] = self.query_by_char(char, limit_per_char) results[char] = self.query_by_char(char, limit_per_char)
return results return results
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]: def search_chars_by_prefix(
self, prefix: str, limit: int = 20
) -> List[Tuple[str, int]]:
""" """
根据字符前缀搜索 - O(n)时间复杂度n为字符总数 根据字符前缀搜索 - O(n)时间复杂度n为字符总数
Args: Args:
prefix: 字符前缀 prefix: 字符前缀
limit: 返回结果数量限制 limit: 返回结果数量限制
Returns: Returns:
列表每个元素为(字符, 总频率)按频率降序排序 列表每个元素为(字符, 总频率)按频率降序排序
""" """
if not self._loaded: if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法") raise RuntimeError("数据未加载请先调用load()方法")
matches = [] matches = []
for char, freq in self._char_freq.items(): for char, freq in self._char_freq.items():
if char.startswith(prefix): if char.startswith(prefix):
matches.append((char, freq)) matches.append((char, freq))
# 按频率降序排序 # 按频率降序排序
matches.sort(key=lambda x: x[1], reverse=True) matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit] if limit > 0 else matches return matches[:limit] if limit > 0 else matches
def is_chinese_char(self, char: str) -> bool: def is_chinese_char(self, char: str) -> bool:
@ -409,31 +426,27 @@ class QueryEngine:
判断是否是汉字 判断是否是汉字
""" """
if not self.is_loaded(): if not self.is_loaded():
raise ValueError('请先调用 load() 方法加载数据') raise ValueError("请先调用 load() 方法加载数据")
return char in self._char_to_ids return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]: def get_statistics(self) -> Dict[str, Any]:
""" """
获取系统统计信息 获取系统统计信息
Returns: Returns:
统计信息字典 统计信息字典
""" """
if not self._loaded: if not self._loaded:
return {"status": "not_loaded"} return {"status": "not_loaded"}
top_chars = sorted( top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
self._char_freq.items(), :10
key=lambda x: x[1], ]
reverse=True
)[:10]
top_pinyins = sorted( top_pinyins = sorted(
self._pinyin_freq.items(), self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
key=lambda x: x[1],
reverse=True
)[:10] )[:10]
return { return {
"status": "loaded", "status": "loaded",
"timestamp": self._counter_data.timestamp, "timestamp": self._counter_data.timestamp,
@ -445,13 +458,13 @@ class QueryEngine:
"index_time_seconds": self._index_time, "index_time_seconds": self._index_time,
"top_chars": top_chars, "top_chars": top_chars,
"top_pinyins": top_pinyins, "top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata "metadata": self._counter_data.metadata,
} }
def is_loaded(self) -> bool: def is_loaded(self) -> bool:
"""检查数据是否已加载""" """检查数据是否已加载"""
return self._loaded return self._loaded
def clear(self): def clear(self):
"""清除所有数据和索引,释放内存""" """清除所有数据和索引,释放内存"""
self._counter_data = None self._counter_data = None
@ -465,4 +478,4 @@ class QueryEngine:
self._loaded = False self._loaded = False
self._total_pairs = 0 self._total_pairs = 0
self._load_time = 0.0 self._load_time = 0.0
self._index_time = 0.0 self._index_time = 0.0