Compare commits
No commits in common. "5b1c6fcb2b0b3556d6032fb7be082ba45f991d68" and "ea54c7da393c68c1e214541216fba47a410d1b96" have entirely different histories.
5b1c6fcb2b
...
ea54c7da39
|
|
@ -1,9 +1,8 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -13,8 +12,8 @@ from modelscope import AutoTokenizer
|
|||
from pypinyin import lazy_pinyin
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
class PinyinInputDataset(IterableDataset):
|
||||
"""
|
||||
|
|
@ -47,7 +46,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
py_group_json_file: Optional[Dict[str, int]] = None,
|
||||
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
|
||||
):
|
||||
"""
|
||||
初始化数据集
|
||||
|
|
@ -98,32 +97,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 加载数据集
|
||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||
|
||||
# 加载拼音分组
|
||||
self.pg_groups = {
|
||||
"y": 0,
|
||||
"k": 0,
|
||||
"e": 0,
|
||||
"l": 1,
|
||||
"w": 1,
|
||||
"f": 1,
|
||||
"q": 2,
|
||||
"a": 2,
|
||||
"s": 2,
|
||||
"x": 3,
|
||||
"b": 3,
|
||||
"r": 3,
|
||||
"o": 4,
|
||||
"m": 4,
|
||||
"z": 4,
|
||||
"g": 5,
|
||||
"n": 5,
|
||||
"c": 5,
|
||||
"t": 6,
|
||||
"p": 6,
|
||||
"d": 6,
|
||||
"j": 7,
|
||||
"h": 7,
|
||||
}
|
||||
with open(py_group_json_file, "r") as f:
|
||||
self.py_groups = json.load(f)
|
||||
|
||||
|
||||
def get_next_chinese_chars(
|
||||
self,
|
||||
|
|
@ -414,8 +390,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
char_info = char_info_map.get((char, pinyin))
|
||||
if not char_info:
|
||||
continue
|
||||
|
||||
logger.info(f"获取字符信息: {char_info}")
|
||||
|
||||
logger.info(f'获取字符信息: {char_info}')
|
||||
# 削峰填谷调整
|
||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||
if adjust_factor <= 0:
|
||||
|
|
@ -426,6 +402,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 拼音处理
|
||||
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||||
if not processed_pinyin:
|
||||
continue
|
||||
|
||||
# Tokenize
|
||||
hint = self.tokenizer(
|
||||
|
|
@ -445,9 +423,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
"char_id": torch.tensor([char_info["id"]]),
|
||||
"char": char,
|
||||
"freq": char_info["freq"],
|
||||
"pg": torch.tensor(
|
||||
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||
),
|
||||
}
|
||||
|
||||
# 根据调整因子重复样本
|
||||
|
|
@ -461,7 +436,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not self.shuffle:
|
||||
yield from batch_samples
|
||||
return
|
||||
|
||||
|
||||
# 使用numpy批量操作代替random.shuffle
|
||||
if batch_samples:
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
|
|
@ -495,15 +470,14 @@ class PinyinInputDataset(IterableDataset):
|
|||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
# 批量收集需要查询的字符信息
|
||||
chinese_positions = [
|
||||
i
|
||||
for i, char in enumerate(text)
|
||||
i for i, char in enumerate(text)
|
||||
if self.query_engine.is_chinese_char(char)
|
||||
]
|
||||
|
||||
|
||||
# 批量收集需要查询的字符信息
|
||||
char_pinyin_batch = []
|
||||
char_positions = []
|
||||
|
||||
|
||||
for i in chinese_positions: # 只遍历中文字符位置
|
||||
char = text[i]
|
||||
py = pinyin_list[i]
|
||||
|
|
@ -544,6 +518,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
由于是流式数据集,无法预先知道长度
|
||||
|
|
@ -616,7 +591,7 @@ if __name__ == "__main__":
|
|||
|
||||
# 初始化查询引擎
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
query_engine.load("./pinyin_char_pairs_info.json")
|
||||
|
||||
# 创建数据集
|
||||
dataset = PinyinInputDataset(
|
||||
|
|
@ -650,7 +625,7 @@ if __name__ == "__main__":
|
|||
break
|
||||
return
|
||||
|
||||
|
||||
|
||||
cProfile.run('profile_func(dataloader)')
|
||||
|
||||
"""
|
||||
|
|
@ -670,3 +645,4 @@ if __name__ == "__main__":
|
|||
"""
|
||||
except StopIteration:
|
||||
print("数据集为空")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
# file name: query_engine.py
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import msgpack
|
||||
import gzip
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import time
|
||||
import os
|
||||
|
||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||
|
||||
|
|
@ -16,7 +13,7 @@ from .char_info import CharInfo, PinyinCharPairsCounter
|
|||
class QueryEngine:
|
||||
"""
|
||||
高效拼音-字符查询引擎
|
||||
|
||||
|
||||
特性:
|
||||
1. O(1)时间复杂度的ID查询
|
||||
2. O(1)时间复杂度的字符查询
|
||||
|
|
@ -24,22 +21,22 @@ class QueryEngine:
|
|||
4. 内存友好,构建高效索引
|
||||
5. 支持批量查询和前缀搜索
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, min_count: int = 109):
|
||||
"""初始化查询引擎"""
|
||||
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||||
|
||||
|
||||
# 核心索引 - 提供O(1)查询
|
||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
||||
|
||||
|
||||
# 辅助索引 - 快速获取详细信息
|
||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||||
|
||||
|
||||
# 统计信息
|
||||
self._loaded = False
|
||||
self._total_pairs = 0
|
||||
|
|
@ -47,96 +44,91 @@ class QueryEngine:
|
|||
self._index_time = 0.0
|
||||
|
||||
self.min_count = min_count
|
||||
|
||||
def load(
|
||||
self,
|
||||
file_path: Union[str, Path] = (
|
||||
files(__package__) / "data" / "pinyin_char_statistics.json"
|
||||
),
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
def load(self, file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
加载统计结果文件
|
||||
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩
|
||||
|
||||
file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩
|
||||
|
||||
Returns:
|
||||
元数据字典
|
||||
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 文件格式不支持
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 读取并解析文件
|
||||
self._counter_data = self.parse_file(file_path)
|
||||
|
||||
|
||||
# 构建索引
|
||||
self._build_indices()
|
||||
|
||||
|
||||
self._load_time = time.time() - start_time
|
||||
self._loaded = True
|
||||
|
||||
|
||||
return self._counter_data.metadata
|
||||
|
||||
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
|
||||
|
||||
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
||||
"""解析文件,支持多种格式"""
|
||||
with open(file_path, "rb") as f:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
|
||||
# 尝试解压
|
||||
try:
|
||||
data = gzip.decompress(data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 尝试不同格式
|
||||
for parser_name, parser in [
|
||||
("msgpack", self._parse_msgpack),
|
||||
("pickle", self._parse_pickle),
|
||||
("json", self._parse_json),
|
||||
('msgpack', self._parse_msgpack),
|
||||
('pickle', self._parse_pickle),
|
||||
('json', self._parse_json)
|
||||
]:
|
||||
try:
|
||||
return parser(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
raise ValueError("无法解析文件格式")
|
||||
|
||||
|
||||
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析msgpack格式"""
|
||||
data_dict = msgpack.unpackb(data, raw=False)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
|
||||
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析pickle格式"""
|
||||
return pickle.loads(data)
|
||||
|
||||
|
||||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析json格式"""
|
||||
data_str = data.decode("utf-8")
|
||||
data_str = data.decode('utf-8')
|
||||
data_dict = json.loads(data_str)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
|
||||
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
||||
"""字典转PinyinCharPairsCounter"""
|
||||
# 转换CharInfo字典
|
||||
pairs_dict = {}
|
||||
if "pairs" in data_dict and data_dict["pairs"]:
|
||||
for id_str, info_dict in data_dict["pairs"].items():
|
||||
if 'pairs' in data_dict and data_dict['pairs']:
|
||||
for id_str, info_dict in data_dict['pairs'].items():
|
||||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||
data_dict["pairs"] = pairs_dict
|
||||
|
||||
data_dict['pairs'] = pairs_dict
|
||||
|
||||
return PinyinCharPairsCounter(**data_dict)
|
||||
|
||||
|
||||
def _build_indices(self):
|
||||
"""构建所有查询索引"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 重置索引
|
||||
self._id_to_info.clear()
|
||||
self._char_to_ids.clear()
|
||||
|
|
@ -145,13 +137,13 @@ class QueryEngine:
|
|||
self._char_freq.clear()
|
||||
self._pinyin_freq.clear()
|
||||
self._char_pinyin_map.clear()
|
||||
|
||||
|
||||
# 复制频率数据
|
||||
if self._counter_data.chars:
|
||||
self._char_freq = self._counter_data.chars.copy()
|
||||
if self._counter_data.pinyins:
|
||||
self._pinyin_freq = self._counter_data.pinyins.copy()
|
||||
|
||||
|
||||
# 构建核心索引
|
||||
for char_info in self._counter_data.pairs.values():
|
||||
if char_info.count < self.min_count:
|
||||
|
|
@ -159,169 +151,165 @@ class QueryEngine:
|
|||
char = char_info.char
|
||||
pinyin = char_info.pinyin
|
||||
char_info_id = char_info.id
|
||||
|
||||
|
||||
# ID索引
|
||||
self._id_to_info[char_info_id] = char_info
|
||||
|
||||
|
||||
# 字符索引
|
||||
if char not in self._char_to_ids:
|
||||
self._char_to_ids[char] = []
|
||||
self._char_to_ids[char].append(char_info_id)
|
||||
|
||||
|
||||
# 拼音索引
|
||||
if pinyin not in self._pinyin_to_ids:
|
||||
self._pinyin_to_ids[pinyin] = []
|
||||
self._pinyin_to_ids[pinyin].append(char_info_id)
|
||||
|
||||
|
||||
# 字符-拼音映射
|
||||
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||||
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
|
||||
|
||||
|
||||
self._total_pairs = len(self._id_to_info)
|
||||
self._index_time = time.time() - start_time
|
||||
|
||||
|
||||
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
||||
"""
|
||||
通过ID查询字符信息 - O(1)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
id: 记录ID
|
||||
|
||||
|
||||
Returns:
|
||||
CharInfo对象,不存在则返回None
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
return self._id_to_info.get(id)
|
||||
|
||||
|
||||
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
||||
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
limit: 返回结果数量限制,0表示返回所有
|
||||
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
if char not in self._char_to_ids:
|
||||
return []
|
||||
|
||||
|
||||
# 获取所有相关ID
|
||||
ids = self._char_to_ids[char]
|
||||
|
||||
|
||||
# 构建结果并排序
|
||||
results = []
|
||||
for char_info_id in ids:
|
||||
char_info = self._id_to_info[char_info_id]
|
||||
results.append((char_info_id, char_info.pinyin, char_info.count))
|
||||
|
||||
|
||||
# 按次数降序排序
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
|
||||
# 应用限制
|
||||
if limit > 0 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def query_by_pinyin(
|
||||
self, pinyin: str, limit: int = 0
|
||||
) -> List[Tuple[int, str, int]]:
|
||||
|
||||
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
pinyin: 拼音字符串
|
||||
limit: 返回结果数量限制,0表示返回所有
|
||||
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
if pinyin not in self._pinyin_to_ids:
|
||||
return []
|
||||
|
||||
|
||||
# 获取所有相关ID
|
||||
ids = self._pinyin_to_ids[pinyin]
|
||||
|
||||
|
||||
# 构建结果并排序
|
||||
results = []
|
||||
for char_info_id in ids:
|
||||
char_info = self._id_to_info[char_info_id]
|
||||
results.append((char_info_id, char_info.char, char_info.count))
|
||||
|
||||
|
||||
# 按次数降序排序
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
|
||||
# 应用限制
|
||||
if limit > 0 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_char_frequency(self, char: str) -> int:
|
||||
"""
|
||||
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
|
||||
|
||||
Returns:
|
||||
总出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
return self._char_freq.get(char, 0)
|
||||
|
||||
|
||||
def get_pinyin_frequency(self, pinyin: str) -> int:
|
||||
"""
|
||||
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
pinyin: 拼音字符串
|
||||
|
||||
|
||||
Returns:
|
||||
总出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
return self._pinyin_freq.get(pinyin, 0)
|
||||
|
||||
|
||||
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
||||
"""
|
||||
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
pinyin: 拼音字符串
|
||||
|
||||
|
||||
Returns:
|
||||
出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||
|
||||
def get_char_info_by_char_pinyin(
|
||||
self, char: str, pinyin: str
|
||||
) -> Optional[CharInfo]:
|
||||
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
|
||||
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
pinyin: 拼音字符串
|
||||
|
||||
|
||||
Returns:
|
||||
ID和频率
|
||||
"""
|
||||
|
|
@ -331,14 +319,12 @@ class QueryEngine:
|
|||
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||
return self.query_by_id(char_info_id)
|
||||
|
||||
def batch_get_char_pinyin_info(
|
||||
self, pairs: List[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], CharInfo]:
|
||||
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
|
||||
"""批量获取汉字-拼音信息
|
||||
|
||||
|
||||
Args:
|
||||
pairs: 汉字-拼音列表
|
||||
|
||||
|
||||
Returns:
|
||||
字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None)
|
||||
"""
|
||||
|
|
@ -353,72 +339,69 @@ class QueryEngine:
|
|||
else:
|
||||
result[pair] = None
|
||||
return result
|
||||
|
||||
|
||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||
"""
|
||||
批量ID查询 - O(n)时间复杂度
|
||||
|
||||
|
||||
Args:
|
||||
ids: ID列表
|
||||
|
||||
|
||||
Returns:
|
||||
字典,key为ID,value为CharInfo对象(不存在则为None)
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
results = {}
|
||||
for id_value in ids:
|
||||
results[id_value] = self._id_to_info.get(id_value, None)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def batch_query_by_chars(
|
||||
self, chars: List[str], limit_per_char: int = 0
|
||||
) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||
"""
|
||||
批量字符查询
|
||||
|
||||
|
||||
Args:
|
||||
chars: 字符列表
|
||||
limit_per_char: 每个字符的结果数量限制
|
||||
|
||||
|
||||
Returns:
|
||||
字典,key为字符,value为查询结果列表
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
results = {}
|
||||
for char in chars:
|
||||
results[char] = self.query_by_char(char, limit_per_char)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def search_chars_by_prefix(
|
||||
self, prefix: str, limit: int = 20
|
||||
) -> List[Tuple[str, int]]:
|
||||
|
||||
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||
|
||||
|
||||
Args:
|
||||
prefix: 字符前缀
|
||||
limit: 返回结果数量限制
|
||||
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(字符, 总频率),按频率降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
|
||||
matches = []
|
||||
for char, freq in self._char_freq.items():
|
||||
if char.startswith(prefix):
|
||||
matches.append((char, freq))
|
||||
|
||||
|
||||
# 按频率降序排序
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
return matches[:limit] if limit > 0 else matches
|
||||
|
||||
def is_chinese_char(self, char: str) -> bool:
|
||||
|
|
@ -426,27 +409,31 @@ class QueryEngine:
|
|||
判断是否是汉字
|
||||
"""
|
||||
if not self.is_loaded():
|
||||
raise ValueError("请先调用 load() 方法加载数据")
|
||||
raise ValueError('请先调用 load() 方法加载数据')
|
||||
return char in self._char_to_ids
|
||||
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
if not self._loaded:
|
||||
return {"status": "not_loaded"}
|
||||
|
||||
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
|
||||
:10
|
||||
]
|
||||
|
||||
top_pinyins = sorted(
|
||||
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
|
||||
|
||||
top_chars = sorted(
|
||||
self._char_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
|
||||
top_pinyins = sorted(
|
||||
self._pinyin_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
return {
|
||||
"status": "loaded",
|
||||
"timestamp": self._counter_data.timestamp,
|
||||
|
|
@ -458,13 +445,13 @@ class QueryEngine:
|
|||
"index_time_seconds": self._index_time,
|
||||
"top_chars": top_chars,
|
||||
"top_pinyins": top_pinyins,
|
||||
"metadata": self._counter_data.metadata,
|
||||
"metadata": self._counter_data.metadata
|
||||
}
|
||||
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""检查数据是否已加载"""
|
||||
return self._loaded
|
||||
|
||||
|
||||
def clear(self):
|
||||
"""清除所有数据和索引,释放内存"""
|
||||
self._counter_data = None
|
||||
|
|
@ -478,4 +465,4 @@ class QueryEngine:
|
|||
self._loaded = False
|
||||
self._total_pairs = 0
|
||||
self._load_time = 0.0
|
||||
self._index_time = 0.0
|
||||
self._index_time = 0.0
|
||||
Loading…
Reference in New Issue