feat: 添加拼音输入法模拟数据集及相关功能实现

This commit is contained in:
songsenand 2026-02-09 00:43:38 +08:00
parent 5ea0b0b31c
commit f2c260de72
6 changed files with 833 additions and 7 deletions

3
.gitignore vendored
View File

@ -213,4 +213,5 @@ cython_debug/
uv.lock
*.log
marimo/
marimo/
__marimo__/

View File

@ -11,6 +11,7 @@ dependencies = [
"msgpack>=1.1.2",
"pypinyin>=0.55.0",
"rich>=14.3.1",
"transformers>=5.1.0",
"typer>=0.21.1",
]

View File

@ -16,6 +16,7 @@ from loguru import logger
from tqdm import trange
from .char_info import PinyinCharPairsCounter, CharInfo
from .query import QueryEngine
@ -48,6 +49,12 @@ class PinyinCharStatistics:
# 启动工作进程
self._start_workers()
# 实现一个加载历史快照,并且以历史快照的数据进行初始化的函数
def load_history_snapshot(self, file_path: str):
"""加载历史快照,并且以历史快照的数据进行初始化"""
self
def _start_workers(self):
"""启动工作进程"""

624
src/suinput/dataset.py Normal file
View File

@ -0,0 +1,624 @@
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from pypinyin import lazy_pinyin
import random
from modelscope import AutoTokenizer
from typing import Tuple, List, Dict, Any
import re
import numpy as np
from loguru import logger
class PinyinInputDataset(IterableDataset):
"""
拼音输入法模拟数据集
特性:
1. 流式读取数据集内存友好
2. 实时拼音转换和多音字处理
3. 前文上下文多种采样方式
4. 拼音截断模拟不完整输入
5. 内置削峰填谷算法平衡数据分布
6. 缓冲区打乱支持多进程
"""
def __init__(
self,
data_dir: str,
query_engine,
tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std",
max_len: int = 88,
text_field: str = "text",
batch_query_size: int = 1000,
# 打乱参数
shuffle: bool = True,
shuffle_buffer_size: int = 100,
# 削峰填谷参数
max_freq: int = 434748359, # "的"的频率
min_freq: int = 109, # "蓚"的频率
drop_start_freq: int = 30000000, # 开始丢弃的阈值
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
):
"""
初始化数据集
Args:
data_dir: 数据集目录
query_engine: QueryEngine实例
tokenizer_name: tokenizer名称
max_len: 最大序列长度
text_field: 文本字段名
batch_query_size: 批量查询大小
shuffle: 是否打乱数据
shuffle_buffer_size: 打乱缓冲区大小
max_freq: 最大字符频率
min_freq: 最小字符频率
drop_start_freq: 开始削峰的频率阈值
repeat_end_freq: 开始填谷的频率阈值
max_drop_prob: 最高频率字符的丢弃概率
max_repeat_expect: 最低频率字符的重复期望
"""
self.query_engine = query_engine
self.max_len = max_len
self.text_field = text_field
self.batch_query_size = batch_query_size
# 打乱相关参数
self.shuffle = shuffle
self.shuffle_buffer_size = shuffle_buffer_size
self.shuffle_buffer = []
# 削峰填谷参数
self.max_freq = max_freq
self.min_freq = min_freq
self.drop_start_freq = drop_start_freq
self.repeat_end_freq = repeat_end_freq
self.max_drop_prob = max_drop_prob
self.max_repeat_expect = max_repeat_expect
# 加载tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# 获取总字频用于后续计算
stats = query_engine.get_statistics()
self.total_chars = stats.get("valid_input_character_count", 0)
# 汉字正则表达式
self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
# 缓存字典
self.char_info_cache = {}
# 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True)
def is_chinese_char(self, char: str) -> bool:
"""判断是否为中文字符"""
return bool(self.chinese_pattern.match(char))
def get_next_chinese_chars(
self, text: str, start_idx: int, max_count: int = 3
) -> List[Tuple[str, str]]:
"""
获取后续的中文字符及其拼音
Args:
text: 完整文本
start_idx: 起始位置
max_count: 最大字符数
Returns:
列表每个元素为(字符, 拼音)
"""
result = []
count = 0
for i in range(start_idx + 1, len(text)):
if count >= max_count:
break
char = text[i]
if self.is_chinese_char(char):
# 获取拼音注意这里需要确保拼音列表长度与text一致
try:
# 重新计算整个text的拼音可能效率低但确保准确
# 实际实现中可以考虑缓存或优化
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
if i < len(pinyin_list):
result.append((char, pinyin_list[i]))
count += 1
except Exception:
break
else:
# 非汉字,继续查找
continue
return result
def sample_context(self, context: str) -> str:
"""
三种方式采样前文上下文
Args:
context: 原始前文最多100字符
Returns:
采样后的54个字符
"""
if not context:
return ""
# 确保有足够长度
context_len = len(context)
# 随机选择采样方式 (各1/3概率)
choice = random.random()
if choice < 0.333:
# 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context
elif choice < 0.667:
# 方式2: 随机位置取46个连续字符
if context_len <= 46:
return context
start = random.randint(0, context_len - 46)
return context[start : start + 46]
else:
# 方式3: 12+6×7组合
if context_len < 12:
return context
# 最后12个字符
last_12 = context[-12:]
# 从剩下的前88个字符中随机取6段每段7个字符
remaining = context[:-12] if context_len > 12 else ""
remaining_len = len(remaining)
if remaining_len < 7:
# 如果不够7个字符直接返回最后12个字符
return last_12
segments = []
for _ in range(6):
if remaining_len < 7:
break
start = random.randint(0, remaining_len - 7)
segment = remaining[start : start + 7]
segments.append(segment)
# 拼接
combined = "".join(segments)
result = combined + last_12
# 确保总长度为54可能不足
if len(result) < 54:
# 如果不够,从前面补一些字符
needed = 54 - len(result)
if context_len >= needed:
result = context[:needed] + result
# 截断到54字符
return result[:54]
def truncate_pinyin(self, pinyin: str) -> str:
"""
截断拼音
Args:
pinyin: 原始拼音
Returns:
截断后的拼音可能为空字符串
"""
if not pinyin:
return ""
# 随机决定截断方式
rand_val = random.random()
if rand_val < 0.1:
# 10%概率截断为空
return ""
elif rand_val < 0.6:
# 50%概率不截断
return pinyin
else:
# 40%概率随机截断
# 均匀分配剩余概率给各种截断长度
max_len = len(pinyin)
if max_len <= 1:
return pinyin
# 随机选择截断长度 (1 到 max_len-1)
trunc_len = random.randint(1, max_len - 1)
return pinyin[:trunc_len]
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
"""
处理拼音序列逐个截断并拼接
Args:
pinyin_list: 拼音列表长度1-4
Returns:
拼接后的拼音字符串
"""
result_parts = []
for pinyin in pinyin_list:
truncated = self.truncate_pinyin(pinyin)
if not truncated:
# 如果某个拼音截断为空,则停止
break
result_parts.append(truncated)
if not result_parts:
return ""
result = "".join(result_parts)
# 限制最大长度
if len(result) > 18:
result = result[:18]
return result
def adjust_frequency(self, freq: int) -> int:
"""
削峰填谷 - 根据频率调整采样
Args:
freq: 当前字符频率
Returns:
调整后的采样次数0表示丢弃
"""
# 1. 削峰处理(高频字,>= 3000W开始丢弃
if freq >= self.drop_start_freq:
# 线性丢弃概率3000W时丢弃概率为0434748359时丢弃概率为0.8
# 使用线性插值计算丢弃概率
if self.max_freq == self.drop_start_freq:
drop_prob = 0.0 # 防止除零
else:
drop_prob = (
self.max_drop_prob
* (freq - self.drop_start_freq)
/ (self.max_freq - self.drop_start_freq)
)
# 根据丢弃概率决定是否保留
if random.random() < drop_prob:
return 0 # 丢弃该样本
else:
return 1 # 保留,但不重复
# 2. 填谷处理(低频字,<= 1W开始重复
elif freq <= self.repeat_end_freq:
# 线性重复期望1W时重复期望为0109时重复期望为50
# 使用线性插值计算期望重复次数
if freq <= self.min_freq:
repeat_expect = self.max_repeat_expect # 最低频字重复期望为50
else:
if self.repeat_end_freq == self.min_freq:
repeat_expect = 0 # 防止除零
else:
# 线性插值公式
repeat_expect = (
self.max_repeat_expect
* (self.repeat_end_freq - freq)
/ (self.repeat_end_freq - self.min_freq)
)
# 期望重复次数转换为实际重复次数
# 使用泊松分布实现期望重复,确保有随机性
repeat_count = np.random.poisson(repeat_expect)
# 确保至少返回1次
return max(1, repeat_count)
# 3. 中等频率字1W < freq < 3000W
else:
return 1 # 保持原样
def batch_get_char_info(
self, char_pinyin_pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], Any]:
"""
批量获取字符信息
Args:
char_pinyin_pairs: [(字符, 拼音), ...]
Returns:
字典key为(字符, 拼音)value为(id, 频率)或None
"""
results = {}
# 先检查缓存
uncached_pairs = []
for pair in char_pinyin_pairs:
if pair in self.char_info_cache:
results[pair] = self.char_info_cache[pair]
else:
uncached_pairs.append(pair)
# 批量查询未缓存的
if uncached_pairs:
# 使用query_engine批量查询
char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs)
for pair, char_info in char_infos.items():
if char_info:
info = {
"id": char_info.id,
"freq": char_info.count,
"char": char_info.char,
"pinyin": char_info.pinyin,
}
else:
info = None
results[pair] = info
self.char_info_cache[pair] = info
return results
def _process_batch(self, char_pinyin_batch, char_positions, text):
"""处理批量字符"""
# 批量查询字符信息
char_info_map = self.batch_get_char_info(char_pinyin_batch)
batch_samples = []
for pos_info in char_positions:
char = pos_info["char"]
pinyin = pos_info["pinyin"]
next_pinyins = pos_info["next_pinyins"]
context = pos_info["context"]
# 获取字符信息
char_info = char_info_map.get((char, pinyin))
if not char_info:
continue
# 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0:
continue
# 前文采样
sampled_context = self.sample_context(context)
# 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
if not processed_pinyin:
continue
# Tokenize
hint = self.tokenizer(
sampled_context,
processed_pinyin,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# 生成样本
sample = {
"hint": hint,
"txt": sampled_context,
"py": processed_pinyin,
"char_id": torch.tensor([char_info["id"]]),
"char": char,
"freq": char_info["freq"],
}
# 根据调整因子重复样本
for _ in range(adjust_factor):
batch_samples.append(sample)
return batch_samples
def _shuffle_and_yield(self, batch_samples):
"""打乱并yield样本"""
if not self.shuffle:
for sample in batch_samples:
yield sample
return
# 添加到打乱缓冲区
self.shuffle_buffer.extend(batch_samples)
# 如果缓冲区达到指定大小,打乱并输出
if len(self.shuffle_buffer) >= self.shuffle_buffer_size:
random.shuffle(self.shuffle_buffer)
for sample in self.shuffle_buffer:
yield sample
self.shuffle_buffer = []
def __iter__(self):
"""
迭代器实现支持多进程
返回:
生成器每次返回一个样本
"""
# 获取worker信息为每个worker设置不同的随机种子
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
# 使用base_seed + worker_id确保每个worker有不同但确定的随机序列
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
seed = base_seed + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
# 重置打乱缓冲区
self.shuffle_buffer = []
for item in self.dataset:
text = item.get(self.text_field, "")
if not text:
continue
# 转换为拼音列表
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息
char_pinyin_batch = []
char_positions = [] # 保存字符位置和上下文信息
# 遍历文本中的每个字符
for i, (char, py) in enumerate(zip(text, pinyin_list)):
if not self.is_chinese_char(char):
continue
# 获取后续最多3个中文字符的拼音
next_chars = self.get_next_chinese_chars(text, i, max_count=3)
next_pinyins = [py] + [p for _, p in next_chars]
# 获取前文上下文最多100字符
context = text[max(0, i - 100) : i]
# 收集信息用于批量查询
char_pinyin_batch.append((char, py))
char_positions.append(
{
"index": i,
"char": char,
"pinyin": py,
"next_pinyins": next_pinyins,
"context": context,
"next_chars": next_chars,
}
)
# 达到批量大小时处理
if len(char_pinyin_batch) >= self.batch_query_size:
batch_samples = self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
char_pinyin_batch = []
char_positions = []
# 处理剩余的字符
if char_pinyin_batch:
batch_samples = self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
# 清空缓冲区(处理完所有数据后)
if self.shuffle_buffer:
random.shuffle(self.shuffle_buffer)
for sample in self.shuffle_buffer:
yield sample
self.shuffle_buffer = []
def __len__(self):
"""
由于是流式数据集无法预先知道长度
返回:
返回一个估计值或-1
"""
return -1
# 辅助函数用于DataLoader
def worker_init_fn(worker_id):
"""DataLoader worker初始化函数"""
# 设置每个worker的随机种子
seed = torch.initial_seed() + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
torch.manual_seed(seed % (2**32))
def custom_collate(batch):
"""自定义批处理函数"""
if not batch:
return {}
# 处理hint字段
hints = [item["hint"] for item in batch]
# 合并所有张量字段
result = {
"hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"char": [item["char"] for item in batch],
"txt": [item["txt"] for item in batch],
"py": [item["py"] for item in batch],
}
# 如果存在token_type_ids则添加
if "token_type_ids" in hints[0]:
result["hint"]["token_type_ids"] = torch.cat(
[h["token_type_ids"] for h in hints]
)
# 如果存在freq则添加
if "freq" in batch[0]:
result["freq"] = torch.tensor([item["freq"] for item in batch])
return result
# 使用示例
if __name__ == "__main__":
from query import QueryEngine
from tqdm import tqdm
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load("./pinyin_char_statistics.json")
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=500,
shuffle=True,
shuffle_buffer_size=1000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=256,
num_workers=12,
worker_init_fn=worker_init_fn,
# pin_memory=True,
collate_fn=custom_collate,
prefetch_factor=32,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
timeout=60,
)
# 测试数据集
try:
iterator = iter(dataset)
logger.info("测试数据集")
for i, _ in tqdm(enumerate(dataloader), total=200):
if i >= 200:
break
"""
print(f"Sample {i+1}:")
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
print(f" Pinyin: {sample['py']}")
print(f" Context length: {len(sample['txt'])}")
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
print()
"""
except StopIteration:
print("数据集为空")

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Any
import time
import os
from .char_info import CharInfo, PinyinCharPairsCounter
from char_info import CharInfo, PinyinCharPairsCounter
class QueryEngine:
@ -30,6 +30,7 @@ class QueryEngine:
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
# 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率
@ -62,7 +63,7 @@ class QueryEngine:
start_time = time.time()
# 读取并解析文件
self._counter_data = self._parse_file(file_path)
self._counter_data = self.parse_file(file_path)
# 构建索引
self._build_indices()
@ -72,7 +73,7 @@ class QueryEngine:
return self._counter_data.metadata
def _parse_file(self, file_path: str) -> PinyinCharPairsCounter:
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式"""
with open(file_path, 'rb') as f:
data = f.read()
@ -130,6 +131,7 @@ class QueryEngine:
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
@ -161,6 +163,7 @@ class QueryEngine:
# 字符-拼音映射
self._char_pinyin_map[(char, pinyin)] = char_info.count
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
self._total_pairs = len(self._id_to_info)
self._index_time = time.time() - start_time
@ -295,7 +298,45 @@ class QueryEngine:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_pinyin_map.get((char, pinyin), 0)
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
ID和频率
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
return self.query_by_id(char_info_id)
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
"""批量获取汉字-拼音信息
Args:
pairs: 汉字-拼音列表
Returns:
字典key为汉字-拼音对value为CharInfo对象不存在则为None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
result = {}
for pair in pairs:
char_info_id = self._char_pinyin_to_ids.get(pair)
if char_info_id is not None:
result[pair] = self._id_to_info.get(char_info_id)
else:
result[pair] = None
return result
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
"""
批量ID查询 - O(n)时间复杂度
@ -311,10 +352,10 @@ class QueryEngine:
results = {}
for id_value in ids:
results[id_value] = self._id_to_info.get(id_value)
results[id_value] = self._id_to_info.get(id_value, None)
return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
"""
批量字符查询
@ -408,6 +449,7 @@ class QueryEngine:
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
self._char_pinyin_to_ids.clear()
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0

