feat: 添加汉字拼音统计工具,支持多线程处理与多种格式导出

This commit is contained in:
songsenand 2026-02-02 07:17:29 +08:00
parent 395c02b913
commit c22313748c
8 changed files with 866 additions and 0 deletions

5
.gitignore vendored
View File

@ -210,3 +210,8 @@ cython_debug/
*.out *.out
*.app *.app
uv.lock
*.json
*.log
marimo/

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.13

21
pyproject.toml Normal file
View File

@ -0,0 +1,21 @@
[project]
name = "suinput"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"datasets>=4.5.0",
"loguru>=0.7.3",
"modelscope>=1.34.0",
"msgpack>=1.1.2",
"pypinyin>=0.55.0",
"rich>=14.3.1",
"typer>=0.21.1",
]
[dependency-groups]
dev = [
"pytest>=9.0.2",
]

0
src/suinput/__init__.py Normal file
View File

52
src/suinput/__main__.py Normal file
View File

@ -0,0 +1,52 @@
from .counter import counter_dataset
from multiprocessing import cpu_count
import typer
from rich.console import Console
from rich.table import Table
app = typer.Typer()
console = Console()
@app.command()
def app_counter_dataset(
dataset_path: str = typer.Option(..., help="数据集路径"),
num_workers: int = typer.Option(cpu_count()-1, help="工作线程数"),
output_path: str = typer.Option(None, help="输出文件路径(可选)"),
format: str = typer.Option("msgpack", help="输出格式json/msgpack/pickle"),
compress: bool = typer.Option(False, help="是否压缩输出文件"),
max_queue_size: int = typer.Option(5000, help="队列最大容量"),
max_lines: int = typer.Option(10000, help="处理的最大行数"),
):
"""
统计数据集中汉字和拼音的频率并导出结果
"""
console.print("[bold green]🚀 开始处理数据集...[/bold green]")
# 显示参数配置
table = Table(title="📊 参数配置")
table.add_column("参数", style="cyan")
table.add_column("", style="magenta")
table.add_row("数据集路径", dataset_path)
table.add_row("工作线程数", str(num_workers))
table.add_row("输出路径", output_path or "默认输出")
table.add_row("输出格式", format)
table.add_row("是否压缩", str(compress))
table.add_row("队列容量", str(max_queue_size))
table.add_row("最大行数", str(max_lines))
console.print(table)
# 调用处理函数
counter_dataset(
dataset_path=dataset_path,
num_workers=num_workers,
output_path=output_path,
format=format,
compress=compress,
max_queue_size=max_queue_size,
max_lines=max_lines,
)
console.print("[bold green]✅ 数据处理完成![/bold green]")
app()

28
src/suinput/char_info.py Normal file
View File

@ -0,0 +1,28 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any
@dataclass
class CharInfo:
"""字符信息数据结构"""
id: int = 0
char: str = ""
pinyin: str = ""
count: int = 0
@dataclass
class PinyinCharPairsCounter:
"""拼音字符对计数器"""
timestamp: str = ""
total_characters: int = 0
total_pinyins: int = 0
valid_input_character_count: int = 0 # 新增字段
pairs: Dict[int, CharInfo] = field(default_factory=dict)
chars: Dict[str, int] = field(default_factory=dict)
pinyins: Dict[str, int] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

345
src/suinput/counter.py Normal file
View File

