402 lines
12 KiB
Python
402 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
预处理数据质量分析脚本
|
||
|
||
功能:
|
||
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
||
2. 随机抽样还原为人类可读文本,导出为 CSV 文件
|
||
|
||
用法:
|
||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
|
||
"""
|
||
|
||
import argparse
|
||
import csv
|
||
import json
|
||
import random
|
||
from collections import Counter
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
from loguru import logger
|
||
from rich.console import Console
|
||
from rich.table import Table
|
||
from tqdm import tqdm
|
||
|
||
from .char_info import CharInfo
|
||
from .dataset import CHAR_TO_ID
|
||
from .preprocessed_dataset import PreProcessedDataset
|
||
from .query import QueryEngine
|
||
|
||
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
|
||
|
||
|
||
def decode_pinyin_ids(pinyin_ids: list) -> str:
|
||
"""将 pinyin_ids 还原为拼音字符串"""
|
||
chars = []
|
||
for pid in pinyin_ids:
|
||
if pid == 0:
|
||
break
|
||
chars.append(ID_TO_CHAR.get(pid, "?"))
|
||
return "".join(chars)
|
||
|
||
|
||
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
|
||
"""将 history_slot_ids 还原为文字"""
|
||
parts = []
|
||
for hid in history_ids:
|
||
if hid == 0:
|
||
parts.append("<PAD>")
|
||
else:
|
||
info = query_engine.query_by_id(hid)
|
||
if info is not None:
|
||
parts.append(f"{info.char}({info.pinyin})")
|
||
else:
|
||
parts.append(f"<ID:{hid}>")
|
||
return " | ".join(parts)
|
||
|
||
|
||
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
|
||
"""统计 labels 分布,带进度条"""
|
||
logger.info("正在统计 labels 分布...")
|
||
counter = Counter()
|
||
total = 0
|
||
|
||
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
||
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
||
|
||
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
|
||
|
||
for shard_idx in pbar:
|
||
if dataset._is_sharded:
|
||
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
||
labels = shard_data["labels"].astype(np.int64)
|
||
else:
|
||
labels = dataset.labels[:].astype(np.int64)
|
||
|
||
unique, counts = np.unique(labels, return_counts=True)
|
||
for uid, cnt in zip(unique, counts):
|
||
counter[int(uid)] += cnt
|
||
total += len(labels)
|
||
|
||
if dataset._is_sharded:
|
||
del shard_data
|
||
|
||
return counter, total
|
||
|
||
|
||
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||
"""将一个样本还原为人类可读格式"""
|
||
input_ids = (
|
||
sample["input_ids"].tolist()
|
||
if hasattr(sample["input_ids"], "tolist")
|
||
else sample["input_ids"]
|
||
)
|
||
token_type_ids = (
|
||
sample["token_type_ids"].tolist()
|
||
if hasattr(sample["token_type_ids"], "tolist")
|
||
else sample["token_type_ids"]
|
||
)
|
||
labels = (
|
||
sample["labels"].item()
|
||
if hasattr(sample["labels"], "item")
|
||
else sample["labels"]
|
||
)
|
||
history_ids = (
|
||
sample["history_slot_ids"].tolist()
|
||
if hasattr(sample["history_slot_ids"], "tolist")
|
||
else sample["history_slot_ids"]
|
||
)
|
||
pinyin_ids = (
|
||
sample["pinyin_ids"].tolist()
|
||
if hasattr(sample["pinyin_ids"], "tolist")
|
||
else sample["pinyin_ids"]
|
||
)
|
||
|
||
# 还原 token 文本
|
||
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||
|
||
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
|
||
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
|
||
if sep_positions:
|
||
sep_start = sep_positions[0]
|
||
sent_a_ids = [
|
||
tid
|
||
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
|
||
if tt == 0
|
||
]
|
||
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
|
||
else:
|
||
sent_a_ids = input_ids
|
||
sent_b_ids = []
|
||
|
||
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
|
||
suffix_text = (
|
||
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
|
||
)
|
||
|
||
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
||
label_info = query_engine.query_by_id(labels)
|
||
history_str = decode_history(history_ids, query_engine)
|
||
|
||
return {
|
||
"context": context_text,
|
||
"suffix": suffix_text,
|
||
"pinyin": pinyin_str,
|
||
"label_id": labels,
|
||
"label_char": f"{label_info.char}({label_info.pinyin})"
|
||
if label_info
|
||
else f"<ID:{labels}>",
|
||
"label_count": label_info.count if label_info else 0,
|
||
"history": history_str,
|
||
"full_tokens": token_text,
|
||
}
|
||
|
||
|
||
def main():
|
||
console = Console()
|
||
|
||
parser = argparse.ArgumentParser(description="预处理数据质量分析")
|
||
parser.add_argument(
|
||
"--data-dir",
|
||
type=str,
|
||
required=True,
|
||
help="预处理数据目录(train/ 或 eval/)",
|
||
)
|
||
parser.add_argument(
|
||
"--num-samples",
|
||
type=int,
|
||
default=50,
|
||
help="随机抽样的样本数量(默认50)",
|
||
)
|
||
parser.add_argument(
|
||
"--output",
|
||
type=str,
|
||
default=None,
|
||
help="CSV 输出文件路径(默认: <data-dir>/samples.csv)",
|
||
)
|
||
parser.add_argument(
|
||
"--max-shards",
|
||
type=int,
|
||
default=0,
|
||
help="统计 labels 时最多读取的分片数(0=全部)",
|
||
)
|
||
parser.add_argument(
|
||
"--seed",
|
||
type=int,
|
||
default=42,
|
||
help="随机种子",
|
||
)
|
||
parser.add_argument(
|
||
"--top-k",
|
||
type=int,
|
||
default=30,
|
||
help="显示出现次数最多和最少的标签数量",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
random.seed(args.seed)
|
||
np.random.seed(args.seed)
|
||
|
||
if args.output is None:
|
||
args.output = str(Path(args.data_dir) / "samples.csv")
|
||
|
||
# 加载数据集
|
||
logger.info(f"加载数据集: {args.data_dir}")
|
||
dataset = PreProcessedDataset(args.data_dir)
|
||
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
||
if dataset._is_sharded:
|
||
console.print(
|
||
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
|
||
)
|
||
console.print()
|
||
|
||
# 加载 QueryEngine
|
||
logger.info("加载 QueryEngine...")
|
||
query_engine = QueryEngine()
|
||
query_engine.load()
|
||
|
||
# 加载 Tokenizer
|
||
logger.info("加载 Tokenizer...")
|
||
from importlib.resources import files as pkg_files
|
||
from modelscope import AutoTokenizer
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
|
||
)
|
||
|
||
# ====== 1. Labels 分布分析 ======
|
||
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
|
||
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
|
||
|
||
# 获取词表总大小
|
||
vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0)
|
||
|
||
appeared_ids = set(counter.keys())
|
||
all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS)
|
||
missing_ids = all_ids - appeared_ids
|
||
|
||
console.print(f"\n总样本数: {total:,}")
|
||
console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)")
|
||
console.print(f"唯一标签数: {len(counter):,}")
|
||
console.print(
|
||
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
|
||
)
|
||
console.print(
|
||
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]"
|
||
)
|
||
|
||
most_common = counter.most_common(args.top_k)
|
||
least_common = (
|
||
counter.most_common()[: -args.top_k - 1 : -1]
|
||
if len(counter) > args.top_k
|
||
else counter.most_common()
|
||
)
|
||
|
||
# 最多标签表
|
||
table_top = Table(
|
||
title=f"出现次数最多的 {args.top_k} 个标签",
|
||
show_header=True,
|
||
header_style="bold magenta",
|
||
)
|
||
table_top.add_column("排名", style="cyan", width=6)
|
||
table_top.add_column("ID", style="green", width=8)
|
||
table_top.add_column("字符(拼音)", style="yellow", width=20)
|
||
table_top.add_column("频次", style="red", width=12)
|
||
table_top.add_column("占比", style="blue", width=10)
|
||
|
||
for rank, (label_id, count) in enumerate(most_common, 1):
|
||
info = query_engine.query_by_id(label_id)
|
||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
||
pct = count / total * 100
|
||
table_top.add_row(
|
||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
|
||
)
|
||
console.print(table_top)
|
||
|
||
# 最少标签表
|
||
table_bottom = Table(
|
||
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
|
||
show_header=True,
|
||
header_style="bold magenta",
|
||
)
|
||
table_bottom.add_column("排名", style="cyan", width=6)
|
||
table_bottom.add_column("ID", style="green", width=8)
|
||
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
|
||
table_bottom.add_column("频次", style="red", width=12)
|
||
table_bottom.add_column("占比", style="blue", width=10)
|
||
|
||
for rank, (label_id, count) in enumerate(least_common, 1):
|
||
info = query_engine.query_by_id(label_id)
|
||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
||
pct = count / total * 100
|
||
table_bottom.add_row(
|
||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
|
||
)
|
||
console.print(table_bottom)
|
||
|
||
# 频次分布概览
|
||
table_dist = Table(
|
||
title="频次分布概览", show_header=True, header_style="bold magenta"
|
||
)
|
||
table_dist.add_column("频次区间", style="cyan")
|
||
table_dist.add_column("标签数", style="green")
|
||
table_dist.add_column("占总标签数比例", style="yellow")
|
||
|
||
bins = [
|
||
(1, 10),
|
||
(11, 100),
|
||
(101, 1000),
|
||
(1001, 10000),
|
||
(10001, 100000),
|
||
(100001, 1000000),
|
||
(1000001, float("inf")),
|
||
]
|
||
for lo, hi in bins:
|
||
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
|
||
if count_in_bin > 0:
|
||
hi_str = str(int(hi)) if hi != float("inf") else "∞"
|
||
table_dist.add_row(
|
||
f"{lo}-{hi_str}",
|
||
f"{count_in_bin:,}",
|
||
f"{count_in_bin / len(counter) * 100:.1f}%",
|
||
)
|
||
# 未出现
|
||
if len(missing_ids) > 0:
|
||
table_dist.add_row(
|
||
"未出现",
|
||
f"{len(missing_ids):,}",
|
||
f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%",
|
||
)
|
||
console.print(table_dist)
|
||
|
||
# ====== 2. 随机抽样还原 → CSV ======
|
||
num_samples = min(args.num_samples, len(dataset))
|
||
console.print(
|
||
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
||
)
|
||
|
||
indices = random.sample(range(len(dataset)), num_samples)
|
||
|
||
csv_path = Path(args.output)
|
||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
csv_headers = [
|
||
"index",
|
||
"pinyin",
|
||
"label_char",
|
||
"label_id",
|
||
"label_count",
|
||
"context",
|
||
"suffix",
|
||
"history",
|
||
"full_tokens",
|
||
]
|
||
|
||
with open(csv_path, "w", encoding="utf-8", newline="") as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow(csv_headers)
|
||
|
||
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
|
||
sample = dataset[idx]
|
||
decoded = decode_sample(sample, tokenizer, query_engine)
|
||
writer.writerow(
|
||
[
|
||
idx,
|
||
decoded["pinyin"],
|
||
decoded["label_char"],
|
||
decoded["label_id"],
|
||
decoded["label_count"],
|
||
decoded["context"],
|
||
decoded["suffix"],
|
||
decoded["history"],
|
||
decoded["full_tokens"],
|
||
]
|
||
)
|
||
|
||
console.print(
|
||
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
|
||
)
|
||
|
||
# 打印前 5 个样本的概要
|
||
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
|
||
with open(csv_path, "r", encoding="utf-8") as f:
|
||
reader = csv.DictReader(f)
|
||
for i, row in enumerate(reader):
|
||
if i >= 5:
|
||
break
|
||
console.print(
|
||
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
|
||
f"上下文={row['context'][:50]}..."
|
||
)
|
||
|
||
console.print("\n[bold green]分析完成[/bold green]")
|
||
|
||
|
||
app = main
|
||
|
||
if __name__ == "__main__":
|
||
main()
|