# file name: query_engine.py import json import pickle import msgpack import gzip from typing import Dict, List, Optional, Tuple, Any import time import os from .char_info import CharInfo, PinyinCharPairsCounter class QueryEngine: """ 高效拼音-字符查询引擎 特性: 1. O(1)时间复杂度的ID查询 2. O(1)时间复杂度的字符查询 3. O(1)时间复杂度的拼音查询 4. 内存友好,构建高效索引 5. 支持批量查询和前缀搜索 """ def __init__(self): """初始化查询引擎""" self._counter_data: Optional[PinyinCharPairsCounter] = None # 核心索引 - 提供O(1)查询 self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 # 辅助索引 - 快速获取详细信息 self._char_freq: Dict[str, int] = {} # 字符总频率 self._pinyin_freq: Dict[str, int] = {} # 拼音总频率 self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count # 统计信息 self._loaded = False self._total_pairs = 0 self._load_time = 0.0 self._index_time = 0.0 def load(self, file_path: str) -> Dict[str, Any]: """ 加载统计结果文件 Args: file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩 Returns: 元数据字典 Raises: FileNotFoundError: 文件不存在 ValueError: 文件格式不支持 """ if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") start_time = time.time() # 读取并解析文件 self._counter_data = self._parse_file(file_path) # 构建索引 self._build_indices() self._load_time = time.time() - start_time self._loaded = True return self._counter_data.metadata def _parse_file(self, file_path: str) -> PinyinCharPairsCounter: """解析文件,支持多种格式""" with open(file_path, 'rb') as f: data = f.read() # 尝试解压 try: data = gzip.decompress(data) except Exception: pass # 尝试不同格式 for parser_name, parser in [ ('msgpack', self._parse_msgpack), ('pickle', self._parse_pickle), ('json', self._parse_json) ]: try: return parser(data) except Exception: continue raise ValueError("无法解析文件格式") def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter: """解析msgpack格式""" data_dict = msgpack.unpackb(data, raw=False) return self._dict_to_counter(data_dict) def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter: """解析pickle格式""" return pickle.loads(data) def _parse_json(self, data: bytes) -> PinyinCharPairsCounter: """解析json格式""" data_str = data.decode('utf-8') data_dict = json.loads(data_str) return self._dict_to_counter(data_dict) def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter: """字典转PinyinCharPairsCounter""" # 转换CharInfo字典 pairs_dict = {} if 'pairs' in data_dict and data_dict['pairs']: for id_str, info_dict in data_dict['pairs'].items(): pairs_dict[int(id_str)] = CharInfo(**info_dict) data_dict['pairs'] = pairs_dict return PinyinCharPairsCounter(**data_dict) def _build_indices(self): """构建所有查询索引""" start_time = time.time() # 重置索引 self._id_to_info.clear() self._char_to_ids.clear() self._pinyin_to_ids.clear() self._char_freq.clear() self._pinyin_freq.clear() self._char_pinyin_map.clear() # 复制频率数据 if self._counter_data.chars: self._char_freq = self._counter_data.chars.copy() if self._counter_data.pinyins: self._pinyin_freq = self._counter_data.pinyins.copy() # 构建核心索引 for char_info in self._counter_data.pairs.values(): char = char_info.char pinyin = char_info.pinyin char_info_id = char_info.id # ID索引 self._id_to_info[char_info_id] = char_info # 字符索引 if char not in self._char_to_ids: self._char_to_ids[char] = [] self._char_to_ids[char].append(char_info_id) # 拼音索引 if pinyin not in self._pinyin_to_ids: self._pinyin_to_ids[pinyin] = [] self._pinyin_to_ids[pinyin].append(char_info_id) # 字符-拼音映射 self._char_pinyin_map[(char, pinyin)] = char_info.count self._total_pairs = len(self._id_to_info) self._index_time = time.time() - start_time def query_by_id(self, id: int) -> Optional[CharInfo]: """ 通过ID查询字符信息 - O(1)时间复杂度 Args: id: 记录ID Returns: CharInfo对象,不存在则返回None """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") return self._id_to_info.get(id) def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]: """ 通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数 Args: char: 汉字字符 limit: 返回结果数量限制,0表示返回所有 Returns: 列表,每个元素为(id, 拼音, 次数),按次数降序排序 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") if char not in self._char_to_ids: return [] # 获取所有相关ID ids = self._char_to_ids[char] # 构建结果并排序 results = [] for char_info_id in ids: char_info = self._id_to_info[char_info_id] results.append((char_info_id, char_info.pinyin, char_info.count)) # 按次数降序排序 results.sort(key=lambda x: x[2], reverse=True) # 应用限制 if limit > 0 and len(results) > limit: results = results[:limit] return results def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]: """ 通过拼音查询字符信息 - O(1) + O(k)时间复杂度 Args: pinyin: 拼音字符串 limit: 返回结果数量限制,0表示返回所有 Returns: 列表,每个元素为(id, 字符, 次数),按次数降序排序 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") if pinyin not in self._pinyin_to_ids: return [] # 获取所有相关ID ids = self._pinyin_to_ids[pinyin] # 构建结果并排序 results = [] for char_info_id in ids: char_info = self._id_to_info[char_info_id] results.append((char_info_id, char_info.char, char_info.count)) # 按次数降序排序 results.sort(key=lambda x: x[2], reverse=True) # 应用限制 if limit > 0 and len(results) > limit: results = results[:limit] return results def get_char_frequency(self, char: str) -> int: """ 获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度 Args: char: 汉字字符 Returns: 总出现次数 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") return self._char_freq.get(char, 0) def get_pinyin_frequency(self, pinyin: str) -> int: """ 获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度 Args: pinyin: 拼音字符串 Returns: 总出现次数 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") return self._pinyin_freq.get(pinyin, 0) def get_char_pinyin_count(self, char: str, pinyin: str) -> int: """ 获取特定字符-拼音对的出现次数 - O(1)时间复杂度 Args: char: 汉字字符 pinyin: 拼音字符串 Returns: 出现次数 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") return self._char_pinyin_map.get((char, pinyin), 0) def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]: """ 批量ID查询 - O(n)时间复杂度 Args: ids: ID列表 Returns: 字典,key为ID,value为CharInfo对象(不存在则为None) """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") results = {} for id_value in ids: results[id_value] = self._id_to_info.get(id_value) return results def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]: """ 批量字符查询 Args: chars: 字符列表 limit_per_char: 每个字符的结果数量限制 Returns: 字典,key为字符,value为查询结果列表 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") results = {} for char in chars: results[char] = self.query_by_char(char, limit_per_char) return results def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]: """ 根据字符前缀搜索 - O(n)时间复杂度,n为字符总数 Args: prefix: 字符前缀 limit: 返回结果数量限制 Returns: 列表,每个元素为(字符, 总频率),按频率降序排序 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") matches = [] for char, freq in self._char_freq.items(): if char.startswith(prefix): matches.append((char, freq)) # 按频率降序排序 matches.sort(key=lambda x: x[1], reverse=True) return matches[:limit] if limit > 0 else matches def get_statistics(self) -> Dict[str, Any]: """ 获取系统统计信息 Returns: 统计信息字典 """ if not self._loaded: return {"status": "not_loaded"} top_chars = sorted( self._char_freq.items(), key=lambda x: x[1], reverse=True )[:10] top_pinyins = sorted( self._pinyin_freq.items(), key=lambda x: x[1], reverse=True )[:10] return { "status": "loaded", "timestamp": self._counter_data.timestamp, "total_pairs": self._total_pairs, "total_characters": len(self._char_freq), "total_pinyins": len(self._pinyin_freq), "valid_input_character_count": self._counter_data.valid_input_character_count, "load_time_seconds": self._load_time, "index_time_seconds": self._index_time, "top_chars": top_chars, "top_pinyins": top_pinyins, "metadata": self._counter_data.metadata } def is_loaded(self) -> bool: """检查数据是否已加载""" return self._loaded def clear(self): """清除所有数据和索引,释放内存""" self._counter_data = None self._id_to_info.clear() self._char_to_ids.clear() self._pinyin_to_ids.clear() self._char_freq.clear() self._pinyin_freq.clear() self._char_pinyin_map.clear() self._loaded = False self._total_pairs = 0 self._load_time = 0.0 self._index_time = 0.0