414 lines
13 KiB
Python
414 lines
13 KiB
Python
# file name: query_engine.py
|
||
import json
|
||
import pickle
|
||
import msgpack
|
||
import gzip
|
||
from typing import Dict, List, Optional, Tuple, Any
|
||
import time
|
||
import os
|
||
|
||
from .char_info import CharInfo, PinyinCharPairsCounter
|
||
|
||
|
||
class QueryEngine:
|
||
"""
|
||
高效拼音-字符查询引擎
|
||
|
||
特性:
|
||
1. O(1)时间复杂度的ID查询
|
||
2. O(1)时间复杂度的字符查询
|
||
3. O(1)时间复杂度的拼音查询
|
||
4. 内存友好,构建高效索引
|
||
5. 支持批量查询和前缀搜索
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化查询引擎"""
|
||
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||
|
||
# 核心索引 - 提供O(1)查询
|
||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||
|
||
# 辅助索引 - 快速获取详细信息
|
||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||
|
||
# 统计信息
|
||
self._loaded = False
|
||
self._total_pairs = 0
|
||
self._load_time = 0.0
|
||
self._index_time = 0.0
|
||
|
||
def load(self, file_path: str) -> Dict[str, Any]:
|
||
"""
|
||
加载统计结果文件
|
||
|
||
Args:
|
||
file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩
|
||
|
||
Returns:
|
||
元数据字典
|
||
|
||
Raises:
|
||
FileNotFoundError: 文件不存在
|
||
ValueError: 文件格式不支持
|
||
"""
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
start_time = time.time()
|
||
|
||
# 读取并解析文件
|
||
self._counter_data = self._parse_file(file_path)
|
||
|
||
# 构建索引
|
||
self._build_indices()
|
||
|
||
self._load_time = time.time() - start_time
|
||
self._loaded = True
|
||
|
||
return self._counter_data.metadata
|
||
|
||
def _parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
||
"""解析文件,支持多种格式"""
|
||
with open(file_path, 'rb') as f:
|
||
data = f.read()
|
||
|
||
# 尝试解压
|
||
try:
|
||
data = gzip.decompress(data)
|
||
except Exception:
|
||
pass
|
||
|
||
# 尝试不同格式
|
||
for parser_name, parser in [
|
||
('msgpack', self._parse_msgpack),
|
||
('pickle', self._parse_pickle),
|
||
('json', self._parse_json)
|
||
]:
|
||
try:
|
||
return parser(data)
|
||
except Exception:
|
||
continue
|
||
|
||
raise ValueError("无法解析文件格式")
|
||
|
||
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
||
"""解析msgpack格式"""
|
||
data_dict = msgpack.unpackb(data, raw=False)
|
||
return self._dict_to_counter(data_dict)
|
||
|
||
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
||
"""解析pickle格式"""
|
||
return pickle.loads(data)
|
||
|
||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||
"""解析json格式"""
|
||
data_str = data.decode('utf-8')
|
||
data_dict = json.loads(data_str)
|
||
return self._dict_to_counter(data_dict)
|
||
|
||
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
||
"""字典转PinyinCharPairsCounter"""
|
||
# 转换CharInfo字典
|
||
pairs_dict = {}
|
||
if 'pairs' in data_dict and data_dict['pairs']:
|
||
for id_str, info_dict in data_dict['pairs'].items():
|
||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||
data_dict['pairs'] = pairs_dict
|
||
|
||
return PinyinCharPairsCounter(**data_dict)
|
||
|
||
def _build_indices(self):
|
||
"""构建所有查询索引"""
|
||
start_time = time.time()
|
||
|
||
# 重置索引
|
||
self._id_to_info.clear()
|
||
self._char_to_ids.clear()
|
||
self._pinyin_to_ids.clear()
|
||
self._char_freq.clear()
|
||
self._pinyin_freq.clear()
|
||
self._char_pinyin_map.clear()
|
||
|
||
# 复制频率数据
|
||
if self._counter_data.chars:
|
||
self._char_freq = self._counter_data.chars.copy()
|
||
if self._counter_data.pinyins:
|
||
self._pinyin_freq = self._counter_data.pinyins.copy()
|
||
|
||
# 构建核心索引
|
||
for char_info in self._counter_data.pairs.values():
|
||
char = char_info.char
|
||
pinyin = char_info.pinyin
|
||
char_info_id = char_info.id
|
||
|
||
# ID索引
|
||
self._id_to_info[char_info_id] = char_info
|
||
|
||
# 字符索引
|
||
if char not in self._char_to_ids:
|
||
self._char_to_ids[char] = []
|
||
self._char_to_ids[char].append(char_info_id)
|
||
|
||
# 拼音索引
|
||
if pinyin not in self._pinyin_to_ids:
|
||
self._pinyin_to_ids[pinyin] = []
|
||
self._pinyin_to_ids[pinyin].append(char_info_id)
|
||
|
||
# 字符-拼音映射
|
||
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||
|
||
self._total_pairs = len(self._id_to_info)
|
||
self._index_time = time.time() - start_time
|
||
|
||
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
||
"""
|
||
通过ID查询字符信息 - O(1)时间复杂度
|
||
|
||
Args:
|
||
id: 记录ID
|
||
|
||
Returns:
|
||
CharInfo对象,不存在则返回None
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
return self._id_to_info.get(id)
|
||
|
||
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||
"""
|
||
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
||
|
||
Args:
|
||
char: 汉字字符
|
||
limit: 返回结果数量限制,0表示返回所有
|
||
|
||
Returns:
|
||
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
if char not in self._char_to_ids:
|
||
return []
|
||
|
||
# 获取所有相关ID
|
||
ids = self._char_to_ids[char]
|
||
|
||
# 构建结果并排序
|
||
results = []
|
||
for char_info_id in ids:
|
||
char_info = self._id_to_info[char_info_id]
|
||
results.append((char_info_id, char_info.pinyin, char_info.count))
|
||
|
||
# 按次数降序排序
|
||
results.sort(key=lambda x: x[2], reverse=True)
|
||
|
||
# 应用限制
|
||
if limit > 0 and len(results) > limit:
|
||
results = results[:limit]
|
||
|
||
return results
|
||
|
||
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||
"""
|
||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||
|
||
Args:
|
||
pinyin: 拼音字符串
|
||
limit: 返回结果数量限制,0表示返回所有
|
||
|
||
Returns:
|
||
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
if pinyin not in self._pinyin_to_ids:
|
||
return []
|
||
|
||
# 获取所有相关ID
|
||
ids = self._pinyin_to_ids[pinyin]
|
||
|
||
# 构建结果并排序
|
||
results = []
|
||
for char_info_id in ids:
|
||
char_info = self._id_to_info[char_info_id]
|
||
results.append((char_info_id, char_info.char, char_info.count))
|
||
|
||
# 按次数降序排序
|
||
results.sort(key=lambda x: x[2], reverse=True)
|
||
|
||
# 应用限制
|
||
if limit > 0 and len(results) > limit:
|
||
results = results[:limit]
|
||
|
||
return results
|
||
|
||
def get_char_frequency(self, char: str) -> int:
|
||
"""
|
||
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
||
|
||
Args:
|
||
char: 汉字字符
|
||
|
||
Returns:
|
||
总出现次数
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
return self._char_freq.get(char, 0)
|
||
|
||
def get_pinyin_frequency(self, pinyin: str) -> int:
|
||
"""
|
||
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
||
|
||
Args:
|
||
pinyin: 拼音字符串
|
||
|
||
Returns:
|
||
总出现次数
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
return self._pinyin_freq.get(pinyin, 0)
|
||
|
||
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
||
"""
|
||
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
||
|
||
Args:
|
||
char: 汉字字符
|
||
pinyin: 拼音字符串
|
||
|
||
Returns:
|
||
出现次数
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
return self._char_pinyin_map.get((char, pinyin), 0)
|
||
|
||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||
"""
|
||
批量ID查询 - O(n)时间复杂度
|
||
|
||
Args:
|
||
ids: ID列表
|
||
|
||
Returns:
|
||
字典,key为ID,value为CharInfo对象(不存在则为None)
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
results = {}
|
||
for id_value in ids:
|
||
results[id_value] = self._id_to_info.get(id_value)
|
||
|
||
return results
|
||
|
||
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
||
"""
|
||
批量字符查询
|
||
|
||
Args:
|
||
chars: 字符列表
|
||
limit_per_char: 每个字符的结果数量限制
|
||
|
||
Returns:
|
||
字典,key为字符,value为查询结果列表
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
results = {}
|
||
for char in chars:
|
||
results[char] = self.query_by_char(char, limit_per_char)
|
||
|
||
return results
|
||
|
||
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
|
||
"""
|
||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||
|
||
Args:
|
||
prefix: 字符前缀
|
||
limit: 返回结果数量限制
|
||
|
||
Returns:
|
||
列表,每个元素为(字符, 总频率),按频率降序排序
|
||
"""
|
||
if not self._loaded:
|
||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||
|
||
matches = []
|
||
for char, freq in self._char_freq.items():
|
||
if char.startswith(prefix):
|
||
matches.append((char, freq))
|
||
|
||
# 按频率降序排序
|
||
matches.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
return matches[:limit] if limit > 0 else matches
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""
|
||
获取系统统计信息
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
if not self._loaded:
|
||
return {"status": "not_loaded"}
|
||
|
||
top_chars = sorted(
|
||
self._char_freq.items(),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:10]
|
||
|
||
top_pinyins = sorted(
|
||
self._pinyin_freq.items(),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:10]
|
||
|
||
return {
|
||
"status": "loaded",
|
||
"timestamp": self._counter_data.timestamp,
|
||
"total_pairs": self._total_pairs,
|
||
"total_characters": len(self._char_freq),
|
||
"total_pinyins": len(self._pinyin_freq),
|
||
"valid_input_character_count": self._counter_data.valid_input_character_count,
|
||
"load_time_seconds": self._load_time,
|
||
"index_time_seconds": self._index_time,
|
||
"top_chars": top_chars,
|
||
"top_pinyins": top_pinyins,
|
||
"metadata": self._counter_data.metadata
|
||
}
|
||
|
||
def is_loaded(self) -> bool:
|
||
"""检查数据是否已加载"""
|
||
return self._loaded
|
||
|
||
def clear(self):
|
||
"""清除所有数据和索引,释放内存"""
|
||
self._counter_data = None
|
||
self._id_to_info.clear()
|
||
self._char_to_ids.clear()
|
||
self._pinyin_to_ids.clear()
|
||
self._char_freq.clear()
|
||
self._pinyin_freq.clear()
|
||
self._char_pinyin_map.clear()
|
||
self._loaded = False
|
||
self._total_pairs = 0
|
||
self._load_time = 0.0
|
||
self._index_time = 0.0 |