444 lines
14 KiB
Python
444 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
缺失字符补充工具
|
||
|
||
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
||
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
||
步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录
|
||
|
||
用法:
|
||
python -m model.supplement_missing find-missing \
|
||
--preprocessed-dir ./preprocessed/train \
|
||
--output missing_chars.json
|
||
|
||
python -m model.supplement_missing generate-template \
|
||
--missing-chars missing_chars.json \
|
||
--output supplement_texts.jsonl
|
||
|
||
python -m model.supplement_missing preprocess-supplement \
|
||
--missing-chars missing_chars.json \
|
||
--supplement-texts supplement_texts.jsonl \
|
||
--output-dir ./preprocessed/supplement \
|
||
--num-samples 100000
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Set
|
||
|
||
import numpy as np
|
||
import torch
|
||
from loguru import logger
|
||
from rich.console import Console
|
||
from rich.table import Table
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from .dataset import PinyinInputDataset
|
||
from .preprocess import collect_samples
|
||
from .query import QueryEngine
|
||
from .trainer import preprocess_collate_fn, worker_init_fn
|
||
|
||
|
||
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
||
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
|
||
appeared: Set[int] = set()
|
||
|
||
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
|
||
if not shard_files:
|
||
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
|
||
return appeared
|
||
|
||
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
|
||
data = np.load(shard_path)
|
||
labels = data["labels"].astype(np.int64)
|
||
if labels.ndim > 1 and labels.shape[-1] == 1:
|
||
labels = labels.squeeze(-1)
|
||
unique_ids = np.unique(labels)
|
||
appeared.update(int(uid) for uid in unique_ids)
|
||
del data
|
||
|
||
return appeared
|
||
|
||
|
||
def cmd_find_missing(args):
|
||
console = Console()
|
||
preprocessed_dir = Path(args.preprocessed_dir)
|
||
|
||
if not preprocessed_dir.exists():
|
||
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
|
||
return
|
||
|
||
metadata_path = preprocessed_dir / "metadata.json"
|
||
if not metadata_path.exists():
|
||
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
|
||
return
|
||
|
||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||
metadata = json.load(f)
|
||
console.print(
|
||
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
|
||
)
|
||
|
||
console.print("[bold cyan]扫描 labels...[/bold cyan]")
|
||
appeared = scan_labels(preprocessed_dir)
|
||
|
||
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
|
||
query_engine = QueryEngine()
|
||
query_engine.load()
|
||
|
||
all_ids = set(query_engine._id_to_info.keys())
|
||
missing_ids = all_ids - appeared
|
||
|
||
missing_chars = []
|
||
for mid in sorted(missing_ids):
|
||
if mid == 0:
|
||
continue
|
||
info = query_engine.query_by_id(mid)
|
||
if info is not None:
|
||
missing_chars.append(
|
||
{
|
||
"id": info.id,
|
||
"char": info.char,
|
||
"pinyin": info.pinyin,
|
||
"count": info.count,
|
||
}
|
||
)
|
||
|
||
result = {
|
||
"missing_count": len(missing_chars),
|
||
"missing_chars": missing_chars,
|
||
}
|
||
|
||
output_path = Path(args.output)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
|
||
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
|
||
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
|
||
console.print(f"已出现标签: {len(appeared):,}")
|
||
console.print(
|
||
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
|
||
)
|
||
|
||
if missing_chars:
|
||
table = Table(
|
||
title=f"缺失字符 (共 {len(missing_chars)} 个)",
|
||
show_header=True,
|
||
header_style="bold magenta",
|
||
)
|
||
table.add_column("ID", style="cyan", width=8)
|
||
table.add_column("字符", style="yellow", width=6)
|
||
table.add_column("拼音", style="green", width=12)
|
||
table.add_column("语料频次", style="red", width=12)
|
||
for entry in missing_chars:
|
||
table.add_row(
|
||
str(entry["id"]),
|
||
entry["char"],
|
||
entry["pinyin"],
|
||
f"{entry['count']:,}",
|
||
)
|
||
console.print(table)
|
||
|
||
console.print(f"\n已输出到: {output_path}")
|
||
|
||
|
||
def cmd_generate_template(args):
|
||
console = Console()
|
||
|
||
missing_path = Path(args.missing_chars)
|
||
if not missing_path.exists():
|
||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||
return
|
||
|
||
with open(missing_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
missing_chars = data.get("missing_chars", [])
|
||
if not missing_chars:
|
||
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
|
||
return
|
||
|
||
num_entries = args.num_entries
|
||
total_lines = len(missing_chars) * num_entries
|
||
|
||
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
|
||
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
|
||
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
|
||
|
||
output_path = Path(args.output)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
for entry in missing_chars:
|
||
for i in range(num_entries):
|
||
line = json.dumps(
|
||
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
|
||
ensure_ascii=False,
|
||
)
|
||
f.write(line + "\n")
|
||
|
||
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
|
||
console.print(
|
||
f"共 {total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
|
||
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
|
||
)
|
||
|
||
|
||
def cmd_preprocess_supplement(args):
|
||
console = Console()
|
||
|
||
# 加载缺失字符
|
||
missing_path = Path(args.missing_chars)
|
||
if not missing_path.exists():
|
||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||
return
|
||
|
||
with open(missing_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
missing_chars = data.get("missing_chars", [])
|
||
if not missing_chars:
|
||
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
|
||
return
|
||
|
||
target_labels = {entry["id"] for entry in missing_chars}
|
||
target_labels.add(0) # 包含 EOS
|
||
|
||
# 解析参数
|
||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||
length_weights = {
|
||
int(k): int(v)
|
||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||
}
|
||
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
max_iter = args.num_samples * 5
|
||
num_workers = args.num_workers
|
||
|
||
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
|
||
console.print(f"补充文本: {args.supplement_texts}")
|
||
console.print(f"缺失字符数: {len(missing_chars)}")
|
||
console.print(f"目标样本: {args.num_samples:,}")
|
||
console.print(f"输出目录: {output_dir}")
|
||
console.print(f"Worker 数: {num_workers}")
|
||
console.print()
|
||
|
||
torch.manual_seed(args.seed)
|
||
np.random.seed(args.seed)
|
||
|
||
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
|
||
dataset = PinyinInputDataset(
|
||
data_path="json",
|
||
max_workers=num_workers,
|
||
max_iter_length=max_iter,
|
||
max_seq_length=args.max_seq_length,
|
||
text_field="text",
|
||
py_style_weight=py_style_weight,
|
||
shuffle_buffer_size=100,
|
||
length_weights=length_weights,
|
||
data_kwargs={
|
||
"data_files": args.supplement_texts,
|
||
"streaming": False,
|
||
},
|
||
target_labels=target_labels,
|
||
)
|
||
|
||
dataloader_kwargs = {
|
||
"batch_size": args.batch_size,
|
||
"num_workers": num_workers,
|
||
"pin_memory": False,
|
||
"worker_init_fn": worker_init_fn,
|
||
"collate_fn": preprocess_collate_fn(args.max_seq_length),
|
||
}
|
||
if num_workers > 0:
|
||
dataloader_kwargs["prefetch_factor"] = 2
|
||
dataloader_kwargs["persistent_workers"] = True
|
||
|
||
dataloader = DataLoader(dataset, **dataloader_kwargs)
|
||
|
||
logger.info("开始收集补充数据...")
|
||
count = collect_samples(
|
||
dataloader,
|
||
args.num_samples,
|
||
output_dir,
|
||
"supplement",
|
||
args.max_seq_length,
|
||
args.shard_size,
|
||
)
|
||
|
||
if count < args.num_samples:
|
||
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
|
||
|
||
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
|
||
console.print(f"生成样本: {count:,}")
|
||
console.print(f"输出目录: {output_dir}")
|
||
|
||
total_size = sum(
|
||
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
|
||
)
|
||
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
|
||
console.print()
|
||
console.print(
|
||
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
|
||
)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="缺失字符补充工具",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
子命令:
|
||
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
||
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
||
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录)
|
||
|
||
示例:
|
||
python -m model.supplement_missing find-missing \\
|
||
--preprocessed-dir ./preprocessed/train \\
|
||
--output missing_chars.json
|
||
|
||
python -m model.supplement_missing generate-template \\
|
||
--missing-chars missing_chars.json \\
|
||
--output supplement_texts.jsonl
|
||
|
||
python -m model.supplement_missing preprocess-supplement \\
|
||
--missing-chars missing_chars.json \\
|
||
--supplement-texts supplement_texts.jsonl \\
|
||
--output-dir ./preprocessed/supplement \\
|
||
--num-samples 100000
|
||
""",
|
||
)
|
||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||
|
||
# find-missing
|
||
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
|
||
p_find.add_argument(
|
||
"--preprocessed-dir",
|
||
type=str,
|
||
required=True,
|
||
help="预处理数据目录(包含 shard_*.npz 和 metadata.json)",
|
||
)
|
||
p_find.add_argument(
|
||
"--output",
|
||
type=str,
|
||
default="missing_chars.json",
|
||
help="输出 JSON 文件路径(默认: missing_chars.json)",
|
||
)
|
||
|
||
# generate-template
|
||
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
|
||
p_gen.add_argument(
|
||
"--missing-chars",
|
||
type=str,
|
||
required=True,
|
||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||
)
|
||
p_gen.add_argument(
|
||
"--output",
|
||
type=str,
|
||
default="supplement_texts.jsonl",
|
||
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl)",
|
||
)
|
||
p_gen.add_argument(
|
||
"--num-entries",
|
||
type=int,
|
||
default=3,
|
||
help="每个缺失字符生成的模板条数(默认: 3)",
|
||
)
|
||
|
||
# preprocess-supplement
|
||
p_pre = subparsers.add_parser(
|
||
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
|
||
)
|
||
p_pre.add_argument(
|
||
"--missing-chars",
|
||
type=str,
|
||
required=True,
|
||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--supplement-texts",
|
||
type=str,
|
||
required=True,
|
||
help="已填写的补充文本 JSONL 文件路径",
|
||
)
|
||
p_pre.add_argument(
|
||
"--output-dir",
|
||
type=str,
|
||
required=True,
|
||
help="输出目录(独立目录,不会覆盖已有数据)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--num-samples",
|
||
type=int,
|
||
required=True,
|
||
help="目标样本数量",
|
||
)
|
||
p_pre.add_argument(
|
||
"--batch-size",
|
||
type=int,
|
||
default=128,
|
||
help="批大小(默认: 128)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--num-workers",
|
||
type=int,
|
||
default=0,
|
||
help="DataLoader worker 数量。本地 JSONL 小文件建议 0(默认: 0)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--max-seq-length",
|
||
type=int,
|
||
default=128,
|
||
help="最大序列长度(默认: 128)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--seed",
|
||
type=int,
|
||
default=42,
|
||
help="随机种子(默认: 42)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--shard-size",
|
||
type=int,
|
||
default=5_000_000,
|
||
help="分片大小(样本数),控制内存峰值(默认: 500万)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--py-style-weight",
|
||
type=str,
|
||
default="9,2,1",
|
||
help="拼音风格权重(逗号分隔,默认: 9,2,1)",
|
||
)
|
||
p_pre.add_argument(
|
||
"--length-weights",
|
||
type=str,
|
||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.command is None:
|
||
parser.print_help()
|
||
return
|
||
|
||
if args.command == "find-missing":
|
||
cmd_find_missing(args)
|
||
elif args.command == "generate-template":
|
||
cmd_generate_template(args)
|
||
elif args.command == "preprocess-supplement":
|
||
cmd_preprocess_supplement(args)
|
||
|
||
|
||
app = main
|
||
|
||
if __name__ == "__main__":
|
||
main()
|