#!/usr/bin/env python3 """ 预处理数据质量分析脚本 功能: 1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数) 2. 随机抽样还原为人类可读文本,导出为 CSV 文件 用法: python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv """ import argparse import csv import json import random from collections import Counter from pathlib import Path import numpy as np from loguru import logger from rich.console import Console from rich.table import Table from tqdm import tqdm from .char_info import CharInfo from .dataset import CHAR_TO_ID from .preprocessed_dataset import PreProcessedDataset from .query import QueryEngine ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()} def decode_pinyin_ids(pinyin_ids: list) -> str: """将 pinyin_ids 还原为拼音字符串""" chars = [] for pid in pinyin_ids: if pid == 0: break chars.append(ID_TO_CHAR.get(pid, "?")) return "".join(chars) def decode_history(history_ids: list, query_engine: QueryEngine) -> str: """将 history_slot_ids 还原为文字""" parts = [] for hid in history_ids: if hid == 0: parts.append("") else: info = query_engine.query_by_id(hid) if info is not None: parts.append(f"{info.char}({info.pinyin})") else: parts.append(f"") return " | ".join(parts) def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0): """统计 labels 分布,带进度条""" logger.info("正在统计 labels 分布...") counter = Counter() total = 0 num_shards = dataset._num_shards if dataset._is_sharded else 1 effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard") for shard_idx in pbar: if dataset._is_sharded: shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz")) labels = shard_data["labels"].astype(np.int64) else: labels = dataset.labels[:].astype(np.int64) unique, counts = np.unique(labels, return_counts=True) for uid, cnt in zip(unique, counts): counter[int(uid)] += cnt total += len(labels) if dataset._is_sharded: del shard_data return counter, total def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict: """将一个样本还原为人类可读格式""" input_ids = ( sample["input_ids"].tolist() if hasattr(sample["input_ids"], "tolist") else sample["input_ids"] ) token_type_ids = ( sample["token_type_ids"].tolist() if hasattr(sample["token_type_ids"], "tolist") else sample["token_type_ids"] ) labels = ( sample["labels"].item() if hasattr(sample["labels"], "item") else sample["labels"] ) history_ids = ( sample["history_slot_ids"].tolist() if hasattr(sample["history_slot_ids"], "tolist") else sample["history_slot_ids"] ) pinyin_ids = ( sample["pinyin_ids"].tolist() if hasattr(sample["pinyin_ids"], "tolist") else sample["pinyin_ids"] ) # 还原 token 文本 token_text = tokenizer.decode(input_ids, skip_special_tokens=False) # 找到 token_type_ids 切换点,分离 sentence A 和 sentence B sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1] if sep_positions: sep_start = sep_positions[0] sent_a_ids = [ tid for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start]) if tt == 0 ] sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1] else: sent_a_ids = input_ids sent_b_ids = [] context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True) suffix_text = ( tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else "" ) pinyin_str = decode_pinyin_ids(pinyin_ids) label_info = query_engine.query_by_id(labels) history_str = decode_history(history_ids, query_engine) return { "context": context_text, "suffix": suffix_text, "pinyin": pinyin_str, "label_id": labels, "label_char": f"{label_info.char}({label_info.pinyin})" if label_info else f"", "label_count": label_info.count if label_info else 0, "history": history_str, "full_tokens": token_text, } def main(): console = Console() parser = argparse.ArgumentParser(description="预处理数据质量分析") parser.add_argument( "--data-dir", type=str, required=True, help="预处理数据目录(train/ 或 eval/)", ) parser.add_argument( "--num-samples", type=int, default=50, help="随机抽样的样本数量(默认50)", ) parser.add_argument( "--output", type=str, default=None, help="CSV 输出文件路径(默认: /samples.csv)", ) parser.add_argument( "--max-shards", type=int, default=0, help="统计 labels 时最多读取的分片数(0=全部)", ) parser.add_argument( "--seed", type=int, default=42, help="随机种子", ) parser.add_argument( "--top-k", type=int, default=30, help="显示出现次数最多和最少的标签数量", ) args = parser.parse_args() random.seed(args.seed) np.random.seed(args.seed) if args.output is None: args.output = str(Path(args.data_dir) / "samples.csv") # 加载数据集 logger.info(f"加载数据集: {args.data_dir}") dataset = PreProcessedDataset(args.data_dir) console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]") if dataset._is_sharded: console.print( f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本" ) console.print() # 加载 QueryEngine logger.info("加载 QueryEngine...") query_engine = QueryEngine() query_engine.load() # 加载 Tokenizer logger.info("加载 Tokenizer...") from importlib.resources import files as pkg_files from modelscope import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( Path(str(pkg_files(__package__))) / "assets" / "tokenizer" ) # ====== 1. Labels 分布分析 ====== console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]") counter, total = analyze_labels(dataset, max_shards=args.max_shards) # 获取词表总大小 vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0) appeared_ids = set(counter.keys()) all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS) missing_ids = all_ids - appeared_ids console.print(f"\n总样本数: {total:,}") console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)") console.print(f"唯一标签数: {len(counter):,}") console.print( f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)" ) console.print( f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]" ) most_common = counter.most_common(args.top_k) least_common = ( counter.most_common()[: -args.top_k - 1 : -1] if len(counter) > args.top_k else counter.most_common() ) # 最多标签表 table_top = Table( title=f"出现次数最多的 {args.top_k} 个标签", show_header=True, header_style="bold magenta", ) table_top.add_column("排名", style="cyan", width=6) table_top.add_column("ID", style="green", width=8) table_top.add_column("字符(拼音)", style="yellow", width=20) table_top.add_column("频次", style="red", width=12) table_top.add_column("占比", style="blue", width=10) for rank, (label_id, count) in enumerate(most_common, 1): info = query_engine.query_by_id(label_id) label_str = f"{info.char}({info.pinyin})" if info else f"" pct = count / total * 100 table_top.add_row( str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%" ) console.print(table_top) # 最少标签表 table_bottom = Table( title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签", show_header=True, header_style="bold magenta", ) table_bottom.add_column("排名", style="cyan", width=6) table_bottom.add_column("ID", style="green", width=8) table_bottom.add_column("字符(拼音)", style="yellow", width=20) table_bottom.add_column("频次", style="red", width=12) table_bottom.add_column("占比", style="blue", width=10) for rank, (label_id, count) in enumerate(least_common, 1): info = query_engine.query_by_id(label_id) label_str = f"{info.char}({info.pinyin})" if info else f"" pct = count / total * 100 table_bottom.add_row( str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%" ) console.print(table_bottom) # 频次分布概览 table_dist = Table( title="频次分布概览", show_header=True, header_style="bold magenta" ) table_dist.add_column("频次区间", style="cyan") table_dist.add_column("标签数", style="green") table_dist.add_column("占总标签数比例", style="yellow") bins = [ (1, 10), (11, 100), (101, 1000), (1001, 10000), (10001, 100000), (100001, 1000000), (1000001, float("inf")), ] for lo, hi in bins: count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi) if count_in_bin > 0: hi_str = str(int(hi)) if hi != float("inf") else "∞" table_dist.add_row( f"{lo}-{hi_str}", f"{count_in_bin:,}", f"{count_in_bin / len(counter) * 100:.1f}%", ) # 未出现 if len(missing_ids) > 0: table_dist.add_row( "未出现", f"{len(missing_ids):,}", f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%", ) console.print(table_dist) # ====== 2. 随机抽样还原 → CSV ====== num_samples = min(args.num_samples, len(dataset)) console.print( f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]" ) indices = random.sample(range(len(dataset)), num_samples) csv_path = Path(args.output) csv_path.parent.mkdir(parents=True, exist_ok=True) csv_headers = [ "index", "pinyin", "label_char", "label_id", "label_count", "context", "suffix", "history", "full_tokens", ] with open(csv_path, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow(csv_headers) for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")): sample = dataset[idx] decoded = decode_sample(sample, tokenizer, query_engine) writer.writerow( [ idx, decoded["pinyin"], decoded["label_char"], decoded["label_id"], decoded["label_count"], decoded["context"], decoded["suffix"], decoded["history"], decoded["full_tokens"], ] ) console.print( f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]" ) # 打印前 5 个样本的概要 console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]") with open(csv_path, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for i, row in enumerate(reader): if i >= 5: break console.print( f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} " f"上下文={row['context'][:50]}..." ) console.print("\n[bold green]分析完成[/bold green]") app = main if __name__ == "__main__": main()