From f2c260de721d9dea7d1653da7d463db4f3387e40 Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 9 Feb 2026 00:43:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=8B=BC=E9=9F=B3?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E6=B3=95=E6=A8=A1=E6=8B=9F=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E5=8F=8A=E7=9B=B8=E5=85=B3=E5=8A=9F=E8=83=BD=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- pyproject.toml | 1 + src/suinput/counter.py | 7 + src/suinput/dataset.py | 624 +++++++++++++++++++++++++++++++++++++++++ src/suinput/query.py | 54 +++- test/test_query.py | 151 ++++++++++ 6 files changed, 833 insertions(+), 7 deletions(-) create mode 100644 src/suinput/dataset.py create mode 100644 test/test_query.py diff --git a/.gitignore b/.gitignore index cf5ba2e..6b6ce0a 100644 --- a/.gitignore +++ b/.gitignore @@ -213,4 +213,5 @@ cython_debug/ uv.lock *.log -marimo/ \ No newline at end of file +marimo/ +__marimo__/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6478c70..e8f64de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "msgpack>=1.1.2", "pypinyin>=0.55.0", "rich>=14.3.1", + "transformers>=5.1.0", "typer>=0.21.1", ] diff --git a/src/suinput/counter.py b/src/suinput/counter.py index 28e9e2b..5de5e6a 100644 --- a/src/suinput/counter.py +++ b/src/suinput/counter.py @@ -16,6 +16,7 @@ from loguru import logger from tqdm import trange from .char_info import PinyinCharPairsCounter, CharInfo +from .query import QueryEngine @@ -48,6 +49,12 @@ class PinyinCharStatistics: # 启动工作进程 self._start_workers() + + + # 实现一个加载历史快照,并且以历史快照的数据进行初始化的函数 + def load_history_snapshot(self, file_path: str): + """加载历史快照,并且以历史快照的数据进行初始化""" + self def _start_workers(self): """启动工作进程""" diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py new file mode 100644 index 0000000..a581082 --- /dev/null +++ b/src/suinput/dataset.py @@ -0,0 +1,624 @@ +import torch +from torch.utils.data import IterableDataset, DataLoader +from datasets import load_dataset +from pypinyin import lazy_pinyin +import random +from modelscope import AutoTokenizer +from typing import Tuple, List, Dict, Any +import re +import numpy as np + +from loguru import logger + + +class PinyinInputDataset(IterableDataset): + """ + 拼音输入法模拟数据集 + + 特性: + 1. 流式读取数据集,内存友好 + 2. 实时拼音转换和多音字处理 + 3. 前文上下文多种采样方式 + 4. 拼音截断模拟不完整输入 + 5. 内置削峰填谷算法平衡数据分布 + 6. 缓冲区打乱支持多进程 + """ + + def __init__( + self, + data_dir: str, + query_engine, + tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std", + max_len: int = 88, + text_field: str = "text", + batch_query_size: int = 1000, + # 打乱参数 + shuffle: bool = True, + shuffle_buffer_size: int = 100, + # 削峰填谷参数 + max_freq: int = 434748359, # "的"的频率 + min_freq: int = 109, # "蓚"的频率 + drop_start_freq: int = 30000000, # 开始丢弃的阈值 + repeat_end_freq: int = 10000, # 开始重复的阈值 + max_drop_prob: float = 0.8, # 最大丢弃概率 + max_repeat_expect: float = 50.0, # 最大重复期望 + ): + """ + 初始化数据集 + + Args: + data_dir: 数据集目录 + query_engine: QueryEngine实例 + tokenizer_name: tokenizer名称 + max_len: 最大序列长度 + text_field: 文本字段名 + batch_query_size: 批量查询大小 + shuffle: 是否打乱数据 + shuffle_buffer_size: 打乱缓冲区大小 + max_freq: 最大字符频率 + min_freq: 最小字符频率 + drop_start_freq: 开始削峰的频率阈值 + repeat_end_freq: 开始填谷的频率阈值 + max_drop_prob: 最高频率字符的丢弃概率 + max_repeat_expect: 最低频率字符的重复期望 + """ + self.query_engine = query_engine + self.max_len = max_len + self.text_field = text_field + self.batch_query_size = batch_query_size + + # 打乱相关参数 + self.shuffle = shuffle + self.shuffle_buffer_size = shuffle_buffer_size + self.shuffle_buffer = [] + + # 削峰填谷参数 + self.max_freq = max_freq + self.min_freq = min_freq + self.drop_start_freq = drop_start_freq + self.repeat_end_freq = repeat_end_freq + self.max_drop_prob = max_drop_prob + self.max_repeat_expect = max_repeat_expect + + # 加载tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + # 获取总字频用于后续计算 + stats = query_engine.get_statistics() + self.total_chars = stats.get("valid_input_character_count", 0) + + # 汉字正则表达式 + self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]") + + # 缓存字典 + self.char_info_cache = {} + + # 加载数据集 + self.dataset = load_dataset(data_dir, split="train", streaming=True) + + def is_chinese_char(self, char: str) -> bool: + """判断是否为中文字符""" + return bool(self.chinese_pattern.match(char)) + + def get_next_chinese_chars( + self, text: str, start_idx: int, max_count: int = 3 + ) -> List[Tuple[str, str]]: + """ + 获取后续的中文字符及其拼音 + + Args: + text: 完整文本 + start_idx: 起始位置 + max_count: 最大字符数 + + Returns: + 列表,每个元素为(字符, 拼音) + """ + result = [] + count = 0 + + for i in range(start_idx + 1, len(text)): + if count >= max_count: + break + + char = text[i] + if self.is_chinese_char(char): + # 获取拼音(注意:这里需要确保拼音列表长度与text一致) + try: + # 重新计算整个text的拼音可能效率低,但确保准确 + # 实际实现中可以考虑缓存或优化 + pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) + if i < len(pinyin_list): + result.append((char, pinyin_list[i])) + count += 1 + except Exception: + break + else: + # 非汉字,继续查找 + continue + + return result + + def sample_context(self, context: str) -> str: + """ + 三种方式采样前文上下文 + + Args: + context: 原始前文(最多100字符) + + Returns: + 采样后的54个字符 + """ + if not context: + return "" + + # 确保有足够长度 + context_len = len(context) + + # 随机选择采样方式 (各1/3概率) + choice = random.random() + + if choice < 0.333: + # 方式1: 靠近汉字的54个字符 + return context[-54:] if context_len >= 54 else context + elif choice < 0.667: + # 方式2: 随机位置取46个连续字符 + if context_len <= 46: + return context + start = random.randint(0, context_len - 46) + return context[start : start + 46] + else: + # 方式3: 12+6×7组合 + if context_len < 12: + return context + + # 最后12个字符 + last_12 = context[-12:] + + # 从剩下的前88个字符中随机取6段,每段7个字符 + remaining = context[:-12] if context_len > 12 else "" + remaining_len = len(remaining) + + if remaining_len < 7: + # 如果不够7个字符,直接返回最后12个字符 + return last_12 + + segments = [] + for _ in range(6): + if remaining_len < 7: + break + start = random.randint(0, remaining_len - 7) + segment = remaining[start : start + 7] + segments.append(segment) + + # 拼接 + combined = "".join(segments) + result = combined + last_12 + + # 确保总长度为54(可能不足) + if len(result) < 54: + # 如果不够,从前面补一些字符 + needed = 54 - len(result) + if context_len >= needed: + result = context[:needed] + result + + # 截断到54字符 + return result[:54] + + def truncate_pinyin(self, pinyin: str) -> str: + """ + 截断拼音 + + Args: + pinyin: 原始拼音 + + Returns: + 截断后的拼音,可能为空字符串 + """ + if not pinyin: + return "" + + # 随机决定截断方式 + rand_val = random.random() + + if rand_val < 0.1: + # 10%概率截断为空 + return "" + elif rand_val < 0.6: + # 50%概率不截断 + return pinyin + else: + # 40%概率随机截断 + # 均匀分配剩余概率给各种截断长度 + max_len = len(pinyin) + if max_len <= 1: + return pinyin + + # 随机选择截断长度 (1 到 max_len-1) + trunc_len = random.randint(1, max_len - 1) + return pinyin[:trunc_len] + + def process_pinyin_sequence(self, pinyin_list: List[str]) -> str: + """ + 处理拼音序列,逐个截断并拼接 + + Args: + pinyin_list: 拼音列表,长度1-4 + + Returns: + 拼接后的拼音字符串 + """ + result_parts = [] + + for pinyin in pinyin_list: + truncated = self.truncate_pinyin(pinyin) + if not truncated: + # 如果某个拼音截断为空,则停止 + break + result_parts.append(truncated) + + if not result_parts: + return "" + + result = "".join(result_parts) + + # 限制最大长度 + if len(result) > 18: + result = result[:18] + + return result + + def adjust_frequency(self, freq: int) -> int: + """ + 削峰填谷 - 根据频率调整采样 + + Args: + freq: 当前字符频率 + + Returns: + 调整后的采样次数,0表示丢弃 + """ + # 1. 削峰处理(高频字,>= 3000W开始丢弃) + if freq >= self.drop_start_freq: + # 线性丢弃概率:3000W时丢弃概率为0,434748359时丢弃概率为0.8 + # 使用线性插值计算丢弃概率 + if self.max_freq == self.drop_start_freq: + drop_prob = 0.0 # 防止除零 + else: + drop_prob = ( + self.max_drop_prob + * (freq - self.drop_start_freq) + / (self.max_freq - self.drop_start_freq) + ) + + # 根据丢弃概率决定是否保留 + if random.random() < drop_prob: + return 0 # 丢弃该样本 + else: + return 1 # 保留,但不重复 + + # 2. 填谷处理(低频字,<= 1W开始重复) + elif freq <= self.repeat_end_freq: + # 线性重复期望:1W时重复期望为0,109时重复期望为50 + # 使用线性插值计算期望重复次数 + if freq <= self.min_freq: + repeat_expect = self.max_repeat_expect # 最低频字重复期望为50 + else: + if self.repeat_end_freq == self.min_freq: + repeat_expect = 0 # 防止除零 + else: + # 线性插值公式 + repeat_expect = ( + self.max_repeat_expect + * (self.repeat_end_freq - freq) + / (self.repeat_end_freq - self.min_freq) + ) + + # 期望重复次数转换为实际重复次数 + # 使用泊松分布实现期望重复,确保有随机性 + repeat_count = np.random.poisson(repeat_expect) + + # 确保至少返回1次 + return max(1, repeat_count) + + # 3. 中等频率字(1W < freq < 3000W) + else: + return 1 # 保持原样 + + def batch_get_char_info( + self, char_pinyin_pairs: List[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Any]: + """ + 批量获取字符信息 + + Args: + char_pinyin_pairs: [(字符, 拼音), ...] + + Returns: + 字典,key为(字符, 拼音),value为(id, 频率)或None + """ + results = {} + + # 先检查缓存 + uncached_pairs = [] + for pair in char_pinyin_pairs: + if pair in self.char_info_cache: + results[pair] = self.char_info_cache[pair] + else: + uncached_pairs.append(pair) + + # 批量查询未缓存的 + if uncached_pairs: + # 使用query_engine批量查询 + char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs) + for pair, char_info in char_infos.items(): + if char_info: + info = { + "id": char_info.id, + "freq": char_info.count, + "char": char_info.char, + "pinyin": char_info.pinyin, + } + else: + info = None + + results[pair] = info + self.char_info_cache[pair] = info + + return results + + def _process_batch(self, char_pinyin_batch, char_positions, text): + """处理批量字符""" + # 批量查询字符信息 + char_info_map = self.batch_get_char_info(char_pinyin_batch) + + batch_samples = [] + + for pos_info in char_positions: + char = pos_info["char"] + pinyin = pos_info["pinyin"] + next_pinyins = pos_info["next_pinyins"] + context = pos_info["context"] + + # 获取字符信息 + char_info = char_info_map.get((char, pinyin)) + if not char_info: + continue + + # 削峰填谷调整 + adjust_factor = self.adjust_frequency(char_info["freq"]) + if adjust_factor <= 0: + continue + + # 前文采样 + sampled_context = self.sample_context(context) + + # 拼音处理 + processed_pinyin = self.process_pinyin_sequence(next_pinyins) + if not processed_pinyin: + continue + + # Tokenize + hint = self.tokenizer( + sampled_context, + processed_pinyin, + max_length=self.max_len, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + + # 生成样本 + sample = { + "hint": hint, + "txt": sampled_context, + "py": processed_pinyin, + "char_id": torch.tensor([char_info["id"]]), + "char": char, + "freq": char_info["freq"], + } + + # 根据调整因子重复样本 + for _ in range(adjust_factor): + batch_samples.append(sample) + + return batch_samples + + def _shuffle_and_yield(self, batch_samples): + """打乱并yield样本""" + if not self.shuffle: + for sample in batch_samples: + yield sample + return + + # 添加到打乱缓冲区 + self.shuffle_buffer.extend(batch_samples) + + # 如果缓冲区达到指定大小,打乱并输出 + if len(self.shuffle_buffer) >= self.shuffle_buffer_size: + random.shuffle(self.shuffle_buffer) + for sample in self.shuffle_buffer: + yield sample + self.shuffle_buffer = [] + + def __iter__(self): + """ + 迭代器实现,支持多进程 + + 返回: + 生成器,每次返回一个样本 + """ + # 获取worker信息,为每个worker设置不同的随机种子 + worker_info = torch.utils.data.get_worker_info() + + if worker_info is not None: + worker_id = worker_info.id + # 使用base_seed + worker_id确保每个worker有不同但确定的随机序列 + base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42 + seed = base_seed + worker_id + random.seed(seed % (2**32)) + np.random.seed(seed % (2**32)) + + # 重置打乱缓冲区 + self.shuffle_buffer = [] + + for item in self.dataset: + text = item.get(self.text_field, "") + if not text: + continue + + # 转换为拼音列表 + pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) + # 批量收集需要查询的字符信息 + char_pinyin_batch = [] + char_positions = [] # 保存字符位置和上下文信息 + # 遍历文本中的每个字符 + for i, (char, py) in enumerate(zip(text, pinyin_list)): + if not self.is_chinese_char(char): + continue + + # 获取后续最多3个中文字符的拼音 + next_chars = self.get_next_chinese_chars(text, i, max_count=3) + next_pinyins = [py] + [p for _, p in next_chars] + # 获取前文上下文(最多100字符) + context = text[max(0, i - 100) : i] + + # 收集信息用于批量查询 + char_pinyin_batch.append((char, py)) + char_positions.append( + { + "index": i, + "char": char, + "pinyin": py, + "next_pinyins": next_pinyins, + "context": context, + "next_chars": next_chars, + } + ) + + # 达到批量大小时处理 + if len(char_pinyin_batch) >= self.batch_query_size: + batch_samples = self._process_batch( + char_pinyin_batch, char_positions, text + ) + yield from self._shuffle_and_yield(batch_samples) + char_pinyin_batch = [] + char_positions = [] + # 处理剩余的字符 + if char_pinyin_batch: + batch_samples = self._process_batch( + char_pinyin_batch, char_positions, text + ) + yield from self._shuffle_and_yield(batch_samples) + + # 清空缓冲区(处理完所有数据后) + if self.shuffle_buffer: + random.shuffle(self.shuffle_buffer) + for sample in self.shuffle_buffer: + yield sample + self.shuffle_buffer = [] + + def __len__(self): + """ + 由于是流式数据集,无法预先知道长度 + + 返回: + 返回一个估计值或-1 + """ + return -1 + + +# 辅助函数,用于DataLoader +def worker_init_fn(worker_id): + """DataLoader worker初始化函数""" + # 设置每个worker的随机种子 + seed = torch.initial_seed() + worker_id + random.seed(seed % (2**32)) + np.random.seed(seed % (2**32)) + torch.manual_seed(seed % (2**32)) + + +def custom_collate(batch): + """自定义批处理函数""" + if not batch: + return {} + + # 处理hint字段 + hints = [item["hint"] for item in batch] + + # 合并所有张量字段 + result = { + "hint": { + "input_ids": torch.cat([h["input_ids"] for h in hints]), + "attention_mask": torch.cat([h["attention_mask"] for h in hints]), + }, + "char_id": torch.cat([item["char_id"] for item in batch]), + "char": [item["char"] for item in batch], + "txt": [item["txt"] for item in batch], + "py": [item["py"] for item in batch], + } + + # 如果存在token_type_ids则添加 + if "token_type_ids" in hints[0]: + result["hint"]["token_type_ids"] = torch.cat( + [h["token_type_ids"] for h in hints] + ) + + # 如果存在freq则添加 + if "freq" in batch[0]: + result["freq"] = torch.tensor([item["freq"] for item in batch]) + + return result + + +# 使用示例 +if __name__ == "__main__": + from query import QueryEngine + from tqdm import tqdm + + # 初始化查询引擎 + query_engine = QueryEngine() + query_engine.load("./pinyin_char_statistics.json") + + # 创建数据集 + dataset = PinyinInputDataset( + data_dir="/home/songsenand/Data/corpus/CCI-Data/", + query_engine=query_engine, + tokenizer_name="iic/nlp_structbert_backbone_tiny_std", + max_len=88, + batch_query_size=500, + shuffle=True, + shuffle_buffer_size=1000, + ) + + logger.info("数据集初始化") + dataloader = DataLoader( + dataset, + batch_size=256, + num_workers=12, + worker_init_fn=worker_init_fn, + # pin_memory=True, + collate_fn=custom_collate, + prefetch_factor=32, + persistent_workers=True, + shuffle=False, # 数据集内部已实现打乱 + timeout=60, + ) + + # 测试数据集 + try: + iterator = iter(dataset) + logger.info("测试数据集") + for i, _ in tqdm(enumerate(dataloader), total=200): + if i >= 200: + break + """ + print(f"Sample {i+1}:") + print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}") + print(f" Pinyin: {sample['py']}") + print(f" Context length: {len(sample['txt'])}") + print(f" Hint shape: {sample['hint']['input_ids'].shape}") + print() + """ + except StopIteration: + print("数据集为空") diff --git a/src/suinput/query.py b/src/suinput/query.py index eaa26d5..3fd590c 100644 --- a/src/suinput/query.py +++ b/src/suinput/query.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Any import time import os -from .char_info import CharInfo, PinyinCharPairsCounter +from char_info import CharInfo, PinyinCharPairsCounter class QueryEngine: @@ -30,6 +30,7 @@ class QueryEngine: self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 + self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {} # 辅助索引 - 快速获取详细信息 self._char_freq: Dict[str, int] = {} # 字符总频率 @@ -62,7 +63,7 @@ class QueryEngine: start_time = time.time() # 读取并解析文件 - self._counter_data = self._parse_file(file_path) + self._counter_data = self.parse_file(file_path) # 构建索引 self._build_indices() @@ -72,7 +73,7 @@ class QueryEngine: return self._counter_data.metadata - def _parse_file(self, file_path: str) -> PinyinCharPairsCounter: + def parse_file(self, file_path: str) -> PinyinCharPairsCounter: """解析文件,支持多种格式""" with open(file_path, 'rb') as f: data = f.read() @@ -130,6 +131,7 @@ class QueryEngine: self._id_to_info.clear() self._char_to_ids.clear() self._pinyin_to_ids.clear() + self._char_pinyin_to_ids.clear() self._char_freq.clear() self._pinyin_freq.clear() self._char_pinyin_map.clear() @@ -161,6 +163,7 @@ class QueryEngine: # 字符-拼音映射 self._char_pinyin_map[(char, pinyin)] = char_info.count + self._char_pinyin_to_ids[(char, pinyin)] = char_info_id self._total_pairs = len(self._id_to_info) self._index_time = time.time() - start_time @@ -295,7 +298,45 @@ class QueryEngine: raise RuntimeError("数据未加载,请先调用load()方法") return self._char_pinyin_map.get((char, pinyin), 0) - + + def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]: + """获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度 + + Args: + char: 汉字字符 + pinyin: 拼音字符串 + + Returns: + ID和频率 + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None) + return self.query_by_id(char_info_id) + + def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]: + """批量获取汉字-拼音信息 + + Args: + pairs: 汉字-拼音列表 + + Returns: + 字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None) + """ + if not self._loaded: + raise RuntimeError("数据未加载,请先调用load()方法") + + result = {} + for pair in pairs: + char_info_id = self._char_pinyin_to_ids.get(pair) + if char_info_id is not None: + result[pair] = self._id_to_info.get(char_info_id) + else: + result[pair] = None + return result + + def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]: """ 批量ID查询 - O(n)时间复杂度 @@ -311,10 +352,10 @@ class QueryEngine: results = {} for id_value in ids: - results[id_value] = self._id_to_info.get(id_value) + results[id_value] = self._id_to_info.get(id_value, None) return results - + def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]: """ 批量字符查询 @@ -408,6 +449,7 @@ class QueryEngine: self._char_freq.clear() self._pinyin_freq.clear() self._char_pinyin_map.clear() + self._char_pinyin_to_ids.clear() self._loaded = False self._total_pairs = 0 self._load_time = 0.0 diff --git a/test/test_query.py b/test/test_query.py new file mode 100644 index 0000000..b130c4f --- /dev/null +++ b/test/test_query.py @@ -0,0 +1,151 @@ +# test_query_engine.py +import pytest +import tempfile +import os +import json +from suinput.query import QueryEngine +from suinput.char_info import CharInfo, PinyinCharPairsCounter + +# 将测试数据保存为 JSON 文件 +@pytest.fixture +def json_file_path(): + yield "pinyin_char_statistics.json" + +# 测试 QueryEngine 的基本功能 +class TestQueryEngine: + def test_load_from_json(self, json_file_path): + """测试从 JSON 文件加载数据""" + engine = QueryEngine() + metadata = engine.load(json_file_path) + + assert engine.is_loaded() is True + assert metadata["format"] == "json" + assert metadata["pair_count"] == 20646 + + def test_query_by_id(self, json_file_path): + """测试通过 ID 查询字符信息""" + engine = QueryEngine() + engine.load(json_file_path) + + result = engine.query_by_id(8) + assert result is not None + assert result.char == "中" + assert result.pinyin == "zhong" + assert result.count == 73927282 + + result = engine.query_by_id(100000) # 不存在的 ID + assert result is None + + def test_query_by_char(self, json_file_path): + """测试通过字符查询拼音信息""" + engine = QueryEngine() + engine.load(json_file_path) + + results = engine.query_by_char("长") + assert len(results) == 2 + assert results[0] == (159, "zhang", 15424264) + assert results[1] == (414, "chang", 6663465) + + results_limited = engine.query_by_char("长", limit=1) + assert len(results_limited) == 1 + assert results_limited[0] == (159, "zhang", 15424264) + + results_empty = engine.query_by_char("X") # 不存在的字符 + assert results_empty == [] + + def test_query_by_pinyin(self, json_file_path): + """测试通过拼音查询字符信息""" + engine = QueryEngine() + engine.load(json_file_path) + + results = engine.query_by_pinyin("zhong") + assert len(results) == 57 + assert results[0] == (8, "中", 73927282) + + results_empty = engine.query_by_pinyin("xxx") # 不存在的拼音 + assert results_empty == [] + + def test_get_char_frequency(self, json_file_path): + """测试获取字符总频率""" + engine = QueryEngine() + engine.load(json_file_path) + + freq = engine.get_char_frequency("中") + assert freq == 73927282 + + freq_zero = engine.get_char_frequency("X") # 不存在的字符 + assert freq_zero == 0 + + def test_get_pinyin_frequency(self, json_file_path): + """测试获取拼音总频率""" + engine = QueryEngine() + engine.load(json_file_path) + + freq = engine.get_pinyin_frequency("zhong") + assert freq == 136246123 + + freq_zero = engine.get_pinyin_frequency("xxx") # 不存在的拼音 + assert freq_zero == 0 + + def test_get_char_pinyin_count(self, json_file_path): + """测试获取字符-拼音对的出现次数""" + engine = QueryEngine() + engine.load(json_file_path) + + count = engine.get_char_pinyin_count("中", "zhong") + assert count == 73927282 + + count_zero = engine.get_char_pinyin_count("中", "xxx") # 不存在的拼音 + assert count_zero == 0 + + def test_batch_query_by_ids(self, json_file_path): + """测试批量 ID 查询""" + engine = QueryEngine() + engine.load(json_file_path) + + results = engine.batch_query_by_ids([8, 9, 10000000]) + assert len(results) == 3 + assert results[9].char == "为" + + def test_search_chars_by_prefix(self, json_file_path): + """测试根据字符前缀搜索""" + engine = QueryEngine() + engine.load(json_file_path) + + results = engine.search_chars_by_prefix("中") + assert len(results) == 1 + assert results[0] == ("中", 73927282) + + results_empty = engine.search_chars_by_prefix("X") # 不存在的前缀 + assert results_empty == [] + + def test_get_statistics(self, json_file_path): + """测试获取统计信息""" + engine = QueryEngine() + engine.load(json_file_path) + + stats = engine.get_statistics() + assert stats["status"] == "loaded" + assert stats["total_pairs"] == 20646 + assert stats["total_characters"] == 18240 + assert stats["top_chars"][0] == ("的", 439524694) + + def test_clear(self, json_file_path): + """测试清除数据""" + engine = QueryEngine() + engine.load(json_file_path) + assert engine.is_loaded() is True + + engine.clear() + assert engine.is_loaded() is False + assert engine.get_statistics()["status"] == "not_loaded" + + def test_batch_get_char_pinyin_info(self, json_file_path): + engine = QueryEngine() + engine.load(json_file_path) + assert engine.is_loaded() is True + + pairs = engine.batch_get_char_pinyin_info([("我", "wo"), ("你", "ni"), ("他", "ta")]) + assert pairs[("我", "wo")] == engine.get_char_info_by_char_pinyin("我", "wo") + assert pairs[("你", "ni")] == engine.get_char_info_by_char_pinyin("你", "ni") + assert pairs[("他", "ta")] == engine.get_char_info_by_char_pinyin("他", "ta") \ No newline at end of file