344 lines
12 KiB
Python
344 lines
12 KiB
Python
import json
|
||
import pickle
|
||
import msgpack
|
||
from multiprocessing import Process, Lock, Queue, cpu_count
|
||
from collections import Counter
|
||
from dataclasses import asdict
|
||
import time
|
||
import queue
|
||
|
||
from datasets import load_dataset
|
||
|
||
from pypinyin import lazy_pinyin
|
||
from loguru import logger
|
||
|
||
from tqdm import trange
|
||
|
||
from .char_info import PinyinCharPairsCounter, CharInfo
|
||
|
||
|
||
|
||
class PinyinCharStatistics:
|
||
def __init__(
|
||
self, is_multithreaded=False, num_workers: int = 8, max_queue_size: int = 5000
|
||
):
|
||
"""
|
||
初始化处理器
|
||
"""
|
||
self.max_queue_size = max_queue_size
|
||
self.num_workers = num_workers
|
||
self.is_multithreaded = is_multithreaded
|
||
self._is_shutting_down = False
|
||
|
||
# 主进程中的计数器(用于单线程模式)
|
||
self._char_counter = Counter()
|
||
self._pinyin_counter = Counter()
|
||
self.pinyin_char_counter = Counter()
|
||
|
||
# 多进程相关
|
||
if is_multithreaded:
|
||
# manager = Manager()
|
||
# 创建任务队列和结果队列
|
||
self.task_queue = Queue(maxsize=max_queue_size)
|
||
self.result_queue = Queue() # 用于收集工作进程的结果
|
||
self.lock = Lock()
|
||
self.workers = []
|
||
self.running = True
|
||
|
||
# 启动工作进程
|
||
self._start_workers()
|
||
|
||
def _start_workers(self):
|
||
"""启动工作进程"""
|
||
for i in range(self.num_workers):
|
||
worker = Process(
|
||
target=self._worker_loop,
|
||
args=(self.task_queue, self.result_queue),
|
||
daemon=True
|
||
)
|
||
worker.start()
|
||
self.workers.append(worker)
|
||
|
||
@staticmethod
|
||
def _worker_loop(task_queue: Queue, result_queue: Queue):
|
||
"""工作进程主循环"""
|
||
# 每个进程有自己的本地计数器
|
||
local_char_counter = Counter()
|
||
local_pinyin_counter = Counter()
|
||
local_pair_counter = Counter()
|
||
|
||
should_exit = False
|
||
|
||
while not should_exit:
|
||
try:
|
||
text = task_queue.get(timeout=1)
|
||
if text is None: # 终止信号
|
||
should_exit = True
|
||
continue # 继续循环,但会退出
|
||
# 处理文本
|
||
chinese_chars = [c for c in text if "\u4e00" <= c <= "\u9fff"]
|
||
if len(chinese_chars) / len(text) < 0.6:
|
||
continue
|
||
|
||
# 获取拼音
|
||
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
|
||
pairs = [
|
||
(char, pinyin) for char, pinyin in zip(text, pinyin_list)
|
||
if char != pinyin
|
||
]
|
||
|
||
# 更新本地计数器
|
||
local_char_counter.update(chinese_chars)
|
||
local_pinyin_counter.update([p for _, p in pairs])
|
||
local_pair_counter.update(pairs)
|
||
except queue.Empty:
|
||
if should_exit:
|
||
break
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"处理任务时出错: {e}")
|
||
if local_char_counter or local_pair_counter:
|
||
result_queue.put({
|
||
'char_counter': dict(local_char_counter),
|
||
'pinyin_counter': dict(local_pinyin_counter),
|
||
'pair_counter': dict(local_pair_counter)
|
||
})
|
||
|
||
def _collect_results(self):
|
||
"""收集所有工作进程的结果"""
|
||
# 1. 先发送终止信号
|
||
for i in range(len(self.workers)):
|
||
self.task_queue.put(None)
|
||
|
||
# 2. 给工作进程时间处理
|
||
time.sleep(1) # 等待1秒让进程收到信号
|
||
|
||
# 3. 收集结果时增加重试机制
|
||
results_collected = 0
|
||
max_retries = len(self.workers) * 3 # 最多尝试3轮
|
||
|
||
for retry in range(max_retries):
|
||
try:
|
||
# 非阻塞检查
|
||
if results_collected >= len(self.workers):
|
||
break
|
||
|
||
result = self.result_queue.get(timeout=15) # 增加超时时间
|
||
|
||
if result:
|
||
# 合并到主进程计数器
|
||
self._char_counter.update(result['char_counter'])
|
||
self._pinyin_counter.update(result['pinyin_counter'])
|
||
|
||
# 处理pair_counter(需要从字符串转换回元组)
|
||
for pairs, count in result['pair_counter'].items():
|
||
# 将字符串 "(char, pinyin)" 转换回元组
|
||
if isinstance(pairs, str):
|
||
try:
|
||
# 简单解析:去除括号,分割字符串
|
||
content = pairs[1:-1]
|
||
# 按逗号分割,最多分割一次
|
||
if ", " in content:
|
||
char, pinyin = content.split(", ", 1)
|
||
else:
|
||
char, pinyin = content.split(",", 1)
|
||
# 去除可能的引号
|
||
char = char.strip("'\"")
|
||
pinyin = pinyin.strip("'\"")
|
||
self.pinyin_char_counter.update({(char, pinyin): count})
|
||
except Exception:
|
||
logger.warning(f"无法解析pair: {pairs}")
|
||
elif isinstance(pairs, tuple):
|
||
# 直接作为key使用
|
||
self.pinyin_char_counter.update({pairs: count})
|
||
else:
|
||
logger.warning(f"无法处理pair: {pairs}")
|
||
|
||
results_collected += 1
|
||
except queue.Empty:
|
||
if retry < max_retries - 1:
|
||
logger.warning(f"第 {retry + 1} 次等待结果超时,继续尝试...")
|
||
continue
|
||
else:
|
||
logger.error(f"等待结果超时,已尝试 {max_retries} 次")
|
||
break
|
||
|
||
alive_count = 0
|
||
for worker in self.workers:
|
||
worker.join(timeout=5) # 增加等待时间
|
||
if worker.is_alive():
|
||
logger.warning(f"工作进程 {worker.pid} 仍在运行")
|
||
# 检查进程是否真的在忙
|
||
import psutil
|
||
try:
|
||
p = psutil.Process(worker.pid)
|
||
# 检查CPU使用率或状态
|
||
logger.warning(f"进程状态: {p.status()}, CPU: {p.cpu_percent()}")
|
||
except Exception:
|
||
pass
|
||
worker.terminate()
|
||
alive_count += 1
|
||
|
||
if alive_count > 0:
|
||
logger.warning(f"强制终止了 {alive_count} 个工作进程")
|
||
else:
|
||
logger.info("所有工作进程正常结束")
|
||
|
||
self.workers.clear()
|
||
def _process_single_text(self, text: str):
|
||
"""单线程处理单个文本(线程安全)"""
|
||
chinese_chars = [c for c in text if self._is_chinese_char(c)]
|
||
if len(chinese_chars) / len(text) < 0.6:
|
||
return
|
||
|
||
# 获取拼音
|
||
pinyin_list = lazy_pinyin(text, errors=lambda x: list(x))
|
||
pairs = [
|
||
(char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin
|
||
]
|
||
|
||
# 更新计数器
|
||
self.pinyin_char_counter.update(pairs)
|
||
self._char_counter.update(chinese_chars)
|
||
self._pinyin_counter.update([p for _, p in pairs])
|
||
|
||
def _is_chinese_char(self, char: str) -> bool:
|
||
"""判断字符是否为中文汉字"""
|
||
return "\u4e00" <= char <= "\u9fff"
|
||
|
||
def process_text(self, text: str):
|
||
"""
|
||
将文本加入任务队列,由工作进程异步处理
|
||
"""
|
||
if self._is_shutting_down:
|
||
raise RuntimeError("处理器正在关闭,无法接受新任务")
|
||
|
||
if self.is_multithreaded:
|
||
if not self.running:
|
||
raise RuntimeError("处理器已关闭,无法接受新任务")
|
||
self.task_queue.put(text)
|
||
if self.task_queue.qsize() > self.max_queue_size * 0.8:
|
||
time.sleep(0.02)
|
||
else:
|
||
self._process_single_text(text)
|
||
|
||
def shutdown(self, wait: bool = True):
|
||
"""
|
||
关闭处理器,等待所有任务完成(可选)
|
||
"""
|
||
if self._is_shutting_down:
|
||
return
|
||
|
||
self._is_shutting_down = True
|
||
self.running = False
|
||
|
||
if wait and self.is_multithreaded:
|
||
# 收集结果
|
||
self._collect_results()
|
||
|
||
def export(
|
||
self, output_path: str = None, format: str = "msgpack", compress: bool = False
|
||
):
|
||
"""
|
||
导出统计信息
|
||
"""
|
||
logger.info("开始导出快照...")
|
||
|
||
# 如果是多进程模式,确保已收集所有结果
|
||
if self.is_multithreaded:
|
||
self._collect_results()
|
||
|
||
# 构建 PinyinCharPairsCounter 对象
|
||
counter_data = PinyinCharPairsCounter()
|
||
counter_data.total_characters = len(self._char_counter)
|
||
counter_data.total_pinyins = len(self._pinyin_counter)
|
||
counter_data.valid_input_character_count = sum(self.pinyin_char_counter.values())
|
||
|
||
# 构造 CharInfo 列表
|
||
char_info_list = []
|
||
pc_counter_sorted = sorted(
|
||
self.pinyin_char_counter.items(), key=lambda x: x[1], reverse=True
|
||
)
|
||
|
||
for idx, ((char, pinyin), count) in enumerate(pc_counter_sorted):
|
||
char_info = CharInfo(id=idx, char=char, pinyin=pinyin, count=count)
|
||
char_info_list.append(char_info)
|
||
|
||
counter_data.pairs = {info.id: info for info in char_info_list}
|
||
counter_data.chars = dict(self._char_counter)
|
||
counter_data.pinyins = dict(self._pinyin_counter)
|
||
counter_data.metadata = {
|
||
"format": format,
|
||
"compressed": compress,
|
||
"pair_count": len(counter_data.pairs),
|
||
}
|
||
|
||
if not output_path:
|
||
dict_counter_data = asdict(counter_data)
|
||
dict_counter_data.pop("pairs")
|
||
dict_counter_data.pop("chars")
|
||
dict_counter_data.pop("pinyins")
|
||
print(dict_counter_data)
|
||
return counter_data.metadata
|
||
|
||
# 序列化数据
|
||
if format == "json":
|
||
data_str = json.dumps(asdict(counter_data), ensure_ascii=False, indent=2)
|
||
data = data_str.encode("utf-8")
|
||
elif format == "msgpack":
|
||
data = msgpack.packb(asdict(counter_data), use_bin_type=True)
|
||
elif format == "pickle":
|
||
data = pickle.dumps(counter_data)
|
||
else:
|
||
raise ValueError(f"不支持的格式: {format}")
|
||
|
||
# 可选压缩
|
||
if compress:
|
||
import gzip
|
||
data = gzip.compress(data)
|
||
|
||
# 写入文件
|
||
with open(output_path, "wb") as f:
|
||
f.write(data)
|
||
|
||
logger.info(f"快照导出完成: {output_path} (格式: {format}, 压缩: {compress})")
|
||
return counter_data.metadata
|
||
|
||
|
||
def counter_dataset(
|
||
dataset_path: str,
|
||
num_workers: int = cpu_count(),
|
||
output_path: str = None,
|
||
format: str = "msgpack",
|
||
compress: bool = True,
|
||
max_queue_size: int = 50000,
|
||
max_lines: int = 10000,
|
||
):
|
||
"""
|
||
处理数据集的核心逻辑
|
||
"""
|
||
dataset = load_dataset(dataset_path, split="train", streaming=True)
|
||
iter_data = iter(dataset)
|
||
|
||
# 使用多进程
|
||
processor = PinyinCharStatistics(
|
||
is_multithreaded=True,
|
||
num_workers=4,
|
||
max_queue_size=max_queue_size
|
||
)
|
||
|
||
# 使用 tqdm.trange 显示进度条
|
||
for i in trange(max_lines, desc="正在处理数据"):
|
||
try:
|
||
line = next(iter_data)
|
||
text = line.get("text", "")
|
||
if text:
|
||
processor.process_text(text)
|
||
except StopIteration:
|
||
print("⚠️ 数据集已提前结束。")
|
||
break
|
||
|
||
# 导出前会收集所有结果
|
||
processor.export(output_path, format=format, compress=compress)
|
||
if not processor._is_shutting_down:
|
||
processor.shutdown() |