SUInput/src/suinput/query.py

414 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# file name: query_engine.py
import json
import pickle
import msgpack
import gzip
from typing import Dict, List, Optional, Tuple, Any
import time
import os
from .char_info import CharInfo, PinyinCharPairsCounter
class QueryEngine:
"""
高效拼音-字符查询引擎
特性:
1. O(1)时间复杂度的ID查询
2. O(1)时间复杂度的字符查询
3. O(1)时间复杂度的拼音查询
4. 内存友好,构建高效索引
5. 支持批量查询和前缀搜索
"""
def __init__(self):
"""初始化查询引擎"""
self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
# 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0
def load(self, file_path: str) -> Dict[str, Any]:
"""
加载统计结果文件
Args:
file_path: 文件路径支持msgpack/pickle/json格式自动检测压缩
Returns:
元数据字典
Raises:
FileNotFoundError: 文件不存在
ValueError: 文件格式不支持
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
start_time = time.time()
# 读取并解析文件
self._counter_data = self._parse_file(file_path)
# 构建索引
self._build_indices()
self._load_time = time.time() - start_time
self._loaded = True
return self._counter_data.metadata
def _parse_file(self, file_path: str) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式"""
with open(file_path, 'rb') as f:
data = f.read()
# 尝试解压
try:
data = gzip.decompress(data)
except Exception:
pass
# 尝试不同格式
for parser_name, parser in [
('msgpack', self._parse_msgpack),
('pickle', self._parse_pickle),
('json', self._parse_json)
]:
try:
return parser(data)
except Exception:
continue
raise ValueError("无法解析文件格式")
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
"""解析msgpack格式"""
data_dict = msgpack.unpackb(data, raw=False)
return self._dict_to_counter(data_dict)
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
"""解析pickle格式"""
return pickle.loads(data)
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式"""
data_str = data.decode('utf-8')
data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict)
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
"""字典转PinyinCharPairsCounter"""
# 转换CharInfo字典
pairs_dict = {}
if 'pairs' in data_dict and data_dict['pairs']:
for id_str, info_dict in data_dict['pairs'].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict['pairs'] = pairs_dict
return PinyinCharPairsCounter(**data_dict)
def _build_indices(self):
"""构建所有查询索引"""
start_time = time.time()
# 重置索引
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
# 复制频率数据
if self._counter_data.chars:
self._char_freq = self._counter_data.chars.copy()
if self._counter_data.pinyins:
self._pinyin_freq = self._counter_data.pinyins.copy()
# 构建核心索引
for char_info in self._counter_data.pairs.values():
char = char_info.char
pinyin = char_info.pinyin
char_info_id = char_info.id
# ID索引
self._id_to_info[char_info_id] = char_info
# 字符索引
if char not in self._char_to_ids:
self._char_to_ids[char] = []
self._char_to_ids[char].append(char_info_id)
# 拼音索引
if pinyin not in self._pinyin_to_ids:
self._pinyin_to_ids[pinyin] = []
self._pinyin_to_ids[pinyin].append(char_info_id)
# 字符-拼音映射
self._char_pinyin_map[(char, pinyin)] = char_info.count
self._total_pairs = len(self._id_to_info)
self._index_time = time.time() - start_time
def query_by_id(self, id: int) -> Optional[CharInfo]:
"""
通过ID查询字符信息 - O(1)时间复杂度
Args:
id: 记录ID
Returns:
CharInfo对象不存在则返回None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._id_to_info.get(id)
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
"""
通过字符查询拼音信息 - O(1) + O(k)时间复杂度k为结果数
Args:
char: 汉字字符
limit: 返回结果数量限制0表示返回所有
Returns:
列表,每个元素为(id, 拼音, 次数),按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if char not in self._char_to_ids:
return []
# 获取所有相关ID
ids = self._char_to_ids[char]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.pinyin, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
"""
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
Args:
pinyin: 拼音字符串
limit: 返回结果数量限制0表示返回所有
Returns:
列表,每个元素为(id, 字符, 次数),按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if pinyin not in self._pinyin_to_ids:
return []
# 获取所有相关ID
ids = self._pinyin_to_ids[pinyin]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.char, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def get_char_frequency(self, char: str) -> int:
"""
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
Args:
char: 汉字字符
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_freq.get(char, 0)
def get_pinyin_frequency(self, pinyin: str) -> int:
"""
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
Args:
pinyin: 拼音字符串
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._pinyin_freq.get(pinyin, 0)
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
"""
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_pinyin_map.get((char, pinyin), 0)
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
"""
批量ID查询 - O(n)时间复杂度
Args:
ids: ID列表
Returns:
字典key为IDvalue为CharInfo对象不存在则为None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for id_value in ids:
results[id_value] = self._id_to_info.get(id_value)
return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
"""
批量字符查询
Args:
chars: 字符列表
limit_per_char: 每个字符的结果数量限制
Returns:
字典key为字符value为查询结果列表
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for char in chars:
results[char] = self.query_by_char(char, limit_per_char)
return results
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
"""
根据字符前缀搜索 - O(n)时间复杂度n为字符总数
Args:
prefix: 字符前缀
limit: 返回结果数量限制
Returns:
列表,每个元素为(字符, 总频率),按频率降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
matches = []
for char, freq in self._char_freq.items():
if char.startswith(prefix):
matches.append((char, freq))
# 按频率降序排序
matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit] if limit > 0 else matches
def get_statistics(self) -> Dict[str, Any]:
"""
获取系统统计信息
Returns:
统计信息字典
"""
if not self._loaded:
return {"status": "not_loaded"}
top_chars = sorted(
self._char_freq.items(),
key=lambda x: x[1],
reverse=True
)[:10]
top_pinyins = sorted(
self._pinyin_freq.items(),
key=lambda x: x[1],
reverse=True
)[:10]
return {
"status": "loaded",
"timestamp": self._counter_data.timestamp,
"total_pairs": self._total_pairs,
"total_characters": len(self._char_freq),
"total_pinyins": len(self._pinyin_freq),
"valid_input_character_count": self._counter_data.valid_input_character_count,
"load_time_seconds": self._load_time,
"index_time_seconds": self._index_time,
"top_chars": top_chars,
"top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata
}
def is_loaded(self) -> bool:
"""检查数据是否已加载"""
return self._loaded
def clear(self):
"""清除所有数据和索引,释放内存"""
self._counter_data = None
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0