#!/usr/bin/env python3 """ 缺失字符补充工具 步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON 步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本 步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录 用法: python -m model.supplement_missing find-missing \ --preprocessed-dir ./preprocessed/train \ --output missing_chars.json python -m model.supplement_missing generate-template \ --missing-chars missing_chars.json \ --output supplement_texts.jsonl python -m model.supplement_missing preprocess-supplement \ --missing-chars missing_chars.json \ --supplement-texts supplement_texts.jsonl \ --output-dir ./preprocessed/supplement \ --num-samples 100000 """ import argparse import json from pathlib import Path from typing import Set import numpy as np import torch from loguru import logger from rich.console import Console from rich.table import Table from torch.utils.data import DataLoader from tqdm import tqdm from .dataset import PinyinInputDataset from .preprocess import collect_samples from .query import QueryEngine from .trainer import preprocess_collate_fn, worker_init_fn def scan_labels(preprocessed_dir: Path) -> Set[int]: """扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID""" appeared: Set[int] = set() shard_files = sorted(preprocessed_dir.glob("shard_*.npz")) if not shard_files: logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}") return appeared for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"): data = np.load(shard_path) labels = data["labels"].astype(np.int64) if labels.ndim > 1 and labels.shape[-1] == 1: labels = labels.squeeze(-1) unique_ids = np.unique(labels) appeared.update(int(uid) for uid in unique_ids) del data return appeared def cmd_find_missing(args): console = Console() preprocessed_dir = Path(args.preprocessed_dir) if not preprocessed_dir.exists(): console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]") return metadata_path = preprocessed_dir / "metadata.json" if not metadata_path.exists(): console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]") return with open(metadata_path, "r", encoding="utf-8") as f: metadata = json.load(f) console.print( f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]" ) console.print("[bold cyan]扫描 labels...[/bold cyan]") appeared = scan_labels(preprocessed_dir) console.print("[bold cyan]加载 QueryEngine...[/bold cyan]") query_engine = QueryEngine() query_engine.load() all_ids = set(query_engine._id_to_info.keys()) missing_ids = all_ids - appeared missing_chars = [] for mid in sorted(missing_ids): if mid == 0: continue info = query_engine.query_by_id(mid) if info is not None: missing_chars.append( { "id": info.id, "char": info.char, "pinyin": info.pinyin, "count": info.count, } ) result = { "missing_count": len(missing_chars), "missing_chars": missing_chars, } output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]") console.print(f"词表大小: {len(all_ids):,} (含 EOS)") console.print(f"已出现标签: {len(appeared):,}") console.print( f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})" ) if missing_chars: table = Table( title=f"缺失字符 (共 {len(missing_chars)} 个)", show_header=True, header_style="bold magenta", ) table.add_column("ID", style="cyan", width=8) table.add_column("字符", style="yellow", width=6) table.add_column("拼音", style="green", width=12) table.add_column("语料频次", style="red", width=12) for entry in missing_chars: table.add_row( str(entry["id"]), entry["char"], entry["pinyin"], f"{entry['count']:,}", ) console.print(table) console.print(f"\n已输出到: {output_path}") def cmd_generate_template(args): console = Console() missing_path = Path(args.missing_chars) if not missing_path.exists(): console.print(f"[bold red]文件不存在: {missing_path}[/bold red]") return with open(missing_path, "r", encoding="utf-8") as f: data = json.load(f) missing_chars = data.get("missing_chars", []) if not missing_chars: console.print("[bold green]没有缺失字符,无需生成模板[/bold green]") return num_entries = args.num_entries total_lines = len(missing_chars) * num_entries console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]") console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]") console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]") output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: for entry in missing_chars: for i in range(num_entries): line = json.dumps( {"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"}, ensure_ascii=False, ) f.write(line + "\n") console.print(f"[bold green]模板已生成: {output_path}[/bold green]") console.print( f"共 {total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符)," f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。" ) def cmd_preprocess_supplement(args): console = Console() # 加载缺失字符 missing_path = Path(args.missing_chars) if not missing_path.exists(): console.print(f"[bold red]文件不存在: {missing_path}[/bold red]") return with open(missing_path, "r", encoding="utf-8") as f: data = json.load(f) missing_chars = data.get("missing_chars", []) if not missing_chars: console.print("[bold green]没有缺失字符,无需处理[/bold green]") return target_labels = {entry["id"] for entry in missing_chars} target_labels.add(0) # 包含 EOS # 解析参数 py_style_weight = tuple(int(x) for x in args.py_style_weight.split(",")) length_weights = { int(k): int(v) for k, v in (item.split(":") for item in args.length_weights.split(",")) } output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) max_iter = args.num_samples * 5 num_workers = args.num_workers console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]") console.print(f"补充文本: {args.supplement_texts}") console.print(f"缺失字符数: {len(missing_chars)}") console.print(f"目标样本: {args.num_samples:,}") console.print(f"输出目录: {output_dir}") console.print(f"Worker 数: {num_workers}") console.print() torch.manual_seed(args.seed) np.random.seed(args.seed) console.print("[bold cyan]创建补充数据集...[/bold cyan]") dataset = PinyinInputDataset( data_path="json", max_workers=num_workers, max_iter_length=max_iter, max_seq_length=args.max_seq_length, text_field="text", py_style_weight=py_style_weight, shuffle_buffer_size=100, length_weights=length_weights, data_kwargs={ "data_files": args.supplement_texts, "streaming": False, }, target_labels=target_labels, ) dataloader_kwargs = { "batch_size": args.batch_size, "num_workers": num_workers, "pin_memory": False, "worker_init_fn": worker_init_fn, "collate_fn": preprocess_collate_fn(args.max_seq_length), } if num_workers > 0: dataloader_kwargs["prefetch_factor"] = 2 dataloader_kwargs["persistent_workers"] = True dataloader = DataLoader(dataset, **dataloader_kwargs) logger.info("开始收集补充数据...") count = collect_samples( dataloader, args.num_samples, output_dir, "supplement", args.max_seq_length, args.shard_size, ) if count < args.num_samples: logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}") console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]") console.print(f"生成样本: {count:,}") console.print(f"输出目录: {output_dir}") total_size = sum( f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz" ) console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)") console.print() console.print( "[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json" ) def main(): parser = argparse.ArgumentParser( description="缺失字符补充工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 子命令: find-missing 扫描已预处理数据,找出从未出现的 label ID generate-template 根据缺失字符 JSON 生成 JSONL 占位文件 preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录) 示例: python -m model.supplement_missing find-missing \\ --preprocessed-dir ./preprocessed/train \\ --output missing_chars.json python -m model.supplement_missing generate-template \\ --missing-chars missing_chars.json \\ --output supplement_texts.jsonl python -m model.supplement_missing preprocess-supplement \\ --missing-chars missing_chars.json \\ --supplement-texts supplement_texts.jsonl \\ --output-dir ./preprocessed/supplement \\ --num-samples 100000 """, ) subparsers = parser.add_subparsers(dest="command", help="子命令") # find-missing p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签") p_find.add_argument( "--preprocessed-dir", type=str, required=True, help="预处理数据目录(包含 shard_*.npz 和 metadata.json)", ) p_find.add_argument( "--output", type=str, default="missing_chars.json", help="输出 JSON 文件路径(默认: missing_chars.json)", ) # generate-template p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板") p_gen.add_argument( "--missing-chars", type=str, required=True, help="缺失字符 JSON 文件路径(由 find-missing 生成)", ) p_gen.add_argument( "--output", type=str, default="supplement_texts.jsonl", help="输出 JSONL 文件路径(默认: supplement_texts.jsonl)", ) p_gen.add_argument( "--num-entries", type=int, default=3, help="每个缺失字符生成的模板条数(默认: 3)", ) # preprocess-supplement p_pre = subparsers.add_parser( "preprocess-supplement", help="将 JSONL 预处理为 .npz 分片" ) p_pre.add_argument( "--missing-chars", type=str, required=True, help="缺失字符 JSON 文件路径(由 find-missing 生成)", ) p_pre.add_argument( "--supplement-texts", type=str, required=True, help="已填写的补充文本 JSONL 文件路径", ) p_pre.add_argument( "--output-dir", type=str, required=True, help="输出目录(独立目录,不会覆盖已有数据)", ) p_pre.add_argument( "--num-samples", type=int, required=True, help="目标样本数量", ) p_pre.add_argument( "--batch-size", type=int, default=128, help="批大小(默认: 128)", ) p_pre.add_argument( "--num-workers", type=int, default=0, help="DataLoader worker 数量。本地 JSONL 小文件建议 0(默认: 0)", ) p_pre.add_argument( "--max-seq-length", type=int, default=128, help="最大序列长度(默认: 128)", ) p_pre.add_argument( "--seed", type=int, default=42, help="随机种子(默认: 42)", ) p_pre.add_argument( "--shard-size", type=int, default=5_000_000, help="分片大小(样本数),控制内存峰值(默认: 500万)", ) p_pre.add_argument( "--py-style-weight", type=str, default="9,2,1", help="拼音风格权重(逗号分隔,默认: 9,2,1)", ) p_pre.add_argument( "--length-weights", type=str, default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2", help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2)", ) args = parser.parse_args() if args.command is None: parser.print_help() return if args.command == "find-missing": cmd_find_missing(args) elif args.command == "generate-template": cmd_generate_template(args) elif args.command == "preprocess-supplement": cmd_preprocess_supplement(args) app = main if __name__ == "__main__": main()