feat: 添加汉字拼音统计工具,支持多线程处理与多种格式导出
This commit is contained in:
parent
395c02b913
commit
c22313748c
|
|
@ -210,3 +210,8 @@ cython_debug/
|
|||
*.out
|
||||
*.app
|
||||
|
||||
uv.lock
|
||||
|
||||
*.json
|
||||
*.log
|
||||
marimo/
|
||||
|
|
@ -0,0 +1 @@
|
|||
3.13
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
[project]
|
||||
name = "suinput"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"datasets>=4.5.0",
|
||||
"loguru>=0.7.3",
|
||||
"modelscope>=1.34.0",
|
||||
"msgpack>=1.1.2",
|
||||
"pypinyin>=0.55.0",
|
||||
"rich>=14.3.1",
|
||||
"typer>=0.21.1",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.2",
|
||||
]
|
||||
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from .counter import counter_dataset
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
app = typer.Typer()
|
||||
console = Console()
|
||||
|
||||
@app.command()
|
||||
def app_counter_dataset(
|
||||
dataset_path: str = typer.Option(..., help="数据集路径"),
|
||||
num_workers: int = typer.Option(cpu_count()-1, help="工作线程数"),
|
||||
output_path: str = typer.Option(None, help="输出文件路径(可选)"),
|
||||
format: str = typer.Option("msgpack", help="输出格式:json/msgpack/pickle"),
|
||||
compress: bool = typer.Option(False, help="是否压缩输出文件"),
|
||||
max_queue_size: int = typer.Option(5000, help="队列最大容量"),
|
||||
max_lines: int = typer.Option(10000, help="处理的最大行数"),
|
||||
):
|
||||
"""
|
||||
统计数据集中汉字和拼音的频率,并导出结果。
|
||||
"""
|
||||
console.print("[bold green]🚀 开始处理数据集...[/bold green]")
|
||||
|
||||
# 显示参数配置
|
||||
table = Table(title="📊 参数配置")
|
||||
table.add_column("参数", style="cyan")
|
||||
table.add_column("值", style="magenta")
|
||||
table.add_row("数据集路径", dataset_path)
|
||||
table.add_row("工作线程数", str(num_workers))
|
||||
table.add_row("输出路径", output_path or "默认输出")
|
||||
table.add_row("输出格式", format)
|
||||
table.add_row("是否压缩", str(compress))
|
||||
table.add_row("队列容量", str(max_queue_size))
|
||||
table.add_row("最大行数", str(max_lines))
|
||||
console.print(table)
|
||||
|
||||
# 调用处理函数
|
||||
counter_dataset(
|
||||
dataset_path=dataset_path,
|
||||
num_workers=num_workers,
|
||||
output_path=output_path,
|
||||
format=format,
|
||||
compress=compress,
|
||||
max_queue_size=max_queue_size,
|
||||
max_lines=max_lines,
|
||||
)
|
||||
|
||||
console.print("[bold green]✅ 数据处理完成![/bold green]")
|
||||
|
||||
app()
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
@dataclass
|
||||
class CharInfo:
|
||||
"""字符信息数据结构"""
|
||||
id: int = 0
|
||||
char: str = ""
|
||||
pinyin: str = ""
|
||||
count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PinyinCharPairsCounter:
|
||||
"""拼音字符对计数器"""
|
||||
|
||||
timestamp: str = ""
|
||||
total_characters: int = 0
|
||||
total_pinyins: int = 0
|
||||
valid_input_character_count: int = 0 # 新增字段
|
||||
pairs: Dict[int, CharInfo] = field(default_factory=dict)
|
||||
chars: Dict[str, int] = field(default_factory=dict)
|
||||
pinyins: Dict[str, int] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
|
@ -0,0 +1,345 @@
|
|||
import os
|
||||
import json
|
||||
import pickle
|
||||
import msgpack
|
||||
from multiprocessing import Process, Lock, Queue, current_process, cpu_count
|
||||
from collections import Counter
|
||||
from dataclasses import asdict
|
||||
import time
|
||||
import queue
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from pypinyin import lazy_pinyin
|
||||
from loguru import logger
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from .char_info import PinyinCharPairsCounter, CharInfo
|
||||
|
||||
|
||||
|
||||
class PinyinCharStatistics:
|
||||
def __init__(
|
||||
self, is_multithreaded=False, num_workers: int = 8, max_queue_size: int = 5000
|
||||
):
|
||||
"""
|
||||
初始化处理器
|
||||
"""
|
||||
self.max_queue_size = max_queue_size
|
||||
self.num_workers = num_workers
|
||||
self.is_multithreaded = is_multithreaded
|
||||
self._is_shutting_down = False
|
||||
|
||||
# 主进程中的计数器(用于单线程模式)
|
||||
self._char_counter = Counter()
|
||||
self._pinyin_counter = Counter()
|
||||
self.pinyin_char_counter = Counter()
|
||||
|
||||
# 多进程相关
|
||||
if is_multithreaded:
|
||||
# manager = Manager()
|
||||
# 创建任务队列和结果队列
|
||||
self.task_queue = Queue(maxsize=max_queue_size)
|
||||
self.result_queue = Queue() # 用于收集工作进程的结果
|
||||
self.lock = Lock()
|
||||
self.workers = []
|
||||
self.running = True
|
||||
|
||||
# 启动工作进程
|
||||
self._start_workers()
|
||||
|
||||
def _start_workers(self):
|
||||
"""启动工作进程"""
|
||||
for i in range(self.num_workers):
|
||||
worker = Process(
|
||||
target=self._worker_loop,
|
||||
args=(self.task_queue, self.result_queue),
|
||||
daemon=True
|
||||
)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
@staticmethod
|
||||
def _worker_loop(task_queue: Queue, result_queue: Queue):
|
||||
"""工作进程主循环"""
|
||||
# 每个进程有自己的本地计数器
|
||||
local_char_counter = Counter()
|
||||
local_pinyin_counter = Counter()
|
||||
local_pair_counter = Counter()
|
||||
|
||||
should_exit = False
|
||||
|
||||
while not should_exit:
|
||||
try:
|
||||
text = task_queue.get(timeout=1)
|
||||
if text is None: # 终止信号
|
||||
should_exit = True
|
||||
continue # 继续循环,但会退出
|
||||
# 处理文本
|
||||
chinese_chars = [c for c in text if "\u4e00" <= c <= "\u9fff"]
|
||||
if len(chinese_chars) / len(text) < 0.6:
|
||||
continue
|
||||
|
||||
# 获取拼音
|
||||
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
|
||||
pairs = [
|
||||
(char, pinyin) for char, pinyin in zip(text, pinyin_list)
|
||||
if char != pinyin
|
||||
]
|
||||
|
||||
# 更新本地计数器
|
||||
local_char_counter.update(chinese_chars)
|
||||
local_pinyin_counter.update([p for _, p in pairs])
|
||||
local_pair_counter.update(pairs)
|
||||
except queue.Empty:
|
||||
if should_exit:
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理任务时出错: {e}")
|
||||
if local_char_counter or local_pair_counter:
|
||||
result_queue.put({
|
||||
'char_counter': dict(local_char_counter),
|
||||
'pinyin_counter': dict(local_pinyin_counter),
|
||||
'pair_counter': dict(local_pair_counter)
|
||||
})
|
||||
|
||||
def _collect_results(self):
|
||||
"""收集所有工作进程的结果"""
|
||||
# 1. 先发送终止信号
|
||||
for i in range(len(self.workers)):
|
||||
self.task_queue.put(None)
|
||||
|
||||
# 2. 给工作进程时间处理
|
||||
time.sleep(1) # 等待1秒让进程收到信号
|
||||
|
||||
# 3. 收集结果时增加重试机制
|
||||
results_collected = 0
|
||||
max_retries = len(self.workers) * 3 # 最多尝试3轮
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
# 非阻塞检查
|
||||
if results_collected >= len(self.workers):
|
||||
break
|
||||
|
||||
result = self.result_queue.get(timeout=15) # 增加超时时间
|
||||
|
||||
if result:
|
||||
# 合并到主进程计数器
|
||||
self._char_counter.update(result['char_counter'])
|
||||
self._pinyin_counter.update(result['pinyin_counter'])
|
||||
|
||||
# 处理pair_counter(需要从字符串转换回元组)
|
||||
for pairs, count in result['pair_counter'].items():
|
||||
# 将字符串 "(char, pinyin)" 转换回元组
|
||||
if isinstance(pairs, str):
|
||||
try:
|
||||
# 简单解析:去除括号,分割字符串
|
||||
content = pairs[1:-1]
|
||||
# 按逗号分割,最多分割一次
|
||||
if ", " in content:
|
||||
char, pinyin = content.split(", ", 1)
|
||||
else:
|
||||
char, pinyin = content.split(",", 1)
|
||||
# 去除可能的引号
|
||||
char = char.strip("'\"")
|
||||
pinyin = pinyin.strip("'\"")
|
||||
self.pinyin_char_counter.update({(char, pinyin): count})
|
||||
except Exception:
|
||||
logger.warning(f"无法解析pair: {pairs}")
|
||||
elif isinstance(pairs, tuple):
|
||||
# 直接作为key使用
|
||||
self.pinyin_char_counter.update({pairs: count})
|
||||
else:
|
||||
logger.warning(f"无法处理pair: {pairs}")
|
||||
|
||||
results_collected += 1
|
||||
except queue.Empty:
|
||||
if retry < max_retries - 1:
|
||||
logger.warning(f"第 {retry + 1} 次等待结果超时,继续尝试...")
|
||||
continue
|
||||
else:
|
||||
logger.error(f"等待结果超时,已尝试 {max_retries} 次")
|
||||
break
|
||||
|
||||
alive_count = 0
|
||||
for worker in self.workers:
|
||||
worker.join(timeout=5) # 增加等待时间
|
||||
if worker.is_alive():
|
||||
logger.warning(f"工作进程 {worker.pid} 仍在运行")
|
||||
# 检查进程是否真的在忙
|
||||
import psutil
|
||||
try:
|
||||
p = psutil.Process(worker.pid)
|
||||
# 检查CPU使用率或状态
|
||||
logger.warning(f"进程状态: {p.status()}, CPU: {p.cpu_percent()}")
|
||||
except Exception:
|
||||
pass
|
||||
worker.terminate()
|
||||
alive_count += 1
|
||||
|
||||
if alive_count > 0:
|
||||
logger.warning(f"强制终止了 {alive_count} 个工作进程")
|
||||
else:
|
||||
logger.info("所有工作进程正常结束")
|
||||
|
||||
self.workers.clear()
|
||||
def _process_single_text(self, text: str):
|
||||
"""单线程处理单个文本(线程安全)"""
|
||||
chinese_chars = [c for c in text if self._is_chinese_char(c)]
|
||||
if len(chinese_chars) / len(text) < 0.6:
|
||||
return
|
||||
|
||||
# 获取拼音
|
||||
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
|
||||
pairs = [
|
||||
(char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin
|
||||
]
|
||||
|
||||
# 更新计数器
|
||||
self.pinyin_char_counter.update(pairs)
|
||||
self._char_counter.update(chinese_chars)
|
||||
self._pinyin_counter.update([p for _, p in pairs])
|
||||
|
||||
def _is_chinese_char(self, char: str) -> bool:
|
||||
"""判断字符是否为中文汉字"""
|
||||
return "\u4e00" <= char <= "\u9fff"
|
||||
|
||||
def process_text(self, text: str):
|
||||
"""
|
||||
将文本加入任务队列,由工作进程异步处理
|
||||
"""
|
||||
if self._is_shutting_down:
|
||||
raise RuntimeError("处理器正在关闭,无法接受新任务")
|
||||
|
||||
if self.is_multithreaded:
|
||||
if not self.running:
|
||||
raise RuntimeError("处理器已关闭,无法接受新任务")
|
||||
self.task_queue.put(text)
|
||||
if self.task_queue.qsize() > self.max_queue_size * 0.8:
|
||||
time.sleep(0.02)
|
||||
else:
|
||||
self._process_single_text(text)
|
||||
|
||||
def shutdown(self, wait: bool = True):
|
||||
"""
|
||||
关闭处理器,等待所有任务完成(可选)
|
||||
"""
|
||||
if self._is_shutting_down:
|
||||
return
|
||||
|
||||
self._is_shutting_down = True
|
||||
self.running = False
|
||||
|
||||
if wait and self.is_multithreaded:
|
||||
# 收集结果
|
||||
self._collect_results()
|
||||
|
||||
def export(
|
||||
self, output_path: str = None, format: str = "msgpack", compress: bool = False
|
||||
):
|
||||
"""
|
||||
导出统计信息
|
||||
"""
|
||||
logger.info("开始导出快照...")
|
||||
|
||||
# 如果是多进程模式,确保已收集所有结果
|
||||
if self.is_multithreaded:
|
||||
self._collect_results()
|
||||
|
||||
# 构建 PinyinCharPairsCounter 对象
|
||||
counter_data = PinyinCharPairsCounter()
|
||||
counter_data.total_characters = len(self._char_counter)
|
||||
counter_data.total_pinyins = len(self._pinyin_counter)
|
||||
counter_data.valid_input_character_count = sum(self.pinyin_char_counter.values())
|
||||
|
||||
# 构造 CharInfo 列表
|
||||
char_info_list = []
|
||||
pc_counter_sorted = sorted(
|
||||
self.pinyin_char_counter.items(), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
||||
for idx, ((char, pinyin), count) in enumerate(pc_counter_sorted):
|
||||
char_info = CharInfo(id=idx + 1, char=char, pinyin=pinyin, count=count)
|
||||
char_info_list.append(char_info)
|
||||
|
||||
counter_data.pairs = {info.id: info for info in char_info_list}
|
||||
counter_data.chars = dict(self._char_counter)
|
||||
counter_data.pinyins = dict(self._pinyin_counter)
|
||||
counter_data.metadata = {
|
||||
"format": format,
|
||||
"compressed": compress,
|
||||
"pair_count": len(counter_data.pairs),
|
||||
}
|
||||
|
||||
if not output_path:
|
||||
dict_counter_data = asdict(counter_data)
|
||||
dict_counter_data.pop("pairs")
|
||||
dict_counter_data.pop("chars")
|
||||
dict_counter_data.pop("pinyins")
|
||||
print(dict_counter_data)
|
||||
return counter_data.metadata
|
||||
|
||||
# 序列化数据
|
||||
if format == "json":
|
||||
data_str = json.dumps(asdict(counter_data), ensure_ascii=False, indent=2)
|
||||
data = data_str.encode("utf-8")
|
||||
elif format == "msgpack":
|
||||
data = msgpack.packb(asdict(counter_data), use_bin_type=True)
|
||||
elif format == "pickle":
|
||||
data = pickle.dumps(counter_data)
|
||||
else:
|
||||
raise ValueError(f"不支持的格式: {format}")
|
||||
|
||||
# 可选压缩
|
||||
if compress:
|
||||
import gzip
|
||||
data = gzip.compress(data)
|
||||
|
||||
# 写入文件
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
logger.info(f"快照导出完成: {output_path} (格式: {format}, 压缩: {compress})")
|
||||
return counter_data.metadata
|
||||
|
||||
|
||||
def counter_dataset(
|
||||
dataset_path: str,
|
||||
num_workers: int = cpu_count(),
|
||||
output_path: str = None,
|
||||
format: str = "msgpack",
|
||||
compress: bool = True,
|
||||
max_queue_size: int = 50000,
|
||||
max_lines: int = 10000,
|
||||
):
|
||||
"""
|
||||
处理数据集的核心逻辑
|
||||
"""
|
||||
dataset = load_dataset(dataset_path, split="train", streaming=True)
|
||||
iter_data = iter(dataset)
|
||||
|
||||
# 使用多进程
|
||||
processor = PinyinCharStatistics(
|
||||
is_multithreaded=True,
|
||||
num_workers=4,
|
||||
max_queue_size=max_queue_size
|
||||
)
|
||||
|
||||
# 使用 tqdm.trange 显示进度条
|
||||
for i in trange(max_lines, desc="正在处理数据"):
|
||||
try:
|
||||
line = next(iter_data)
|
||||
text = line.get("text", "")
|
||||
if text:
|
||||
processor.process_text(text)
|
||||
except StopIteration:
|
||||
print("⚠️ 数据集已提前结束。")
|
||||
break
|
||||
|
||||
# 导出前会收集所有结果
|
||||
processor.export(output_path, format=format, compress=compress)
|
||||
if not processor._is_shutting_down:
|
||||
processor.shutdown()
|
||||
|
|
@ -0,0 +1,414 @@
|
|||
# file name: query_engine.py
|
||||
import json
|
||||
import pickle
|
||||
import msgpack
|
||||
import gzip
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import time
|
||||
import os
|
||||
|
||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||
|
||||
|
||||
class QueryEngine:
|
||||
"""
|
||||
高效拼音-字符查询引擎
|
||||
|
||||
特性:
|
||||
1. O(1)时间复杂度的ID查询
|
||||
2. O(1)时间复杂度的字符查询
|
||||
3. O(1)时间复杂度的拼音查询
|
||||
4. 内存友好,构建高效索引
|
||||
5. 支持批量查询和前缀搜索
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化查询引擎"""
|
||||
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||||
|
||||
# 核心索引 - 提供O(1)查询
|
||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||
|
||||
# 辅助索引 - 快速获取详细信息
|
||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||||
|
||||
# 统计信息
|
||||
self._loaded = False
|
||||
self._total_pairs = 0
|
||||
self._load_time = 0.0
|
||||
self._index_time = 0.0
|
||||
|
||||
def load(self, file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
加载统计结果文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩
|
||||
|
||||
Returns:
|
||||
元数据字典
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 文件格式不支持
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 读取并解析文件
|
||||
self._counter_data = self._parse_file(file_path)
|
||||
|
||||
# 构建索引
|
||||
self._build_indices()
|
||||
|
||||
self._load_time = time.time() - start_time
|
||||
self._loaded = True
|
||||
|
||||
return self._counter_data.metadata
|
||||
|
||||
def _parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
||||
"""解析文件,支持多种格式"""
|
||||
with open(file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
# 尝试解压
|
||||
try:
|
||||
data = gzip.decompress(data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 尝试不同格式
|
||||
for parser_name, parser in [
|
||||
('msgpack', self._parse_msgpack),
|
||||
('pickle', self._parse_pickle),
|
||||
('json', self._parse_json)
|
||||
]:
|
||||
try:
|
||||
return parser(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise ValueError("无法解析文件格式")
|
||||
|
||||
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析msgpack格式"""
|
||||
data_dict = msgpack.unpackb(data, raw=False)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析pickle格式"""
|
||||
return pickle.loads(data)
|
||||
|
||||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析json格式"""
|
||||
data_str = data.decode('utf-8')
|
||||
data_dict = json.loads(data_str)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
||||
"""字典转PinyinCharPairsCounter"""
|
||||
# 转换CharInfo字典
|
||||
pairs_dict = {}
|
||||
if 'pairs' in data_dict and data_dict['pairs']:
|
||||
for id_str, info_dict in data_dict['pairs'].items():
|
||||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||
data_dict['pairs'] = pairs_dict
|
||||
|
||||
return PinyinCharPairsCounter(**data_dict)
|
||||
|
||||
def _build_indices(self):
|
||||
"""构建所有查询索引"""
|
||||
start_time = time.time()
|
||||
|
||||
# 重置索引
|
||||
self._id_to_info.clear()
|
||||
self._char_to_ids.clear()
|
||||
self._pinyin_to_ids.clear()
|
||||
self._char_freq.clear()
|
||||
self._pinyin_freq.clear()
|
||||
self._char_pinyin_map.clear()
|
||||
|
||||
# 复制频率数据
|
||||
if self._counter_data.chars:
|
||||
self._char_freq = self._counter_data.chars.copy()
|
||||
if self._counter_data.pinyins:
|
||||
self._pinyin_freq = self._counter_data.pinyins.copy()
|
||||
|
||||
# 构建核心索引
|
||||
for char_info in self._counter_data.pairs.values():
|
||||
char = char_info.char
|
||||
pinyin = char_info.pinyin
|
||||
char_info_id = char_info.id
|
||||
|
||||
# ID索引
|
||||
self._id_to_info[char_info_id] = char_info
|
||||
|
||||
# 字符索引
|
||||
if char not in self._char_to_ids:
|
||||
self._char_to_ids[char] = []
|
||||
self._char_to_ids[char].append(char_info_id)
|
||||
|
||||
# 拼音索引
|
||||
if pinyin not in self._pinyin_to_ids:
|
||||
self._pinyin_to_ids[pinyin] = []
|
||||
self._pinyin_to_ids[pinyin].append(char_info_id)
|
||||
|
||||
# 字符-拼音映射
|
||||
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||||
|
||||
self._total_pairs = len(self._id_to_info)
|
||||
self._index_time = time.time() - start_time
|
||||
|
||||
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
||||
"""
|
||||
通过ID查询字符信息 - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
id: 记录ID
|
||||
|
||||
Returns:
|
||||
CharInfo对象,不存在则返回None
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._id_to_info.get(id)
|
||||
|
||||
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
limit: 返回结果数量限制,0表示返回所有
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
if char not in self._char_to_ids:
|
||||
return []
|
||||
|
||||
# 获取所有相关ID
|
||||
ids = self._char_to_ids[char]
|
||||
|
||||
# 构建结果并排序
|
||||
results = []
|
||||
for char_info_id in ids:
|
||||
char_info = self._id_to_info[char_info_id]
|
||||
results.append((char_info_id, char_info.pinyin, char_info.count))
|
||||
|
||||
# 按次数降序排序
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# 应用限制
|
||||
if limit > 0 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||
|
||||
Args:
|
||||
pinyin: 拼音字符串
|
||||
limit: 返回结果数量限制,0表示返回所有
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
if pinyin not in self._pinyin_to_ids:
|
||||
return []
|
||||
|
||||
# 获取所有相关ID
|
||||
ids = self._pinyin_to_ids[pinyin]
|
||||
|
||||
# 构建结果并排序
|
||||
results = []
|
||||
for char_info_id in ids:
|
||||
char_info = self._id_to_info[char_info_id]
|
||||
results.append((char_info_id, char_info.char, char_info.count))
|
||||
|
||||
# 按次数降序排序
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# 应用限制
|
||||
if limit > 0 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
def get_char_frequency(self, char: str) -> int:
|
||||
"""
|
||||
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
|
||||
Returns:
|
||||
总出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._char_freq.get(char, 0)
|
||||
|
||||
def get_pinyin_frequency(self, pinyin: str) -> int:
|
||||
"""
|
||||
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
pinyin: 拼音字符串
|
||||
|
||||
Returns:
|
||||
总出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._pinyin_freq.get(pinyin, 0)
|
||||
|
||||
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
||||
"""
|
||||
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
pinyin: 拼音字符串
|
||||
|
||||
Returns:
|
||||
出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||
|
||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||
"""
|
||||
批量ID查询 - O(n)时间复杂度
|
||||
|
||||
Args:
|
||||
ids: ID列表
|
||||
|
||||
Returns:
|
||||
字典,key为ID,value为CharInfo对象(不存在则为None)
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
results = {}
|
||||
for id_value in ids:
|
||||
results[id_value] = self._id_to_info.get(id_value)
|
||||
|
||||
return results
|
||||
|
||||
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||
"""
|
||||
批量字符查询
|
||||
|
||||
Args:
|
||||
chars: 字符列表
|
||||
limit_per_char: 每个字符的结果数量限制
|
||||
|
||||
Returns:
|
||||
字典,key为字符,value为查询结果列表
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
results = {}
|
||||
for char in chars:
|
||||
results[char] = self.query_by_char(char, limit_per_char)
|
||||
|
||||
return results
|
||||
|
||||
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||
|
||||
Args:
|
||||
prefix: 字符前缀
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(字符, 总频率),按频率降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
matches = []
|
||||
for char, freq in self._char_freq.items():
|
||||
if char.startswith(prefix):
|
||||
matches.append((char, freq))
|
||||
|
||||
# 按频率降序排序
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return matches[:limit] if limit > 0 else matches
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
if not self._loaded:
|
||||
return {"status": "not_loaded"}
|
||||
|
||||
top_chars = sorted(
|
||||
self._char_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
top_pinyins = sorted(
|
||||
self._pinyin_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
return {
|
||||
"status": "loaded",
|
||||
"timestamp": self._counter_data.timestamp,
|
||||
"total_pairs": self._total_pairs,
|
||||
"total_characters": len(self._char_freq),
|
||||
"total_pinyins": len(self._pinyin_freq),
|
||||
"valid_input_character_count": self._counter_data.valid_input_character_count,
|
||||
"load_time_seconds": self._load_time,
|
||||
"index_time_seconds": self._index_time,
|
||||
"top_chars": top_chars,
|
||||
"top_pinyins": top_pinyins,
|
||||
"metadata": self._counter_data.metadata
|
||||
}
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""检查数据是否已加载"""
|
||||
return self._loaded
|
||||
|
||||
def clear(self):
|
||||
"""清除所有数据和索引,释放内存"""
|
||||
self._counter_data = None
|
||||
self._id_to_info.clear()
|
||||
self._char_to_ids.clear()
|
||||
self._pinyin_to_ids.clear()
|
||||
self._char_freq.clear()
|
||||
self._pinyin_freq.clear()
|
||||
self._char_pinyin_map.clear()
|
||||
self._loaded = False
|
||||
self._total_pairs = 0
|
||||
self._load_time = 0.0
|
||||
self._index_time = 0.0
|
||||
Loading…
Reference in New Issue