import json import pickle import msgpack from multiprocessing import Process, Lock, Queue, cpu_count from collections import Counter from dataclasses import asdict import time import queue from datasets import load_dataset from pypinyin import lazy_pinyin from loguru import logger from tqdm import trange from .char_info import PinyinCharPairsCounter, CharInfo class PinyinCharStatistics: def __init__( self, is_multithreaded=False, num_workers: int = 8, max_queue_size: int = 5000 ): """ 初始化处理器 """ self.max_queue_size = max_queue_size self.num_workers = num_workers self.is_multithreaded = is_multithreaded self._is_shutting_down = False # 主进程中的计数器(用于单线程模式) self._char_counter = Counter() self._pinyin_counter = Counter() self.pinyin_char_counter = Counter() # 多进程相关 if is_multithreaded: # manager = Manager() # 创建任务队列和结果队列 self.task_queue = Queue(maxsize=max_queue_size) self.result_queue = Queue() # 用于收集工作进程的结果 self.lock = Lock() self.workers = [] self.running = True # 启动工作进程 self._start_workers() def _start_workers(self): """启动工作进程""" for i in range(self.num_workers): worker = Process( target=self._worker_loop, args=(self.task_queue, self.result_queue), daemon=True ) worker.start() self.workers.append(worker) @staticmethod def _worker_loop(task_queue: Queue, result_queue: Queue): """工作进程主循环""" # 每个进程有自己的本地计数器 local_char_counter = Counter() local_pinyin_counter = Counter() local_pair_counter = Counter() should_exit = False while not should_exit: try: text = task_queue.get(timeout=1) if text is None: # 终止信号 should_exit = True continue # 继续循环,但会退出 # 处理文本 chinese_chars = [c for c in text if "\u4e00" <= c <= "\u9fff"] if len(chinese_chars) / len(text) < 0.6: continue # 获取拼音 pinyin_list = lazy_pinyin(text, errors=lambda x: list(x)) pairs = [ (char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin ] # 更新本地计数器 local_char_counter.update(chinese_chars) local_pinyin_counter.update([p for _, p in pairs]) local_pair_counter.update(pairs) except queue.Empty: if should_exit: break continue except Exception as e: logger.error(f"处理任务时出错: {e}") if local_char_counter or local_pair_counter: result_queue.put({ 'char_counter': dict(local_char_counter), 'pinyin_counter': dict(local_pinyin_counter), 'pair_counter': dict(local_pair_counter) }) def _collect_results(self): """收集所有工作进程的结果""" # 1. 先发送终止信号 for i in range(len(self.workers)): self.task_queue.put(None) # 2. 给工作进程时间处理 time.sleep(1) # 等待1秒让进程收到信号 # 3. 收集结果时增加重试机制 results_collected = 0 max_retries = len(self.workers) * 3 # 最多尝试3轮 for retry in range(max_retries): try: # 非阻塞检查 if results_collected >= len(self.workers): break result = self.result_queue.get(timeout=15) # 增加超时时间 if result: # 合并到主进程计数器 self._char_counter.update(result['char_counter']) self._pinyin_counter.update(result['pinyin_counter']) # 处理pair_counter(需要从字符串转换回元组) for pairs, count in result['pair_counter'].items(): # 将字符串 "(char, pinyin)" 转换回元组 if isinstance(pairs, str): try: # 简单解析:去除括号,分割字符串 content = pairs[1:-1] # 按逗号分割,最多分割一次 if ", " in content: char, pinyin = content.split(", ", 1) else: char, pinyin = content.split(",", 1) # 去除可能的引号 char = char.strip("'\"") pinyin = pinyin.strip("'\"") self.pinyin_char_counter.update({(char, pinyin): count}) except Exception: logger.warning(f"无法解析pair: {pairs}") elif isinstance(pairs, tuple): # 直接作为key使用 self.pinyin_char_counter.update({pairs: count}) else: logger.warning(f"无法处理pair: {pairs}") results_collected += 1 except queue.Empty: if retry < max_retries - 1: logger.warning(f"第 {retry + 1} 次等待结果超时,继续尝试...") continue else: logger.error(f"等待结果超时,已尝试 {max_retries} 次") break alive_count = 0 for worker in self.workers: worker.join(timeout=5) # 增加等待时间 if worker.is_alive(): logger.warning(f"工作进程 {worker.pid} 仍在运行") # 检查进程是否真的在忙 import psutil try: p = psutil.Process(worker.pid) # 检查CPU使用率或状态 logger.warning(f"进程状态: {p.status()}, CPU: {p.cpu_percent()}") except Exception: pass worker.terminate() alive_count += 1 if alive_count > 0: logger.warning(f"强制终止了 {alive_count} 个工作进程") else: logger.info("所有工作进程正常结束") self.workers.clear() def _process_single_text(self, text: str): """单线程处理单个文本(线程安全)""" chinese_chars = [c for c in text if self._is_chinese_char(c)] if len(chinese_chars) / len(text) < 0.6: return # 获取拼音 pinyin_list = lazy_pinyin(text, errors=lambda x: list(x)) pairs = [ (char, pinyin) for char, pinyin in zip(text, pinyin_list) if char != pinyin ] # 更新计数器 self.pinyin_char_counter.update(pairs) self._char_counter.update(chinese_chars) self._pinyin_counter.update([p for _, p in pairs]) def _is_chinese_char(self, char: str) -> bool: """判断字符是否为中文汉字""" return "\u4e00" <= char <= "\u9fff" def process_text(self, text: str): """ 将文本加入任务队列,由工作进程异步处理 """ if self._is_shutting_down: raise RuntimeError("处理器正在关闭,无法接受新任务") if self.is_multithreaded: if not self.running: raise RuntimeError("处理器已关闭,无法接受新任务") self.task_queue.put(text) if self.task_queue.qsize() > self.max_queue_size * 0.8: time.sleep(0.02) else: self._process_single_text(text) def shutdown(self, wait: bool = True): """ 关闭处理器,等待所有任务完成(可选) """ if self._is_shutting_down: return self._is_shutting_down = True self.running = False if wait and self.is_multithreaded: # 收集结果 self._collect_results() def export( self, output_path: str = None, format: str = "msgpack", compress: bool = False ): """ 导出统计信息 """ logger.info("开始导出快照...") # 如果是多进程模式,确保已收集所有结果 if self.is_multithreaded: self._collect_results() # 构建 PinyinCharPairsCounter 对象 counter_data = PinyinCharPairsCounter() counter_data.total_characters = len(self._char_counter) counter_data.total_pinyins = len(self._pinyin_counter) counter_data.valid_input_character_count = sum(self.pinyin_char_counter.values()) # 构造 CharInfo 列表 char_info_list = [] pc_counter_sorted = sorted( self.pinyin_char_counter.items(), key=lambda x: x[1], reverse=True ) for idx, ((char, pinyin), count) in enumerate(pc_counter_sorted): char_info = CharInfo(id=idx, char=char, pinyin=pinyin, count=count) char_info_list.append(char_info) counter_data.pairs = {info.id: info for info in char_info_list} counter_data.chars = dict(self._char_counter) counter_data.pinyins = dict(self._pinyin_counter) counter_data.metadata = { "format": format, "compressed": compress, "pair_count": len(counter_data.pairs), } if not output_path: dict_counter_data = asdict(counter_data) dict_counter_data.pop("pairs") dict_counter_data.pop("chars") dict_counter_data.pop("pinyins") print(dict_counter_data) return counter_data.metadata # 序列化数据 if format == "json": data_str = json.dumps(asdict(counter_data), ensure_ascii=False, indent=2) data = data_str.encode("utf-8") elif format == "msgpack": data = msgpack.packb(asdict(counter_data), use_bin_type=True) elif format == "pickle": data = pickle.dumps(counter_data) else: raise ValueError(f"不支持的格式: {format}") # 可选压缩 if compress: import gzip data = gzip.compress(data) # 写入文件 with open(output_path, "wb") as f: f.write(data) logger.info(f"快照导出完成: {output_path} (格式: {format}, 压缩: {compress})") return counter_data.metadata def counter_dataset( dataset_path: str, num_workers: int = cpu_count(), output_path: str = None, format: str = "msgpack", compress: bool = True, max_queue_size: int = 50000, max_lines: int = 10000, ): """ 处理数据集的核心逻辑 """ dataset = load_dataset(dataset_path, split="train", streaming=True) iter_data = iter(dataset) # 使用多进程 processor = PinyinCharStatistics( is_multithreaded=True, num_workers=4, max_queue_size=max_queue_size ) # 使用 tqdm.trange 显示进度条 for i in trange(max_lines, desc="正在处理数据"): try: line = next(iter_data) text = line.get("text", "") if text: processor.process_text(text) except StopIteration: print("⚠️ 数据集已提前结束。") break # 导出前会收集所有结果 processor.export(output_path, format=format, compress=compress) if not processor._is_shutting_down: processor.shutdown()