SUimeModelTraner/src/model/supplement_missing.py

444 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
缺失字符补充工具
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID输出 JSON
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录
用法:
python -m model.supplement_missing find-missing \
--preprocessed-dir ./preprocessed/train \
--output missing_chars.json
python -m model.supplement_missing generate-template \
--missing-chars missing_chars.json \
--output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \
--missing-chars missing_chars.json \
--supplement-texts supplement_texts.jsonl \
--output-dir ./preprocessed/supplement \
--num-samples 100000
"""
import argparse
import json
from pathlib import Path
from typing import Set
import numpy as np
import torch
from loguru import logger
from rich.console import Console
from rich.table import Table
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .preprocess import collect_samples
from .query import QueryEngine
from .trainer import preprocess_collate_fn, worker_init_fn
def scan_labels(preprocessed_dir: Path) -> Set[int]:
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
appeared: Set[int] = set()
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
if not shard_files:
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
return appeared
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
data = np.load(shard_path)
labels = data["labels"].astype(np.int64)
if labels.ndim > 1 and labels.shape[-1] == 1:
labels = labels.squeeze(-1)
unique_ids = np.unique(labels)
appeared.update(int(uid) for uid in unique_ids)
del data
return appeared
def cmd_find_missing(args):
console = Console()
preprocessed_dir = Path(args.preprocessed_dir)
if not preprocessed_dir.exists():
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
return
metadata_path = preprocessed_dir / "metadata.json"
if not metadata_path.exists():
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
return
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
console.print(
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
)
console.print("[bold cyan]扫描 labels...[/bold cyan]")
appeared = scan_labels(preprocessed_dir)
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
query_engine = QueryEngine()
query_engine.load()
all_ids = set(query_engine._id_to_info.keys())
missing_ids = all_ids - appeared
missing_chars = []
for mid in sorted(missing_ids):
if mid == 0:
continue
info = query_engine.query_by_id(mid)
if info is not None:
missing_chars.append(
{
"id": info.id,
"char": info.char,
"pinyin": info.pinyin,
"count": info.count,
}
)
result = {
"missing_count": len(missing_chars),
"missing_chars": missing_chars,
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
console.print(f"已出现标签: {len(appeared):,}")
console.print(
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
)
if missing_chars:
table = Table(
title=f"缺失字符 (共 {len(missing_chars)} 个)",
show_header=True,
header_style="bold magenta",
)
table.add_column("ID", style="cyan", width=8)
table.add_column("字符", style="yellow", width=6)
table.add_column("拼音", style="green", width=12)
table.add_column("语料频次", style="red", width=12)
for entry in missing_chars:
table.add_row(
str(entry["id"]),
entry["char"],
entry["pinyin"],
f"{entry['count']:,}",
)
console.print(table)
console.print(f"\n已输出到: {output_path}")
def cmd_generate_template(args):
console = Console()
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
return
num_entries = args.num_entries
total_lines = len(missing_chars) * num_entries
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for entry in missing_chars:
for i in range(num_entries):
line = json.dumps(
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
ensure_ascii=False,
)
f.write(line + "\n")
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
console.print(
f"{total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
)
def cmd_preprocess_supplement(args):
console = Console()
# 加载缺失字符
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
return
target_labels = {entry["id"] for entry in missing_chars}
target_labels.add(0) # 包含 EOS
# 解析参数
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
max_iter = args.num_samples * 5
num_workers = args.num_workers
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
console.print(f"补充文本: {args.supplement_texts}")
console.print(f"缺失字符数: {len(missing_chars)}")
console.print(f"目标样本: {args.num_samples:,}")
console.print(f"输出目录: {output_dir}")
console.print(f"Worker 数: {num_workers}")
console.print()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
dataset = PinyinInputDataset(
data_path="json",
max_workers=num_workers,
max_iter_length=max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
data_kwargs={
"data_files": args.supplement_texts,
"streaming": False,
},
target_labels=target_labels,
)
dataloader_kwargs = {
"batch_size": args.batch_size,
"num_workers": num_workers,
"pin_memory": False,
"worker_init_fn": worker_init_fn,
"collate_fn": preprocess_collate_fn(args.max_seq_length),
}
if num_workers > 0:
dataloader_kwargs["prefetch_factor"] = 2
dataloader_kwargs["persistent_workers"] = True
dataloader = DataLoader(dataset, **dataloader_kwargs)
logger.info("开始收集补充数据...")
count = collect_samples(
dataloader,
args.num_samples,
output_dir,
"supplement",
args.max_seq_length,
args.shard_size,
)
if count < args.num_samples:
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
console.print(f"生成样本: {count:,}")
console.print(f"输出目录: {output_dir}")
total_size = sum(
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
console.print()
console.print(
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
)
def main():
parser = argparse.ArgumentParser(
description="缺失字符补充工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
子命令:
find-missing 扫描已预处理数据,找出从未出现的 label ID
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录)
示例:
python -m model.supplement_missing find-missing \\
--preprocessed-dir ./preprocessed/train \\
--output missing_chars.json
python -m model.supplement_missing generate-template \\
--missing-chars missing_chars.json \\
--output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \\
--missing-chars missing_chars.json \\
--supplement-texts supplement_texts.jsonl \\
--output-dir ./preprocessed/supplement \\
--num-samples 100000
""",
)
subparsers = parser.add_subparsers(dest="command", help="子命令")
# find-missing
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
p_find.add_argument(
"--preprocessed-dir",
type=str,
required=True,
help="预处理数据目录(包含 shard_*.npz 和 metadata.json",
)
p_find.add_argument(
"--output",
type=str,
default="missing_chars.json",
help="输出 JSON 文件路径(默认: missing_chars.json",
)
# generate-template
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
p_gen.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_gen.add_argument(
"--output",
type=str,
default="supplement_texts.jsonl",
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl",
)
p_gen.add_argument(
"--num-entries",
type=int,
default=3,
help="每个缺失字符生成的模板条数(默认: 3",
)
# preprocess-supplement
p_pre = subparsers.add_parser(
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
)
p_pre.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_pre.add_argument(
"--supplement-texts",
type=str,
required=True,
help="已填写的补充文本 JSONL 文件路径",
)
p_pre.add_argument(
"--output-dir",
type=str,
required=True,
help="输出目录(独立目录,不会覆盖已有数据)",
)
p_pre.add_argument(
"--num-samples",
type=int,
required=True,
help="目标样本数量",
)
p_pre.add_argument(
"--batch-size",
type=int,
default=128,
help="批大小(默认: 128",
)
p_pre.add_argument(
"--num-workers",
type=int,
default=0,
help="DataLoader worker 数量。本地 JSONL 小文件建议 0默认: 0",
)
p_pre.add_argument(
"--max-seq-length",
type=int,
default=128,
help="最大序列长度(默认: 128",
)
p_pre.add_argument(
"--seed",
type=int,
default=42,
help="随机种子(默认: 42",
)
p_pre.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小(样本数),控制内存峰值(默认: 500万",
)
p_pre.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔,默认: 9,2,1",
)
p_pre.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
)
args = parser.parse_args()
if args.command is None:
parser.print_help()
return
if args.command == "find-missing":
cmd_find_missing(args)
elif args.command == "generate-template":
cmd_generate_template(args)
elif args.command == "preprocess-supplement":
cmd_preprocess_supplement(args)
app = main
if __name__ == "__main__":
main()