From c22313748cc416321bf36cf57bf5e41da3eb80c8 Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 2 Feb 2026 07:17:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=B1=89=E5=AD=97?= =?UTF-8?q?=E6=8B=BC=E9=9F=B3=E7=BB=9F=E8=AE=A1=E5=B7=A5=E5=85=B7=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=A4=9A=E7=BA=BF=E7=A8=8B=E5=A4=84=E7=90=86?= =?UTF-8?q?=E4=B8=8E=E5=A4=9A=E7=A7=8D=E6=A0=BC=E5=BC=8F=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 + .python-version | 1 + pyproject.toml | 21 ++ src/suinput/__init__.py | 0 src/suinput/__main__.py | 52 +++++ src/suinput/char_info.py | 28 +++ src/suinput/counter.py | 345 ++++++++++++++++++++++++++++++++ src/suinput/query.py | 414 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 866 insertions(+) create mode 100644 .python-version create mode 100644 pyproject.toml create mode 100644 src/suinput/__init__.py create mode 100644 src/suinput/__main__.py create mode 100644 src/suinput/char_info.py create mode 100644 src/suinput/counter.py create mode 100644 src/suinput/query.py diff --git a/.gitignore b/.gitignore index 1c2eeeb..1d9903f 100644 --- a/.gitignore +++ b/.gitignore @@ -210,3 +210,8 @@ cython_debug/ *.out *.app +uv.lock + +*.json +*.log +marimo/ \ No newline at end of file diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6478c70 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[project] +name = "suinput" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [ + "datasets>=4.5.0", + "loguru>=0.7.3", + "modelscope>=1.34.0", + "msgpack>=1.1.2", + "pypinyin>=0.55.0", + "rich>=14.3.1", + "typer>=0.21.1", +] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", +] + diff --git a/src/suinput/__init__.py b/src/suinput/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/suinput/__main__.py b/src/suinput/__main__.py new file mode 100644 index 0000000..7936096 --- /dev/null +++ b/src/suinput/__main__.py @@ -0,0 +1,52 @@ +from .counter import counter_dataset +from multiprocessing import cpu_count + +import typer +from rich.console import Console +from rich.table import Table + +app = typer.Typer() +console = Console() + +@app.command() +def app_counter_dataset( + dataset_path: str = typer.Option(..., help="数据集路径"), + num_workers: int = typer.Option(cpu_count()-1, help="工作线程数"), + output_path: str = typer.Option(None, help="输出文件路径(可选)"), + format: str = typer.Option("msgpack", help="输出格式:json/msgpack/pickle"), + compress: bool = typer.Option(False, help="是否压缩输出文件"), + max_queue_size: int = typer.Option(5000, help="队列最大容量"), + max_lines: int = typer.Option(10000, help="处理的最大行数"), +): + """ + 统计数据集中汉字和拼音的频率,并导出结果。 + """ + console.print("[bold green]🚀 开始处理数据集...[/bold green]") + + # 显示参数配置 + table = Table(title="📊 参数配置") + table.add_column("参数", style="cyan") + table.add_column("值", style="magenta") + table.add_row("数据集路径", dataset_path) + table.add_row("工作线程数", str(num_workers)) + table.add_row("输出路径", output_path or "默认输出") + table.add_row("输出格式", format) + table.add_row("是否压缩", str(compress)) + table.add_row("队列容量", str(max_queue_size)) + table.add_row("最大行数", str(max_lines)) + console.print(table) + + # 调用处理函数 + counter_dataset( + dataset_path=dataset_path, + num_workers=num_workers, + output_path=output_path, + format=format, + compress=compress, + max_queue_size=max_queue_size, + max_lines=max_lines, + ) + + console.print("[bold green]✅ 数据处理完成![/bold green]") + +app() diff --git a/src/suinput/char_info.py b/src/suinput/char_info.py new file mode 100644 index 0000000..c791b25 --- /dev/null +++ b/src/suinput/char_info.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, Any + +@dataclass +class CharInfo: + """字符信息数据结构""" + id: int = 0 + char: str = "" + pinyin: str = "" + count: int = 0 + + +@dataclass +class PinyinCharPairsCounter: + """拼音字符对计数器""" + + timestamp: str = "" + total_characters: int = 0 + total_pinyins: int = 0 + valid_input_character_count: int = 0 # 新增字段 + pairs: Dict[int, CharInfo] = field(default_factory=dict) + chars: Dict[str, int] = field(default_factory=dict) + pinyins: Dict[str, int] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/src/suinput/counter.py b/src/suinput/counter.py new file mode 100644 index 0000000..28e9e2b --- /dev/null +++ b/src/suinput/counter.py @@ -0,0 +1,345 @@ +import os +import json +import pickle +import msgpack +from multiprocessing import Process, Lock, Queue, current_process, cpu_count +from collections import Counter +from dataclasses import asdict +import time +import queue + +from datasets import load_dataset + +from pypinyin import lazy_pinyin +from loguru import logger + +from tqdm import trange + +from .char_info import PinyinCharPairsCounter, CharInfo + + + +class PinyinCharStatistics: + def __init__( + self, is_multithreaded=False, num_workers: int = 8, max_queue_size: int = 5000 + ): + """ + 初始化处理器 + """ + self.max_queue_size = max_queue_size + self.num_workers = num_workers + self.is_multithreaded = is_multithreaded + self._is_shutting_down = False + + # 主进程中的计数器(用于单线程模式) + self._char_counter = Counter() + self._pinyin_counter = Counter() + self.pinyin_char_counter = Counter() + + # 多进程相关 + if is_multithreaded: + # manager = Manager() + # 创建任务队列和结果队列 + self.task_queue = Queue(maxsize=max_queue_size) + self.result_queue = Queue() # 用于收集工作进程的结果 + self.lock = Lock() + self.workers = [] + self.running = True + + # 启动工作进程 + self._start_workers() + + def _start_workers(self): + """启动工作进程""" + for i in range(self.num_workers): + worker = Process( + target=self._worker_loop, + args=(self.task_queue, self.result_queue), + daemon=True + ) + worker.start() + self.workers.append(worker) + + @staticmethod + def _worker_loop(task_queue: Queue, result_queue: Queue): + """工作进程主循环""" + # 每个进程有自己的本地计数器 + local_char_counter = Counter() + local_pinyin_counter = Counter() + local_pair_counter = Counter() + + should_exit = False + + while not should_exit: + try: + text = task_queue.get(timeout=1) + if text is None: # 终止信号 + should_exit = True + continue # 继续循环,但会退出 + # 处理文本 + chinese_chars = [c for c in text if "\u4e00" <= c <= "\u9fff"] + if len(chinese_chars) / len(text) < 0.6: + continue + + # 获取拼音 + pinyin_list = lazy_pinyin(text, errors=lambda x: list(x)) + pairs = [ + (char, pinyin) for char, pinyin in zip(text, pinyin_list) + if char != pinyin + ] + + # 更新本地计数器 + local_char_counter.update(chinese_chars) + local_pinyin_counter.update([p for _, p in pairs]) + local_pair_counter.update(pairs) + except queue.Empty: + if should_exit: + break + continue + except Exception as e: + logger.error(f"处理任务时出错: {e}") + if local_char_counter or local_pair_counter: + result_queue.put({ + 'char_counter': dict(local_char_counter), + 'pinyin_counter': dict(local_pinyin_counter), + 'pair_counter': dict(local_pair_counter) + }) + + def _collect_results(self): + """收集所有工作进程的结果""" + # 1. 先发送终止信号 + for i in range(len(self.workers)): + self.task_queue.put(None) + + # 2. 给工作进程时间处理 + time.sleep(1) # 等待1秒让进程收到信号 + + # 3. 收集结果时增加重试机制 + results_collected = 0 + max_retries = len(self.workers) * 3 # 最多尝试3轮 + + for retry in range(max_retries): + try: + # 非阻塞检查 + if results_collected >= len(self.workers): + break + + result = self.result_queue.get(timeout=15) # 增加超时时间 + + if result: + # 合并到主进程计数器 + self._char_counter.update(result['char_counter']) + self._pinyin_counter.update(result['pinyin_counter']) + + # 处理pair_counter(需要从字符串转换回元组) + for pairs, count in result['pair_counter'].items(): + # 将字符串 "(char, pinyin)" 转换回元组 + if isinstance(pairs, str): + try: + # 简单解析:去除括号,分割字符串 + content = pairs[1:-1] + # 按逗号分割,最多分割一次 + if ", " in content: + char, pinyin = content.split(", ", 1) + else: + char, pinyin = content.split(",", 1) + # 去除可能的引号 + char = char.strip("'\"") + pinyin = pinyin.strip("'\"") + self.pinyin_char_counter.update({(char, pinyin): count}) + except Exception: + logger.warning(f"无法解析pair: {pairs}") + elif isinstance(pairs, tuple): + # 直接作为key使用 + self.pinyin_char_counter.update({pairs: count}) + else: + logger.warning(f"无法处理pair: {pairs}") + + results_collected += 1 + except queue.Empty: + if retry < max_retries - 1: + logger.warning(f"第 {retry + 1} 次等待结果超时,继续尝试...") + continue + else: + logger.error(f"等待结果超时,已尝试 {max_retries} 次") + break + + alive_count = 0 + for worker in self.workers: + worker.join(timeout=5) # 增加等待时间 + if worker.is_alive(): + logger.warning(f"工作进程 {worker.pid} 仍在运行") + # 检查进程是否真的在忙 + import psutil + try: + p = psutil.Process(worker.pid) + # 检查CPU使用率或状态 + logger.warning(f"进程状态: {p.status()}, CPU: {p.cpu_percent()}") + except Exception: + pass + worker.terminate() + alive_count += 1 + + if alive_count > 0: + logger.warning(f"强制终止了 {alive_count} 个工作进程") + else: + logger.info("所有工作进程正常结束") + + self.workers.clear() + def _process_single_text(self, text: str): + """单线程处理单个文本(线程安全)""" + chinese_chars = [c for c in text if self._is_chinese_char(c)] + if len(chinese_chars) / len(text) < 0.6: + return + + # 获取拼音 + pinyin_list = lazy_pinyin(text, errors=lambda x: list(x)) + pairs = [ + (char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin + ] + + # 更新计数器 + self.pinyin_char_counter.update(pairs) + self._char_counter.update(chinese_chars) + self._pinyin_counter.update([p for _, p in pairs]) + + def _is_chinese_char(self, char: str) -> bool: + """判断字符是否为中文汉字""" + return "\u4e00" <= char <= "\u9fff" + + def process_text(self, text: str): + """ + 将文本加入任务队列,由工作进程异步处理 + """ + if self._is_shutting_down: + raise RuntimeError("处理器正在关闭,无法接受新任务") + + if self.is_multithreaded: + if not self.running: + raise RuntimeError("处理器已关闭,无法接受新任务") + self.task_queue.put(text) + if self.task_queue.qsize() > self.max_queue_size * 0.8: + time.sleep(0.02) + else: + self._process_single_text(text) + + def shutdown(self, wait: bool = True): + """ + 关闭处理器,等待所有任务完成(可选) + """ + if self._is_shutting_down: + return + + self._is_shutting_down = True + self.running = False + + if wait and self.is_multithreaded: + # 收集结果 + self._collect_results() + + def export( + self, output_path: str = None, format: str = "msgpack", compress: bool = False + ): + """ + 导出统计信息 + """ + logger.info("开始导出快照...") + + # 如果是多进程模式,确保已收集所有结果 + if self.is_multithreaded: + self._collect_results() + + # 构建 PinyinCharPairsCounter 对象 + counter_data = PinyinCharPairsCounter() + counter_data.total_characters = len(self._char_counter) + counter_data.total_pinyins = len(self._pinyin_counter) + counter_data.valid_input_character_count = sum(self.pinyin_char_counter.values()) + + # 构造 CharInfo 列表 + char_info_list = [] + pc_counter_sorted = sorted( + self.pinyin_char_counter.items(), key=lambda x: x[1], reverse=True + ) + + for idx, ((char, pinyin), count) in enumerate(pc_counter_sorted): + char_info = CharInfo(id=idx + 1, char=char, pinyin=pinyin, count=count) + char_info_list.append(char_info) + + counter_data.pairs = {info.id: info for info in char_info_list} + counter_data.chars = dict(self._char_counter) + counter_data.pinyins = dict(self._pinyin_counter) + counter_data.metadata = { + "format": format, + "compressed": compress, + "pair_count": len(counter_data.pairs), + } + + if not output_path: + dict_counter_data = asdict(counter_data) + dict_counter_data.pop("pairs") + dict_counter_data.pop("chars") + dict_counter_data.pop("pinyins") + print(dict_counter_data) + return counter_data.metadata + + # 序列化数据 + if format == "json": + data_str = json.dumps(asdict(counter_data), ensure_ascii=False, indent=2) + data = data_str.encode("utf-8") + elif format == "msgpack": + data = msgpack.packb(asdict(counter_data), use_bin_type=True) + elif format == "pickle": + data = pickle.dumps(counter_data) + else: + raise ValueError(f"不支持的格式: {format}") + + # 可选压缩 + if compress: + import gzip + data = gzip.compress(data) + + # 写入文件 + with open(output_path, "wb") as f: + f.write(data) + + logger.info(f"快照导出完成: {output_path} (格式: {format}, 压缩: {compress})") + return counter_data.metadata + + +def counter_dataset( + dataset_path: str, + num_workers: int = cpu_count(), + output_path: str = None, + format: str = "msgpack", + compress: bool = True, + max_queue_size: int = 50000, + max_lines: int = 10000, +): + """ + 处理数据集的核心逻辑 + """ + dataset = load_dataset(dataset_path, split="train", streaming=True) + iter_data = iter(dataset) + + # 使用多进程 + processor = PinyinCharStatistics( + is_multithreaded=True, + num_workers=4, + max_queue_size=max_queue_size + ) + + # 使用 tqdm.trange 显示进度条 + for i in trange(max_lines, desc="正在处理数据"): + try: + line = next(iter_data) + text = line.get("text", "") + if text: + processor.process_text(text) + except StopIteration: + print("⚠️ 数据集已提前结束。") + break + + # 导出前会收集所有结果 + processor.export(output_path, format=format, compress=compress) + if not processor._is_shutting_down: + processor.shutdown() \ No newline at end of file diff --git a/src/suinput/query.py b/src/suinput/query.py new file mode 100644 index 0000000..eaa26d5 --- /dev/null +++ b/src/suinput/query.py @@ -0,0 +1,414 @@ +# file name: query_engine.py +import json +import pickle +import msgpack +import gzip +from typing import Dict, List, Optional, Tuple, Any +import time +import os + +from .char_info import CharInfo, PinyinCharPairsCounter + + +class QueryEngine: + """ + 高效拼音-字符查询引擎 + + 特性: + 1. O(1)时间复杂度的ID查询 + 2. O(1)时间复杂度的字符查询 + 3. O(1)时间复杂度的拼音查询 + 4. 内存友好,构建高效索引 + 5. 支持批量查询和前缀搜索 + """ + + def __init__(self): + """初始化查询引擎""" + self._counter_data: Optional[PinyinCharPairsCounter] = None + + # 核心索引 - 提供O(1)查询 + self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo + self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 + self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 + + # 辅助索引 - 快速获取详细信息 + self._char_freq: Dict[str, int] = {} # 字符总频率 + self._pinyin_freq: Dict[str, int] = {} # 拼音总频率 + self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count + + # 统计信息 + self._loaded = False + self._total_pairs = 0 + self._load_time = 0.0 + self._index_time = 0.0 + + def load(self, file_path: str) -> Dict[str, Any]: + """ + 加载统计结果文件 + + Args: + file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩 + + Returns: + 元数据字典 + + Raises: + FileNotFoundError: 文件不存在 + ValueError: 文件格式不支持 + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + start_time = time.time() + + # 读取并解析文件 + self._counter_data = self._parse_file(file_path) + + # 构建索引 + self._build_indices() + + self._load_time = time.time() - start_time + self._loaded = True + + return self._counter_data.metadata + + def _parse_file(self, file_path: str) -> PinyinCharPairsCounter: + """解析文件,支持多种格式""" + with open(file_path, 'rb') as f: + data = f.read() + + # 尝试解压 + try: + data = gzip.decompress(data) + except Exception: + pass + + # 尝试不同格式 + for parser_name, parser in [ + ('msgpack', self._parse_msgpack), + ('pickle', self._parse_pickle), + ('json', self._parse_json) + ]: + try: + return parser(data) + except Exception: + continue + + raise ValueError("无法解析文件格式") + + def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter: + """解析msgpack格式""" + data_dict = msgpack.unpackb(data, raw=False) + return self._dict_to_counter(data_dict) + + def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter: + """解析pickle格式""" + return pickle.loads(data) + + def _parse_json(self, data: bytes) -> PinyinCharPairsCounter: + """解析json格式""" + data_str = data.decode('utf-8') + data_dict = json.loads(data_str) + return self._dict_to_counter(data_dict) + + def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter: + """字典转PinyinCharPairsCounter""" + # 转换CharInfo字典 + pairs_dict = {} + if 'pairs' in data_dict and data_dict['pairs']: + for id_str, info_dict in data_dict['pairs'].items(): + pairs_dict[int(id_str)] = CharInfo(**info_dict) + data_dict['pairs'] = pairs_dict + + return PinyinCharPairsCounter(**data_dict) + + def _build_indices(self): + """构建所有查询索引""" + start_time = time.time() + + # 重置索引 + self._id_to_info.clear() + self._char_to_ids.clear() + self._pinyin_to_ids.clear() + self._char_freq.clear() + self._pinyin_freq.clear() + self._char_pinyin_map.clear() + + # 复制频率数据 + if self._counter_data.chars: + self._char_freq = self._counter_data.chars.copy() + if self._counter_data.pinyins: + self._pinyin_freq = self._counter_data.pinyins.copy() + + # 构建核心索引 + for char_info in self._counter_data.pairs.values(): + char = char_info.char + pinyin = char_info.pinyin + char_info_id = char_info.id + + # ID索引 + self._id_to_info[char_info_id] = char_info + + # 字符索引 + if char not in self._char_to_ids: + self._char_to_ids[char] = [] + self._char_to_ids[char].append(char_info_id) + + # 拼音索引 + if pinyin not in self._pinyin_to_ids: + self._pinyin_to_ids[pinyin] = [] + self._pinyin_to_ids[pinyin].append(char_info_id) + + # 字符-拼音映射 + self._char_pinyin_map[(char, pinyin)] = char_info.count + + self._total_pairs = len(self._id_to_info) + self._index_time = time.time() - start_time + + def query_by_id(self, id: int) -> Optional[CharInfo]: + """ + 通过ID查询字符信息 - O(1)时间复杂度 + + Args: + id: 记录ID + + Returns: + CharInfo对象,不存在则返回None + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + return self._id_to_info.get(id) + + def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]: + """ + 通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数 + + Args: + char: 汉字字符 + limit: 返回结果数量限制,0表示返回所有 + + Returns: + 列表,每个元素为(id, 拼音, 次数),按次数降序排序 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + if char not in self._char_to_ids: + return [] + + # 获取所有相关ID + ids = self._char_to_ids[char] + + # 构建结果并排序 + results = [] + for char_info_id in ids: + char_info = self._id_to_info[char_info_id] + results.append((char_info_id, char_info.pinyin, char_info.count)) + + # 按次数降序排序 + results.sort(key=lambda x: x[2], reverse=True) + + # 应用限制 + if limit > 0 and len(results) > limit: + results = results[:limit] + + return results + + def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]: + """ + 通过拼音查询字符信息 - O(1) + O(k)时间复杂度 + + Args: + pinyin: 拼音字符串 + limit: 返回结果数量限制,0表示返回所有 + + Returns: + 列表,每个元素为(id, 字符, 次数),按次数降序排序 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + if pinyin not in self._pinyin_to_ids: + return [] + + # 获取所有相关ID + ids = self._pinyin_to_ids[pinyin] + + # 构建结果并排序 + results = [] + for char_info_id in ids: + char_info = self._id_to_info[char_info_id] + results.append((char_info_id, char_info.char, char_info.count)) + + # 按次数降序排序 + results.sort(key=lambda x: x[2], reverse=True) + + # 应用限制 + if limit > 0 and len(results) > limit: + results = results[:limit] + + return results + + def get_char_frequency(self, char: str) -> int: + """ + 获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度 + + Args: + char: 汉字字符 + + Returns: + 总出现次数 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + return self._char_freq.get(char, 0) + + def get_pinyin_frequency(self, pinyin: str) -> int: + """ + 获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度 + + Args: + pinyin: 拼音字符串 + + Returns: + 总出现次数 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + return self._pinyin_freq.get(pinyin, 0) + + def get_char_pinyin_count(self, char: str, pinyin: str) -> int: + """ + 获取特定字符-拼音对的出现次数 - O(1)时间复杂度 + + Args: + char: 汉字字符 + pinyin: 拼音字符串 + + Returns: + 出现次数 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + return self._char_pinyin_map.get((char, pinyin), 0) + + def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]: + """ + 批量ID查询 - O(n)时间复杂度 + + Args: + ids: ID列表 + + Returns: + 字典,key为ID,value为CharInfo对象(不存在则为None) + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + results = {} + for id_value in ids: + results[id_value] = self._id_to_info.get(id_value) + + return results + + def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]: + """ + 批量字符查询 + + Args: + chars: 字符列表 + limit_per_char: 每个字符的结果数量限制 + + Returns: + 字典,key为字符,value为查询结果列表 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + results = {} + for char in chars: + results[char] = self.query_by_char(char, limit_per_char) + + return results + + def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]: + """ + 根据字符前缀搜索 - O(n)时间复杂度,n为字符总数 + + Args: + prefix: 字符前缀 + limit: 返回结果数量限制 + + Returns: + 列表,每个元素为(字符, 总频率),按频率降序排序 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + matches = [] + for char, freq in self._char_freq.items(): + if char.startswith(prefix): + matches.append((char, freq)) + + # 按频率降序排序 + matches.sort(key=lambda x: x[1], reverse=True) + + return matches[:limit] if limit > 0 else matches + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取系统统计信息 + + Returns: + 统计信息字典 + """ + if not self._loaded: + return {"status": "not_loaded"} + + top_chars = sorted( + self._char_freq.items(), + key=lambda x: x[1], + reverse=True + )[:10] + + top_pinyins = sorted( + self._pinyin_freq.items(), + key=lambda x: x[1], + reverse=True + )[:10] + + return { + "status": "loaded", + "timestamp": self._counter_data.timestamp, + "total_pairs": self._total_pairs, + "total_characters": len(self._char_freq), + "total_pinyins": len(self._pinyin_freq), + "valid_input_character_count": self._counter_data.valid_input_character_count, + "load_time_seconds": self._load_time, + "index_time_seconds": self._index_time, + "top_chars": top_chars, + "top_pinyins": top_pinyins, + "metadata": self._counter_data.metadata + } + + def is_loaded(self) -> bool: + """检查数据是否已加载""" + return self._loaded + + def clear(self): + """清除所有数据和索引,释放内存""" + self._counter_data = None + self._id_to_info.clear() + self._char_to_ids.clear() + self._pinyin_to_ids.clear() + self._char_freq.clear() + self._pinyin_freq.clear() + self._char_pinyin_map.clear() + self._loaded = False + self._total_pairs = 0 + self._load_time = 0.0 + self._index_time = 0.0 \ No newline at end of file