151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
# test_query_engine.py
|
|
import pytest
|
|
import tempfile
|
|
import os
|
|
import json
|
|
from suinput.query import QueryEngine
|
|
from suinput.char_info import CharInfo, PinyinCharPairsCounter
|
|
|
|
# 将测试数据保存为 JSON 文件
|
|
@pytest.fixture
|
|
def json_file_path():
|
|
yield "pinyin_char_statistics.json"
|
|
|
|
# 测试 QueryEngine 的基本功能
|
|
class TestQueryEngine:
|
|
def test_load_from_json(self, json_file_path):
|
|
"""测试从 JSON 文件加载数据"""
|
|
engine = QueryEngine()
|
|
metadata = engine.load(json_file_path)
|
|
|
|
assert engine.is_loaded() is True
|
|
assert metadata["format"] == "json"
|
|
assert metadata["pair_count"] == 20646
|
|
|
|
def test_query_by_id(self, json_file_path):
|
|
"""测试通过 ID 查询字符信息"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
result = engine.query_by_id(8)
|
|
assert result is not None
|
|
assert result.char == "中"
|
|
assert result.pinyin == "zhong"
|
|
assert result.count == 73927282
|
|
|
|
result = engine.query_by_id(100000) # 不存在的 ID
|
|
assert result is None
|
|
|
|
def test_query_by_char(self, json_file_path):
|
|
"""测试通过字符查询拼音信息"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
results = engine.query_by_char("长")
|
|
assert len(results) == 2
|
|
assert results[0] == (159, "zhang", 15424264)
|
|
assert results[1] == (414, "chang", 6663465)
|
|
|
|
results_limited = engine.query_by_char("长", limit=1)
|
|
assert len(results_limited) == 1
|
|
assert results_limited[0] == (159, "zhang", 15424264)
|
|
|
|
results_empty = engine.query_by_char("X") # 不存在的字符
|
|
assert results_empty == []
|
|
|
|
def test_query_by_pinyin(self, json_file_path):
|
|
"""测试通过拼音查询字符信息"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
results = engine.query_by_pinyin("zhong")
|
|
assert len(results) == 57
|
|
assert results[0] == (8, "中", 73927282)
|
|
|
|
results_empty = engine.query_by_pinyin("xxx") # 不存在的拼音
|
|
assert results_empty == []
|
|
|
|
def test_get_char_frequency(self, json_file_path):
|
|
"""测试获取字符总频率"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
freq = engine.get_char_frequency("中")
|
|
assert freq == 73927282
|
|
|
|
freq_zero = engine.get_char_frequency("X") # 不存在的字符
|
|
assert freq_zero == 0
|
|
|
|
def test_get_pinyin_frequency(self, json_file_path):
|
|
"""测试获取拼音总频率"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
freq = engine.get_pinyin_frequency("zhong")
|
|
assert freq == 136246123
|
|
|
|
freq_zero = engine.get_pinyin_frequency("xxx") # 不存在的拼音
|
|
assert freq_zero == 0
|
|
|
|
def test_get_char_pinyin_count(self, json_file_path):
|
|
"""测试获取字符-拼音对的出现次数"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
count = engine.get_char_pinyin_count("中", "zhong")
|
|
assert count == 73927282
|
|
|
|
count_zero = engine.get_char_pinyin_count("中", "xxx") # 不存在的拼音
|
|
assert count_zero == 0
|
|
|
|
def test_batch_query_by_ids(self, json_file_path):
|
|
"""测试批量 ID 查询"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
results = engine.batch_query_by_ids([8, 9, 10000000])
|
|
assert len(results) == 3
|
|
assert results[9].char == "为"
|
|
|
|
def test_search_chars_by_prefix(self, json_file_path):
|
|
"""测试根据字符前缀搜索"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
results = engine.search_chars_by_prefix("中")
|
|
assert len(results) == 1
|
|
assert results[0] == ("中", 73927282)
|
|
|
|
results_empty = engine.search_chars_by_prefix("X") # 不存在的前缀
|
|
assert results_empty == []
|
|
|
|
def test_get_statistics(self, json_file_path):
|
|
"""测试获取统计信息"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
|
|
stats = engine.get_statistics()
|
|
assert stats["status"] == "loaded"
|
|
assert stats["total_pairs"] == 20646
|
|
assert stats["total_characters"] == 18240
|
|
assert stats["top_chars"][0] == ("的", 439524694)
|
|
|
|
def test_clear(self, json_file_path):
|
|
"""测试清除数据"""
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
assert engine.is_loaded() is True
|
|
|
|
engine.clear()
|
|
assert engine.is_loaded() is False
|
|
assert engine.get_statistics()["status"] == "not_loaded"
|
|
|
|
def test_batch_get_char_pinyin_info(self, json_file_path):
|
|
engine = QueryEngine()
|
|
engine.load(json_file_path)
|
|
assert engine.is_loaded() is True
|
|
|
|
pairs = engine.batch_get_char_pinyin_info([("我", "wo"), ("你", "ni"), ("他", "ta")])
|
|
assert pairs[("我", "wo")] == engine.get_char_info_by_char_pinyin("我", "wo")
|
|
assert pairs[("你", "ni")] == engine.get_char_info_by_char_pinyin("你", "ni")
|
|
assert pairs[("他", "ta")] == engine.get_char_info_by_char_pinyin("他", "ta") |