#!/usr/bin/env python3 """ 预处理脚本:将 PinyinInputDataset 的输出转换为分片压缩 .npz 文件。 采用分片流式写入,内存峰值固定为 shard_size 级别,不随总样本数增长。 每个分片使用 np.savez_compressed 保存(zlib 压缩),GPU 服务器无需解压到硬盘。 用法: python -m model.preprocess \ --train-data-path "some/hf_dataset" \ --eval-data-path "some/hf_dataset" \ --output-dir ./preprocessed \ --num-train-samples 5000000 \ --num-eval-samples 8192 生成目录结构: output_dir/ train/ metadata.json shard_000.npz (5M样本, 6个字段, zlib压缩) shard_001.npz ... eval/ metadata.json shard_000.npz ... """ import argparse import gc import json import struct import time import zipfile from pathlib import Path from typing import Dict, List import numpy as np import torch from loguru import logger from rich.console import Console from torch.utils.data import DataLoader from tqdm import tqdm from .dataset import PinyinInputDataset from .trainer import preprocess_collate_fn, worker_init_fn FIELDS = [ "input_ids", "token_type_ids", "attention_mask", "labels", "history_slot_ids", "pinyin_ids", ] def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]: """从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组""" result = {} for f in FIELDS: tensor = batch[f][:take] arr = tensor.numpy().astype(np.int16) if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1: arr = arr.squeeze(-1) result[f] = arr return result def collect_samples( dataloader: DataLoader, num_samples: int, output_dir: Path, split_name: str, max_seq_length: int = 128, shard_size: int = 5_000_000, ) -> int: """ 分片流式收集样本,每累积 shard_size 个样本保存为一个压缩 .npz 分片。 内存峰值 = shard_size × 每样本字节数(约578字节 @ shard_size=5M → 约2.9GB) """ split_dir = output_dir / split_name split_dir.mkdir(parents=True, exist_ok=True) shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS} shard_count = 0 shard_idx = 0 total = 0 pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples") for batch in dataloader: batch_size = batch["input_ids"].size(0) remaining = num_samples - total if remaining <= 0: break take = min(batch_size, remaining) extracted = _extract_batch(batch, take) for f in FIELDS: shard_buffers[f].append(extracted[f]) shard_count += take total += take pbar.update(take) if shard_count >= shard_size: merged = {} for f in FIELDS: merged[f] = np.concatenate(shard_buffers[f], axis=0) np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged) logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples") shard_idx += 1 shard_buffers = {f: [] for f in FIELDS} shard_count = 0 del merged gc.collect() if total >= num_samples: break # 写入最后一个不满的分片 if shard_count > 0: merged = {} for f in FIELDS: merged[f] = np.concatenate(shard_buffers[f], axis=0) np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged) logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples") shard_idx += 1 pbar.close() actual_count = min(total, num_samples) num_shards = shard_idx shard_sizes = [shard_size] * (num_shards - 1) if num_shards > 0: shard_sizes.append(actual_count - sum(shard_sizes)) metadata = { "num_samples": actual_count, "max_seq_length": max_seq_length, "dtype": "int16", "fields": FIELDS, "shard_size": shard_size, "num_shards": num_shards, "shard_sizes": shard_sizes, } with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp: json.dump(metadata, fp, indent=2, ensure_ascii=False) total_size = sum( f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz" ) logger.info( f"{split_name}: {actual_count} samples in {num_shards} shards, " f"{total_size / (1024**3):.2f} GB (compressed)" ) return actual_count def main(): console = Console() parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件") parser.add_argument( "--train-data-path", type=str, required=True, help="训练数据集路径(HuggingFace格式)", ) parser.add_argument( "--eval-data-path", type=str, required=True, help="评估数据集路径(HuggingFace格式)", ) parser.add_argument("--output-dir", type=str, required=True, help="输出目录") parser.add_argument( "--num-train-samples", type=int, required=True, help="训练集样本数量" ) parser.add_argument( "--num-eval-samples", type=int, required=True, help="评估集样本数量" ) parser.add_argument("--batch-size", type=int, default=128, help="批大小") parser.add_argument( "--num-workers", type=int, default=2, help="DataLoader worker数量" ) parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument( "--shard-size", type=int, default=5_000_000, help="分片大小(样本数),控制内存峰值(默认500万,约2.9GB/分片未压缩)", ) parser.add_argument( "--py-style-weight", type=str, default="9,2,1", help="拼音风格权重(逗号分隔)", ) parser.add_argument( "--shuffle-buffer-size", type=int, default=2000000, help="数据集shuffle缓冲区大小", ) parser.add_argument( "--length-weights", type=str, default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2", help="词长权重", ) args = parser.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) py_style_weight = tuple(int(x) for x in args.py_style_weight.split(",")) length_weights = { int(k): int(v) for k, v in (item.split(":") for item in args.length_weights.split(",")) } output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) train_max_iter = args.num_train_samples * 5 eval_max_iter = args.num_eval_samples * 5 shard_mem_gb = args.shard_size * 578 / (1024**3) console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]") console.print(f"训练集目标: {args.num_train_samples:,} 样本") console.print(f"评估集目标: {args.num_eval_samples:,} 样本") console.print(f"输出目录: {output_dir}") console.print(f"数据类型: int16") console.print( f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)" ) console.print() num_train_workers = args.num_workers num_eval_workers = max(1, args.num_workers // 2) console.print("[bold cyan]创建训练数据集...[/bold cyan]") train_dataset = PinyinInputDataset( data_path=args.train_data_path, max_workers=num_train_workers, max_iter_length=train_max_iter, max_seq_length=args.max_seq_length, text_field="text", py_style_weight=py_style_weight, shuffle_buffer_size=100, length_weights=length_weights, ) train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=num_train_workers, pin_memory=False, worker_init_fn=worker_init_fn, collate_fn=preprocess_collate_fn(args.max_seq_length), prefetch_factor=2, persistent_workers=True if num_train_workers > 0 else False, ) console.print("[bold cyan]创建评估数据集...[/bold cyan]") eval_dataset = PinyinInputDataset( data_path=args.eval_data_path, max_workers=num_eval_workers, max_iter_length=eval_max_iter, max_seq_length=args.max_seq_length, text_field="text", py_style_weight=py_style_weight, shuffle_buffer_size=100, length_weights=length_weights, ) eval_dataloader = DataLoader( eval_dataset, batch_size=args.batch_size, num_workers=num_eval_workers, pin_memory=False, worker_init_fn=worker_init_fn, collate_fn=preprocess_collate_fn(args.max_seq_length), prefetch_factor=2, persistent_workers=True if num_eval_workers > 0 else False, ) logger.info("开始收集训练数据...") train_count = collect_samples( train_dataloader, args.num_train_samples, output_dir, "train", args.max_seq_length, args.shard_size, ) if train_count < args.num_train_samples: logger.warning( f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}" ) logger.info("开始收集评估数据...") eval_count = collect_samples( eval_dataloader, args.num_eval_samples, output_dir, "eval", args.max_seq_length, args.shard_size, ) if eval_count < args.num_eval_samples: logger.warning( f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}" ) console.print("\n[bold green]=== 预处理完成 ===[/bold green]") console.print(f"训练集: {train_count:,} 样本") console.print(f"评估集: {eval_count:,} 样本") console.print(f"输出目录: {output_dir}") for split in ["train", "eval"]: split_dir = output_dir / split if split_dir.exists(): total_size = sum( f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz" ) console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)") def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000): """ 将过大的 .npz 分片拆分到目标大小。 内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。 建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。 """ import gc import time from tqdm import tqdm input_path = Path(input_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) metadata_path = input_path / "metadata.json" if not metadata_path.exists(): raise FileNotFoundError(f"metadata.json not found in {input_dir}") with open(metadata_path) as f: metadata = json.load(f) shard_files = sorted(input_path.glob("shard_*.npz")) console = Console() console.print( f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]" ) global_shard_idx = 0 all_shard_sizes: List[int] = [] total_samples = 0 for sf in tqdm(shard_files, desc="Shards"): t0 = time.time() data = np.load(str(sf)) n_samples = data[list(data.keys())[0]].shape[0] for start in range(0, n_samples, target_size): end = min(start + target_size, n_samples) chunk_data: Dict[str, np.ndarray] = {} for field in data.keys(): chunk_data[field] = data[field][start:end].astype(np.int16).copy() out_file = output_path / f"shard_{global_shard_idx:06d}.npz" np.savez_compressed(out_file, **chunk_data) chunk_size = end - start all_shard_sizes.append(chunk_size) total_samples += chunk_size del chunk_data gc.collect() global_shard_idx += 1 data.close() del data gc.collect() with open(output_path / "metadata.json", "w", encoding="utf-8") as f: metadata["num_samples"] = total_samples metadata["num_shards"] = global_shard_idx metadata["shard_sizes"] = all_shard_sizes metadata.pop("shard_size", None) json.dump(metadata, f, indent=2, ensure_ascii=False) console.print( f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]" ) def _dispatch(): import sys if len(sys.argv) > 1 and sys.argv[1] == "resplit": import argparse as ap p = ap.ArgumentParser(description="拆分过大的 .npz 分片") p.add_argument( "--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz" ) p.add_argument("--output-dir", required=True, help="输出目录") p.add_argument( "--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)" ) args = p.parse_args(sys.argv[2:]) resplit_shards(args.input_dir, args.output_dir, args.target_size) else: main() app = _dispatch if __name__ == "__main__": _dispatch()