SUimeModelTraner/src/model/inspect_preprocessed.py

402 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
预处理数据质量分析脚本
功能:
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
2. 随机抽样还原为人类可读文本,导出为 CSV 文件
用法:
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
"""
import argparse
import csv
import json
import random
from collections import Counter
from pathlib import Path
import numpy as np
from loguru import logger
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from .char_info import CharInfo
from .dataset import CHAR_TO_ID
from .preprocessed_dataset import PreProcessedDataset
from .query import QueryEngine
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
def decode_pinyin_ids(pinyin_ids: list) -> str:
"""将 pinyin_ids 还原为拼音字符串"""
chars = []
for pid in pinyin_ids:
if pid == 0:
break
chars.append(ID_TO_CHAR.get(pid, "?"))
return "".join(chars)
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
"""将 history_slot_ids 还原为文字"""
parts = []
for hid in history_ids:
if hid == 0:
parts.append("<PAD>")
else:
info = query_engine.query_by_id(hid)
if info is not None:
parts.append(f"{info.char}({info.pinyin})")
else:
parts.append(f"<ID:{hid}>")
return " | ".join(parts)
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
"""统计 labels 分布,带进度条"""
logger.info("正在统计 labels 分布...")
counter = Counter()
total = 0
num_shards = dataset._num_shards if dataset._is_sharded else 1
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
for shard_idx in pbar:
if dataset._is_sharded:
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
labels = shard_data["labels"].astype(np.int64)
else:
labels = dataset.labels[:].astype(np.int64)
unique, counts = np.unique(labels, return_counts=True)
for uid, cnt in zip(unique, counts):
counter[int(uid)] += cnt
total += len(labels)
if dataset._is_sharded:
del shard_data
return counter, total
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
"""将一个样本还原为人类可读格式"""
input_ids = (
sample["input_ids"].tolist()
if hasattr(sample["input_ids"], "tolist")
else sample["input_ids"]
)
token_type_ids = (
sample["token_type_ids"].tolist()
if hasattr(sample["token_type_ids"], "tolist")
else sample["token_type_ids"]
)
labels = (
sample["labels"].item()
if hasattr(sample["labels"], "item")
else sample["labels"]
)
history_ids = (
sample["history_slot_ids"].tolist()
if hasattr(sample["history_slot_ids"], "tolist")
else sample["history_slot_ids"]
)
pinyin_ids = (
sample["pinyin_ids"].tolist()
if hasattr(sample["pinyin_ids"], "tolist")
else sample["pinyin_ids"]
)
# 还原 token 文本
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
if sep_positions:
sep_start = sep_positions[0]
sent_a_ids = [
tid
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
if tt == 0
]
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
else:
sent_a_ids = input_ids
sent_b_ids = []
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
suffix_text = (
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
)
pinyin_str = decode_pinyin_ids(pinyin_ids)
label_info = query_engine.query_by_id(labels)
history_str = decode_history(history_ids, query_engine)
return {
"context": context_text,
"suffix": suffix_text,
"pinyin": pinyin_str,
"label_id": labels,
"label_char": f"{label_info.char}({label_info.pinyin})"
if label_info
else f"<ID:{labels}>",
"label_count": label_info.count if label_info else 0,
"history": history_str,
"full_tokens": token_text,
}
def main():
console = Console()
parser = argparse.ArgumentParser(description="预处理数据质量分析")
parser.add_argument(
"--data-dir",
type=str,
required=True,
help="预处理数据目录train/ 或 eval/",
)
parser.add_argument(
"--num-samples",
type=int,
default=50,
help="随机抽样的样本数量默认50",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="CSV 输出文件路径(默认: <data-dir>/samples.csv",
)
parser.add_argument(
"--max-shards",
type=int,
default=0,
help="统计 labels 时最多读取的分片数0=全部)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="随机种子",
)
parser.add_argument(
"--top-k",
type=int,
default=30,
help="显示出现次数最多和最少的标签数量",
)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
if args.output is None:
args.output = str(Path(args.data_dir) / "samples.csv")
# 加载数据集
logger.info(f"加载数据集: {args.data_dir}")
dataset = PreProcessedDataset(args.data_dir)
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
if dataset._is_sharded:
console.print(
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
)
console.print()
# 加载 QueryEngine
logger.info("加载 QueryEngine...")
query_engine = QueryEngine()
query_engine.load()
# 加载 Tokenizer
logger.info("加载 Tokenizer...")
from importlib.resources import files as pkg_files
from modelscope import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
)
# ====== 1. Labels 分布分析 ======
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
# 获取词表总大小
vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0)
appeared_ids = set(counter.keys())
all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS)
missing_ids = all_ids - appeared_ids
console.print(f"\n总样本数: {total:,}")
console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)")
console.print(f"唯一标签数: {len(counter):,}")
console.print(
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
)
console.print(
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]"
)
most_common = counter.most_common(args.top_k)
least_common = (
counter.most_common()[: -args.top_k - 1 : -1]
if len(counter) > args.top_k
else counter.most_common()
)
# 最多标签表
table_top = Table(
title=f"出现次数最多的 {args.top_k} 个标签",
show_header=True,
header_style="bold magenta",
)
table_top.add_column("排名", style="cyan", width=6)
table_top.add_column("ID", style="green", width=8)
table_top.add_column("字符(拼音)", style="yellow", width=20)
table_top.add_column("频次", style="red", width=12)
table_top.add_column("占比", style="blue", width=10)
for rank, (label_id, count) in enumerate(most_common, 1):
info = query_engine.query_by_id(label_id)
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
pct = count / total * 100
table_top.add_row(
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
)
console.print(table_top)
# 最少标签表
table_bottom = Table(
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
show_header=True,
header_style="bold magenta",
)
table_bottom.add_column("排名", style="cyan", width=6)
table_bottom.add_column("ID", style="green", width=8)
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
table_bottom.add_column("频次", style="red", width=12)
table_bottom.add_column("占比", style="blue", width=10)
for rank, (label_id, count) in enumerate(least_common, 1):
info = query_engine.query_by_id(label_id)
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
pct = count / total * 100
table_bottom.add_row(
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
)
console.print(table_bottom)
# 频次分布概览
table_dist = Table(
title="频次分布概览", show_header=True, header_style="bold magenta"
)
table_dist.add_column("频次区间", style="cyan")
table_dist.add_column("标签数", style="green")
table_dist.add_column("占总标签数比例", style="yellow")
bins = [
(1, 10),
(11, 100),
(101, 1000),
(1001, 10000),
(10001, 100000),
(100001, 1000000),
(1000001, float("inf")),
]
for lo, hi in bins:
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
if count_in_bin > 0:
hi_str = str(int(hi)) if hi != float("inf") else ""
table_dist.add_row(
f"{lo}-{hi_str}",
f"{count_in_bin:,}",
f"{count_in_bin / len(counter) * 100:.1f}%",
)
# 未出现
if len(missing_ids) > 0:
table_dist.add_row(
"未出现",
f"{len(missing_ids):,}",
f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%",
)
console.print(table_dist)
# ====== 2. 随机抽样还原 → CSV ======
num_samples = min(args.num_samples, len(dataset))
console.print(
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
)
indices = random.sample(range(len(dataset)), num_samples)
csv_path = Path(args.output)
csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_headers = [
"index",
"pinyin",
"label_char",
"label_id",
"label_count",
"context",
"suffix",
"history",
"full_tokens",
]
with open(csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(csv_headers)
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
sample = dataset[idx]
decoded = decode_sample(sample, tokenizer, query_engine)
writer.writerow(
[
idx,
decoded["pinyin"],
decoded["label_char"],
decoded["label_id"],
decoded["label_count"],
decoded["context"],
decoded["suffix"],
decoded["history"],
decoded["full_tokens"],
]
)
console.print(
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
)
# 打印前 5 个样本的概要
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if i >= 5:
break
console.print(
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
f"上下文={row['context'][:50]}..."
)
console.print("\n[bold green]分析完成[/bold green]")
app = main
if __name__ == "__main__":
main()