SUInput/src/tmp_utils/counter.py

344 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import pickle
import msgpack
from multiprocessing import Process, Lock, Queue, cpu_count
from collections import Counter
from dataclasses import asdict
import time
import queue
from datasets import load_dataset
from pypinyin import lazy_pinyin
from loguru import logger
from tqdm import trange
from .char_info import PinyinCharPairsCounter, CharInfo
class PinyinCharStatistics:
def __init__(
self, is_multithreaded=False, num_workers: int = 8, max_queue_size: int = 5000
):
"""
初始化处理器
"""
self.max_queue_size = max_queue_size
self.num_workers = num_workers
self.is_multithreaded = is_multithreaded
self._is_shutting_down = False
# 主进程中的计数器(用于单线程模式)
self._char_counter = Counter()
self._pinyin_counter = Counter()
self.pinyin_char_counter = Counter()
# 多进程相关
if is_multithreaded:
# manager = Manager()
# 创建任务队列和结果队列
self.task_queue = Queue(maxsize=max_queue_size)
self.result_queue = Queue() # 用于收集工作进程的结果
self.lock = Lock()
self.workers = []
self.running = True
# 启动工作进程
self._start_workers()
def _start_workers(self):
"""启动工作进程"""
for i in range(self.num_workers):
worker = Process(
target=self._worker_loop,
args=(self.task_queue, self.result_queue),
daemon=True
)
worker.start()
self.workers.append(worker)
@staticmethod
def _worker_loop(task_queue: Queue, result_queue: Queue):
"""工作进程主循环"""
# 每个进程有自己的本地计数器
local_char_counter = Counter()
local_pinyin_counter = Counter()
local_pair_counter = Counter()
should_exit = False
while not should_exit:
try:
text = task_queue.get(timeout=1)
if text is None: # 终止信号
should_exit = True
continue # 继续循环,但会退出
# 处理文本
chinese_chars = [c for c in text if "\u4e00" <= c <= "\u9fff"]
if len(chinese_chars) / len(text) < 0.6:
continue
# 获取拼音
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
pairs = [
(char, pinyin) for char, pinyin in zip(text, pinyin_list)
if char != pinyin
]
# 更新本地计数器
local_char_counter.update(chinese_chars)
local_pinyin_counter.update([p for _, p in pairs])
local_pair_counter.update(pairs)
except queue.Empty:
if should_exit:
break
continue
except Exception as e:
logger.error(f"处理任务时出错: {e}")
if local_char_counter or local_pair_counter:
result_queue.put({
'char_counter': dict(local_char_counter),
'pinyin_counter': dict(local_pinyin_counter),
'pair_counter': dict(local_pair_counter)
})
def _collect_results(self):
"""收集所有工作进程的结果"""
# 1. 先发送终止信号
for i in range(len(self.workers)):
self.task_queue.put(None)
# 2. 给工作进程时间处理
time.sleep(1) # 等待1秒让进程收到信号
# 3. 收集结果时增加重试机制
results_collected = 0
max_retries = len(self.workers) * 3 # 最多尝试3轮
for retry in range(max_retries):
try:
# 非阻塞检查
if results_collected >= len(self.workers):
break
result = self.result_queue.get(timeout=15) # 增加超时时间
if result:
# 合并到主进程计数器
self._char_counter.update(result['char_counter'])
self._pinyin_counter.update(result['pinyin_counter'])
# 处理pair_counter需要从字符串转换回元组
for pairs, count in result['pair_counter'].items():
# 将字符串 "(char, pinyin)" 转换回元组
if isinstance(pairs, str):
try:
# 简单解析:去除括号,分割字符串
content = pairs[1:-1]
# 按逗号分割,最多分割一次
if ", " in content:
char, pinyin = content.split(", ", 1)
else:
char, pinyin = content.split(",", 1)
# 去除可能的引号
char = char.strip("'\"")
pinyin = pinyin.strip("'\"")
self.pinyin_char_counter.update({(char, pinyin): count})
except Exception:
logger.warning(f"无法解析pair: {pairs}")
elif isinstance(pairs, tuple):
# 直接作为key使用
self.pinyin_char_counter.update({pairs: count})
else:
logger.warning(f"无法处理pair: {pairs}")
results_collected += 1
except queue.Empty:
if retry < max_retries - 1:
logger.warning(f"{retry + 1} 次等待结果超时,继续尝试...")
continue
else:
logger.error(f"等待结果超时,已尝试 {max_retries}")
break
alive_count = 0
for worker in self.workers:
worker.join(timeout=5) # 增加等待时间
if worker.is_alive():
logger.warning(f"工作进程 {worker.pid} 仍在运行")
# 检查进程是否真的在忙
import psutil
try:
p = psutil.Process(worker.pid)
# 检查CPU使用率或状态
logger.warning(f"进程状态: {p.status()}, CPU: {p.cpu_percent()}")
except Exception:
pass
worker.terminate()
alive_count += 1
if alive_count > 0:
logger.warning(f"强制终止了 {alive_count} 个工作进程")
else:
logger.info("所有工作进程正常结束")
self.workers.clear()
def _process_single_text(self, text: str):
"""单线程处理单个文本(线程安全)"""
chinese_chars = [c for c in text if self._is_chinese_char(c)]
if len(chinese_chars) / len(text) < 0.6:
return
# 获取拼音
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
pairs = [
(char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin
]
# 更新计数器
self.pinyin_char_counter.update(pairs)
self._char_counter.update(chinese_chars)
self._pinyin_counter.update([p for _, p in pairs])
def _is_chinese_char(self, char: str) -> bool:
"""判断字符是否为中文汉字"""
return "\u4e00" <= char <= "\u9fff"
def process_text(self, text: str):
"""
将文本加入任务队列,由工作进程异步处理
"""
if self._is_shutting_down:
raise RuntimeError("处理器正在关闭,无法接受新任务")
if self.is_multithreaded:
if not self.running:
raise RuntimeError("处理器已关闭,无法接受新任务")
self.task_queue.put(text)
if self.task_queue.qsize() > self.max_queue_size * 0.8:
time.sleep(0.02)
else:
self._process_single_text(text)
def shutdown(self, wait: bool = True):
"""
关闭处理器,等待所有任务完成(可选)
"""
if self._is_shutting_down:
return
self._is_shutting_down = True
self.running = False
if wait and self.is_multithreaded:
# 收集结果
self._collect_results()
def export(
self, output_path: str = None, format: str = "msgpack", compress: bool = False
):
"""
导出统计信息
"""
logger.info("开始导出快照...")
# 如果是多进程模式,确保已收集所有结果
if self.is_multithreaded:
self._collect_results()
# 构建 PinyinCharPairsCounter 对象
counter_data = PinyinCharPairsCounter()
counter_data.total_characters = len(self._char_counter)
counter_data.total_pinyins = len(self._pinyin_counter)
counter_data.valid_input_character_count = sum(self.pinyin_char_counter.values())
# 构造 CharInfo 列表
char_info_list = []
pc_counter_sorted = sorted(
self.pinyin_char_counter.items(), key=lambda x: x[1], reverse=True
)
for idx, ((char, pinyin), count) in enumerate(pc_counter_sorted):
char_info = CharInfo(id=idx, char=char, pinyin=pinyin, count=count)
char_info_list.append(char_info)
counter_data.pairs = {info.id: info for info in char_info_list}
counter_data.chars = dict(self._char_counter)
counter_data.pinyins = dict(self._pinyin_counter)
counter_data.metadata = {
"format": format,
"compressed": compress,
"pair_count": len(counter_data.pairs),
}
if not output_path:
dict_counter_data = asdict(counter_data)
dict_counter_data.pop("pairs")
dict_counter_data.pop("chars")
dict_counter_data.pop("pinyins")
print(dict_counter_data)
return counter_data.metadata
# 序列化数据
if format == "json":
data_str = json.dumps(asdict(counter_data), ensure_ascii=False, indent=2)
data = data_str.encode("utf-8")
elif format == "msgpack":
data = msgpack.packb(asdict(counter_data), use_bin_type=True)
elif format == "pickle":
data = pickle.dumps(counter_data)
else:
raise ValueError(f"不支持的格式: {format}")
# 可选压缩
if compress:
import gzip
data = gzip.compress(data)
# 写入文件
with open(output_path, "wb") as f:
f.write(data)
logger.info(f"快照导出完成: {output_path} (格式: {format}, 压缩: {compress})")
return counter_data.metadata
def counter_dataset(
dataset_path: str,
num_workers: int = cpu_count(),
output_path: str = None,
format: str = "msgpack",
compress: bool = True,
max_queue_size: int = 50000,
max_lines: int = 10000,
):
"""
处理数据集的核心逻辑
"""
dataset = load_dataset(dataset_path, split="train", streaming=True)
iter_data = iter(dataset)
# 使用多进程
processor = PinyinCharStatistics(
is_multithreaded=True,
num_workers=4,
max_queue_size=max_queue_size
)
# 使用 tqdm.trange 显示进度条
for i in trange(max_lines, desc="正在处理数据"):
try:
line = next(iter_data)
text = line.get("text", "")
if text:
processor.process_text(text)
except StopIteration:
print("⚠️ 数据集已提前结束。")
break
# 导出前会收集所有结果
processor.export(output_path, format=format, compress=compress)
if not processor._is_shutting_down:
processor.shutdown()