@ -0,0 +1,345 @@
import os
import json
import pickle
import msgpack
from multiprocessing import Process, Lock, Queue, current_process, cpu_count
from collections import Counter
from dataclasses import asdict
import time
import queue
from datasets import load_dataset
from pypinyin import lazy_pinyin
from loguru import logger
from tqdm import trange
from .char_info import PinyinCharPairsCounter, CharInfo
class PinyinCharStatistics:
def __init__(
self, is_multithreaded=False, num_workers: int = 8, max_queue_size: int = 5000
):
"""
初始化处理器
"""
self.max_queue_size = max_queue_size
self.num_workers = num_workers
self.is_multithreaded = is_multithreaded
self._is_shutting_down = False
# 主进程中的计数器(用于单线程模式)
self._char_counter = Counter()
self._pinyin_counter = Counter()
self.pinyin_char_counter = Counter()
# 多进程相关
if is_multithreaded:
# manager = Manager()
# 创建任务队列和结果队列
self.task_queue = Queue(maxsize=max_queue_size)
self.result_queue = Queue() # 用于收集工作进程的结果
self.lock = Lock()
self.workers = []
self.running = True
# 启动工作进程
self._start_workers()
def _start_workers(self):
"""启动工作进程"""
for i in range(self.num_workers):
worker = Process(
target=self._worker_loop,
args=(self.task_queue, self.result_queue),
daemon=True
)
worker.start()
self.workers.append(worker)
@staticmethod
def _worker_loop(task_queue: Queue, result_queue: Queue):
"""工作进程主循环"""
# 每个进程有自己的本地计数器
local_char_counter = Counter()
local_pinyin_counter = Counter()
local_pair_counter = Counter()
should_exit = False
while not should_exit:
try:
text = task_queue.get(timeout=1)
if text is None: # 终止信号
should_exit = True
continue # 继续循环,但会退出
# 处理文本
chinese_chars = [c for c in text if "\u4e00" <= c <= "\u9fff"]
if len(chinese_chars) / len(text) < 0.6:
continue
# 获取拼音
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
pairs = [
(char, pinyin) for char, pinyin in zip(text, pinyin_list)
if char != pinyin
]
# 更新本地计数器
local_char_counter.update(chinese_chars)
local_pinyin_counter.update([p for _, p in pairs])
local_pair_counter.update(pairs)
except queue.Empty:
if should_exit:
break
continue
except Exception as e:
logger.error(f"处理任务时出错: {e}")
if local_char_counter or local_pair_counter:
result_queue.put({
'char_counter': dict(local_char_counter),
'pinyin_counter': dict(local_pinyin_counter),
'pair_counter': dict(local_pair_counter)
})
def _collect_results(self):
"""收集所有工作进程的结果"""
# 1. 先发送终止信号
for i in range(len(self.workers)):
self.task_queue.put(None)
# 2. 给工作进程时间处理
time.sleep(1) # 等待1秒让进程收到信号
# 3. 收集结果时增加重试机制
results_collected = 0
max_retries = len(self.workers) * 3 # 最多尝试3轮
for retry in range(max_retries):
try:
# 非阻塞检查
if results_collected >= len(self.workers):
break
result = self.result_queue.get(timeout=15) # 增加超时时间
if result:
# 合并到主进程计数器
self._char_counter.update(result['char_counter'])
self._pinyin_counter.update(result['pinyin_counter'])
# 处理pair_counter需要从字符串转换回元组
for pairs, count in result['pair_counter'].items():
# 将字符串 "(char, pinyin)" 转换回元组
if isinstance(pairs, str):
try:
# 简单解析:去除括号,分割字符串
content = pairs[1:-1]
# 按逗号分割,最多分割一次
if ", " in content:
char, pinyin = content.split(", ", 1)
else:
char, pinyin = content.split(",", 1)
# 去除可能的引号
char = char.strip("'\"")
pinyin = pinyin.strip("'\"")
self.pinyin_char_counter.update({(char, pinyin): count})
except Exception:
logger.warning(f"无法解析pair: {pairs}")
elif isinstance(pairs, tuple):
# 直接作为key使用
self.pinyin_char_counter.update({pairs: count})
else:
logger.warning(f"无法处理pair: {pairs}")
results_collected += 1
except queue.Empty:
if retry < max_retries - 1:
logger.warning(f"{retry + 1} 次等待结果超时,继续尝试...")
continue
else:
logger.error(f"等待结果超时,已尝试 {max_retries}")
break
alive_count = 0
for worker in self.workers:
worker.join(timeout=5) # 增加等待时间
if worker.is_alive():
logger.warning(f"工作进程 {worker.pid} 仍在运行")
# 检查进程是否真的在忙
import psutil
try:
p = psutil.Process(worker.pid)
# 检查CPU使用率或状态
logger.warning(f"进程状态: {p.status()}, CPU: {p.cpu_percent()}")
except Exception:
pass
worker.terminate()
alive_count += 1
if alive_count > 0:
logger.warning(f"强制终止了 {alive_count} 个工作进程")
else:
logger.info("所有工作进程正常结束")
self.workers.clear()
def _process_single_text(self, text: str):
"""单线程处理单个文本(线程安全)"""
chinese_chars = [c for c in text if self._is_chinese_char(c)]
if len(chinese_chars) / len(text) < 0.6:
return
# 获取拼音
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
pairs = [
(char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin
]
# 更新计数器
self.pinyin_char_counter.update(pairs)
self._char_counter.update(chinese_chars)
self._pinyin_counter.update([p for _, p in pairs])
def _is_chinese_char(self, char: str) -> bool:
"""判断字符是否为中文汉字"""
return "\u4e00" <= char <= "\u9fff"
def process_text(self, text: str):
"""
将文本加入任务队列由工作进程异步处理
"""
if self._is_shutting_down:
raise RuntimeError("处理器正在关闭,无法接受新任务")
if self.is_multithreaded:
if not self.running:
raise RuntimeError("处理器已关闭,无法接受新任务")
self.task_queue.put(text)
if self.task_queue.qsize() > self.max_queue_size * 0.8:
time.sleep(0.02)
else:
self._process_single_text(text)
def shutdown(self, wait: bool = True):
"""
关闭处理器等待所有任务完成可选
"""
if self._is_shutting_down:
return
self._is_shutting_down = True
self.running = False
if wait and self.is_multithreaded:
# 收集结果
self._collect_results()
def export(
self, output_path: str = None, format: str = "msgpack", compress: bool = False
):
"""
导出统计信息
"""
logger.info("开始导出快照...")
# 如果是多进程模式,确保已收集所有结果
if self.is_multithreaded:
self._collect_results()
# 构建 PinyinCharPairsCounter 对象
counter_data = PinyinCharPairsCounter()
counter_data.total_characters = len(self._char_counter)
counter_data.total_pinyins = len(self._pinyin_counter)
counter_data.valid_input_character_count = sum(self.pinyin_char_counter.values())
# 构造 CharInfo 列表
char_info_list = []
pc_counter_sorted = sorted(
self.pinyin_char_counter.items(), key=lambda x: x[1], reverse=True
)
for idx, ((char, pinyin), count) in enumerate(pc_counter_sorted):
char_info = CharInfo(id=idx + 1, char=char, pinyin=pinyin, count=count)
char_info_list.append(char_info)
counter_data.pairs = {info.id: info for info in char_info_list}
counter_data.chars = dict(self._char_counter)
counter_data.pinyins = dict(self._pinyin_counter)
counter_data.metadata = {
"format": format,
"compressed": compress,
"pair_count": len(counter_data.pairs),
}
if not output_path:
dict_counter_data = asdict(counter_data)
dict_counter_data.pop("pairs")
dict_counter_data.pop("chars")
dict_counter_data.pop("pinyins")
print(dict_counter_data)
return counter_data.metadata
# 序列化数据
if format == "json":
data_str = json.dumps(asdict(counter_data), ensure_ascii=False, indent=2)
data = data_str.encode("utf-8")
elif format == "msgpack":
data = msgpack.packb(asdict(counter_data), use_bin_type=True)
elif format == "pickle":
data = pickle.dumps(counter_data)
else:
raise ValueError(f"不支持的格式: {format}")
# 可选压缩
if compress:
import gzip
data = gzip.compress(data)
# 写入文件
with open(output_path, "wb") as f:
f.write(data)
logger.info(f"快照导出完成: {output_path} (格式: {format}, 压缩: {compress})")
return counter_data.metadata
def counter_dataset(
dataset_path: str,
num_workers: int = cpu_count(),
output_path: str = None,
format: str = "msgpack",
compress: bool = True,
max_queue_size: int = 50000,
max_lines: int = 10000,
):
"""
处理数据集的核心逻辑
"""
dataset = load_dataset(dataset_path, split="train", streaming=True)
iter_data = iter(dataset)
# 使用多进程
processor = PinyinCharStatistics(
is_multithreaded=True,
num_workers=4,
max_queue_size=max_queue_size
)
# 使用 tqdm.trange 显示进度条
for i in trange(max_lines, desc="正在处理数据"):
try:
line = next(iter_data)
text = line.get("text", "")
if text:
processor.process_text(text)
except StopIteration:
print("⚠️ 数据集已提前结束。")
break
# 导出前会收集所有结果
processor.export(output_path, format=format, compress=compress)
if not processor._is_shutting_down:
processor.shutdown()

414
src/suinput/query.py Normal file
View File

@ -0,0 +1,414 @@
# file name: query_engine.py
import json
import pickle
import msgpack
import gzip
from typing import Dict, List, Optional, Tuple, Any
import time
import os
from .char_info import CharInfo, PinyinCharPairsCounter
class QueryEngine:
"""
高效拼音-字符查询引擎
特性:
1. O(1)时间复杂度的ID查询
2. O(1)时间复杂度的字符查询
3. O(1)时间复杂度的拼音查询
4. 内存友好构建高效索引
5. 支持批量查询和前缀搜索
"""
def __init__(self):
"""初始化查询引擎"""
self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
# 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0
def load(self, file_path: str) -> Dict[str, Any]:
"""
加载统计结果文件
Args:
file_path: 文件路径支持msgpack/pickle/json格式自动检测压缩
Returns:
元数据字典
Raises:
FileNotFoundError: 文件不存在
ValueError: 文件格式不支持
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
start_time = time.time()
# 读取并解析文件
self._counter_data = self._parse_file(file_path)
# 构建索引
self._build_indices()
self._load_time = time.time() - start_time
self._loaded = True
return self._counter_data.metadata
def _parse_file(self, file_path: str) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式"""
with open(file_path, 'rb') as f:
data = f.read()
# 尝试解压
try:
data = gzip.decompress(data)
except Exception:
pass
# 尝试不同格式
for parser_name, parser in [
('msgpack', self._parse_msgpack),
('pickle', self._parse_pickle),
('json', self._parse_json)
]:
try:
return parser(data)
except Exception:
continue
raise ValueError("无法解析文件格式")
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
"""解析msgpack格式"""
data_dict = msgpack.unpackb(data, raw=False)
return self._dict_to_counter(data_dict)
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
"""解析pickle格式"""
return pickle.loads(data)
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式"""
data_str = data.decode('utf-8')
data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict)
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
"""字典转PinyinCharPairsCounter"""
# 转换CharInfo字典
pairs_dict = {}
if 'pairs' in data_dict and data_dict['pairs']:
for id_str, info_dict in data_dict['pairs'].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict['pairs'] = pairs_dict
return PinyinCharPairsCounter(**data_dict)
def _build_indices(self):
"""构建所有查询索引"""
start_time = time.time()
# 重置索引
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
# 复制频率数据
if self._counter_data.chars:
self._char_freq = self._counter_data.chars.copy()
if self._counter_data.pinyins:
self._pinyin_freq = self._counter_data.pinyins.copy()
# 构建核心索引
for char_info in self._counter_data.pairs.values():
char = char_info.char
pinyin = char_info.pinyin
char_info_id = char_info.id
# ID索引
self._id_to_info[char_info_id] = char_info
# 字符索引
if char not in self._char_to_ids:
self._char_to_ids[char] = []
self._char_to_ids[char].append(char_info_id)
# 拼音索引
if pinyin not in self._pinyin_to_ids:
self._pinyin_to_ids[pinyin] = []
self._pinyin_to_ids[pinyin].append(char_info_id)
# 字符-拼音映射
self._char_pinyin_map[(char, pinyin)] = char_info.count
self._total_pairs = len(self._id_to_info)
self._index_time = time.time() - start_time
def query_by_id(self, id: int) -> Optional[CharInfo]:
"""
通过ID查询字符信息 - O(1)时间复杂度
Args:
id: 记录ID
Returns:
CharInfo对象不存在则返回None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._id_to_info.get(id)
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
"""
通过字符查询拼音信息 - O(1) + O(k)时间复杂度k为结果数
Args:
char: 汉字字符
limit: 返回结果数量限制0表示返回所有
Returns:
列表每个元素为(id, 拼音, 次数)按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if char not in self._char_to_ids:
return []
# 获取所有相关ID
ids = self._char_to_ids[char]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.pinyin, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
"""
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
Args:
pinyin: 拼音字符串
limit: 返回结果数量限制0表示返回所有
Returns:
列表每个元素为(id, 字符, 次数)按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if pinyin not in self._pinyin_to_ids:
return []
# 获取所有相关ID
ids = self._pinyin_to_ids[pinyin]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.char, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def get_char_frequency(self, char: str) -> int:
"""
获取字符的总出现频率所有拼音变体之和 - O(1)时间复杂度
Args:
char: 汉字字符
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_freq.get(char, 0)
def get_pinyin_frequency(self, pinyin: str) -> int:
"""
获取拼音的总出现频率所有字符之和 - O(1)时间复杂度
Args:
pinyin: 拼音字符串
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._pinyin_freq.get(pinyin, 0)
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
"""
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_pinyin_map.get((char, pinyin), 0)
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
"""
批量ID查询 - O(n)时间复杂度
Args:
ids: ID列表
Returns:
字典key为IDvalue为CharInfo对象不存在则为None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for id_value in ids:
results[id_value] = self._id_to_info.get(id_value)
return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
"""
批量字符查询
Args:
chars: 字符列表
limit_per_char: 每个字符的结果数量限制
Returns:
字典key为字符value为查询结果列表
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for char in chars:
results[char] = self.query_by_char(char, limit_per_char)
return results
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
"""
根据字符前缀搜索 - O(n)时间复杂度n为字符总数
Args:
prefix: 字符前缀
limit: 返回结果数量限制
Returns:
列表每个元素为(字符, 总频率)按频率降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
matches = []
for char, freq in self._char_freq.items():
if char.startswith(prefix):
matches.append((char, freq))
# 按频率降序排序
matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit] if limit > 0 else matches
def get_statistics(self) -> Dict[str, Any]:
"""
获取系统统计信息
Returns:
统计信息字典
"""
if not self._loaded:
return {"status": "not_loaded"}
top_chars = sorted(
self._char_freq.items(),
key=lambda x: x[1],
reverse=True
)[:10]
top_pinyins = sorted(
self._pinyin_freq.items(),
key=lambda x: x[1],
reverse=True
)[:10]
return {
"status": "loaded",
"timestamp": self._counter_data.timestamp,
"total_pairs": self._total_pairs,
"total_characters": len(self._char_freq),
"total_pinyins": len(self._pinyin_freq),
"valid_input_character_count": self._counter_data.valid_input_character_count,
"load_time_seconds": self._load_time,
"index_time_seconds": self._index_time,
"top_chars": top_chars,
"top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata
}
def is_loaded(self) -> bool:
"""检查数据是否已加载"""
return self._loaded
def clear(self):
"""清除所有数据和索引,释放内存"""
self._counter_data = None
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0