151
test/test_query.py Normal file
View File

@ -0,0 +1,151 @@
# test_query_engine.py
import pytest
import tempfile
import os
import json
from suinput.query import QueryEngine
from suinput.char_info import CharInfo, PinyinCharPairsCounter
# 将测试数据保存为 JSON 文件
@pytest.fixture
def json_file_path():
yield "pinyin_char_statistics.json"
# 测试 QueryEngine 的基本功能
class TestQueryEngine:
def test_load_from_json(self, json_file_path):
"""测试从 JSON 文件加载数据"""
engine = QueryEngine()
metadata = engine.load(json_file_path)
assert engine.is_loaded() is True
assert metadata["format"] == "json"
assert metadata["pair_count"] == 20646
def test_query_by_id(self, json_file_path):
"""测试通过 ID 查询字符信息"""
engine = QueryEngine()
engine.load(json_file_path)
result = engine.query_by_id(8)
assert result is not None
assert result.char == ""
assert result.pinyin == "zhong"
assert result.count == 73927282
result = engine.query_by_id(100000) # 不存在的 ID
assert result is None
def test_query_by_char(self, json_file_path):
"""测试通过字符查询拼音信息"""
engine = QueryEngine()
engine.load(json_file_path)
results = engine.query_by_char("")
assert len(results) == 2
assert results[0] == (159, "zhang", 15424264)
assert results[1] == (414, "chang", 6663465)
results_limited = engine.query_by_char("", limit=1)
assert len(results_limited) == 1
assert results_limited[0] == (159, "zhang", 15424264)
results_empty = engine.query_by_char("X") # 不存在的字符
assert results_empty == []
def test_query_by_pinyin(self, json_file_path):
"""测试通过拼音查询字符信息"""
engine = QueryEngine()
engine.load(json_file_path)
results = engine.query_by_pinyin("zhong")
assert len(results) == 57
assert results[0] == (8, "", 73927282)
results_empty = engine.query_by_pinyin("xxx") # 不存在的拼音
assert results_empty == []
def test_get_char_frequency(self, json_file_path):
"""测试获取字符总频率"""
engine = QueryEngine()
engine.load(json_file_path)
freq = engine.get_char_frequency("")
assert freq == 73927282
freq_zero = engine.get_char_frequency("X") # 不存在的字符
assert freq_zero == 0
def test_get_pinyin_frequency(self, json_file_path):
"""测试获取拼音总频率"""
engine = QueryEngine()
engine.load(json_file_path)
freq = engine.get_pinyin_frequency("zhong")
assert freq == 136246123
freq_zero = engine.get_pinyin_frequency("xxx") # 不存在的拼音
assert freq_zero == 0
def test_get_char_pinyin_count(self, json_file_path):
"""测试获取字符-拼音对的出现次数"""
engine = QueryEngine()
engine.load(json_file_path)
count = engine.get_char_pinyin_count("", "zhong")
assert count == 73927282
count_zero = engine.get_char_pinyin_count("", "xxx") # 不存在的拼音
assert count_zero == 0
def test_batch_query_by_ids(self, json_file_path):
"""测试批量 ID 查询"""
engine = QueryEngine()
engine.load(json_file_path)
results = engine.batch_query_by_ids([8, 9, 10000000])
assert len(results) == 3
assert results[9].char == ""
def test_search_chars_by_prefix(self, json_file_path):
"""测试根据字符前缀搜索"""
engine = QueryEngine()
engine.load(json_file_path)
results = engine.search_chars_by_prefix("")
assert len(results) == 1
assert results[0] == ("", 73927282)
results_empty = engine.search_chars_by_prefix("X") # 不存在的前缀
assert results_empty == []
def test_get_statistics(self, json_file_path):
"""测试获取统计信息"""
engine = QueryEngine()
engine.load(json_file_path)
stats = engine.get_statistics()
assert stats["status"] == "loaded"
assert stats["total_pairs"] == 20646
assert stats["total_characters"] == 18240
assert stats["top_chars"][0] == ("", 439524694)
def test_clear(self, json_file_path):
"""测试清除数据"""
engine = QueryEngine()
engine.load(json_file_path)
assert engine.is_loaded() is True
engine.clear()
assert engine.is_loaded() is False
assert engine.get_statistics()["status"] == "not_loaded"
def test_batch_get_char_pinyin_info(self, json_file_path):
engine = QueryEngine()
engine.load(json_file_path)
assert engine.is_loaded() is True
pairs = engine.batch_get_char_pinyin_info([("", "wo"), ("", "ni"), ("", "ta")])
assert pairs[("", "wo")] == engine.get_char_info_by_char_pinyin("", "wo")
assert pairs[("", "ni")] == engine.get_char_info_by_char_pinyin("", "ni")
assert pairs[("", "ta")] == engine.get_char_info_by_char_pinyin("", "ta")