434 lines
13 KiB
Python
434 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
预处理脚本:将 PinyinInputDataset 的输出转换为分片压缩 .npz 文件。
|
||
|
||
采用分片流式写入,内存峰值固定为 shard_size 级别,不随总样本数增长。
|
||
每个分片使用 np.savez_compressed 保存(zlib 压缩),GPU 服务器无需解压到硬盘。
|
||
|
||
用法:
|
||
python -m model.preprocess \
|
||
--train-data-path "some/hf_dataset" \
|
||
--eval-data-path "some/hf_dataset" \
|
||
--output-dir ./preprocessed \
|
||
--num-train-samples 5000000 \
|
||
--num-eval-samples 8192
|
||
|
||
生成目录结构:
|
||
output_dir/
|
||
train/
|
||
metadata.json
|
||
shard_000.npz (5M样本, 6个字段, zlib压缩)
|
||
shard_001.npz
|
||
...
|
||
eval/
|
||
metadata.json
|
||
shard_000.npz
|
||
...
|
||
"""
|
||
|
||
import argparse
|
||
import gc
|
||
import json
|
||
import struct
|
||
import time
|
||
import zipfile
|
||
from pathlib import Path
|
||
from typing import Dict, List
|
||
|
||
import numpy as np
|
||
import torch
|
||
from loguru import logger
|
||
from rich.console import Console
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from .dataset import PinyinInputDataset
|
||
from .trainer import preprocess_collate_fn, worker_init_fn
|
||
|
||
FIELDS = [
|
||
"input_ids",
|
||
"token_type_ids",
|
||
"attention_mask",
|
||
"labels",
|
||
"history_slot_ids",
|
||
"pinyin_ids",
|
||
]
|
||
|
||
|
||
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
|
||
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
|
||
result = {}
|
||
for f in FIELDS:
|
||
tensor = batch[f][:take]
|
||
arr = tensor.numpy().astype(np.int16)
|
||
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
|
||
arr = arr.squeeze(-1)
|
||
result[f] = arr
|
||
return result
|
||
|
||
|
||
def collect_samples(
|
||
dataloader: DataLoader,
|
||
num_samples: int,
|
||
output_dir: Path,
|
||
split_name: str,
|
||
max_seq_length: int = 128,
|
||
shard_size: int = 5_000_000,
|
||
) -> int:
|
||
"""
|
||
分片流式收集样本,每累积 shard_size 个样本保存为一个压缩 .npz 分片。
|
||
|
||
内存峰值 = shard_size × 每样本字节数(约578字节 @ shard_size=5M → 约2.9GB)
|
||
"""
|
||
split_dir = output_dir / split_name
|
||
split_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||
shard_count = 0
|
||
shard_idx = 0
|
||
total = 0
|
||
|
||
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
|
||
|
||
for batch in dataloader:
|
||
batch_size = batch["input_ids"].size(0)
|
||
remaining = num_samples - total
|
||
if remaining <= 0:
|
||
break
|
||
take = min(batch_size, remaining)
|
||
|
||
extracted = _extract_batch(batch, take)
|
||
for f in FIELDS:
|
||
shard_buffers[f].append(extracted[f])
|
||
|
||
shard_count += take
|
||
total += take
|
||
pbar.update(take)
|
||
|
||
if shard_count >= shard_size:
|
||
merged = {}
|
||
for f in FIELDS:
|
||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
||
|
||
shard_idx += 1
|
||
shard_buffers = {f: [] for f in FIELDS}
|
||
shard_count = 0
|
||
del merged
|
||
gc.collect()
|
||
|
||
if total >= num_samples:
|
||
break
|
||
|
||
# 写入最后一个不满的分片
|
||
if shard_count > 0:
|
||
merged = {}
|
||
for f in FIELDS:
|
||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
||
shard_idx += 1
|
||
|
||
pbar.close()
|
||
|
||
actual_count = min(total, num_samples)
|
||
num_shards = shard_idx
|
||
|
||
shard_sizes = [shard_size] * (num_shards - 1)
|
||
if num_shards > 0:
|
||
shard_sizes.append(actual_count - sum(shard_sizes))
|
||
|
||
metadata = {
|
||
"num_samples": actual_count,
|
||
"max_seq_length": max_seq_length,
|
||
"dtype": "int16",
|
||
"fields": FIELDS,
|
||
"shard_size": shard_size,
|
||
"num_shards": num_shards,
|
||
"shard_sizes": shard_sizes,
|
||
}
|
||
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
||
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
||
|
||
total_size = sum(
|
||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
||
)
|
||
logger.info(
|
||
f"{split_name}: {actual_count} samples in {num_shards} shards, "
|
||
f"{total_size / (1024**3):.2f} GB (compressed)"
|
||
)
|
||
|
||
return actual_count
|
||
|
||
|
||
def main():
|
||
console = Console()
|
||
|
||
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
|
||
parser.add_argument(
|
||
"--train-data-path",
|
||
type=str,
|
||
required=True,
|
||
help="训练数据集路径(HuggingFace格式)",
|
||
)
|
||
parser.add_argument(
|
||
"--eval-data-path",
|
||
type=str,
|
||
required=True,
|
||
help="评估数据集路径(HuggingFace格式)",
|
||
)
|
||
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
||
parser.add_argument(
|
||
"--num-train-samples", type=int, required=True, help="训练集样本数量"
|
||
)
|
||
parser.add_argument(
|
||
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
|
||
)
|
||
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
|
||
parser.add_argument(
|
||
"--num-workers", type=int, default=2, help="DataLoader worker数量"
|
||
)
|
||
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
|
||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||
parser.add_argument(
|
||
"--shard-size",
|
||
type=int,
|
||
default=5_000_000,
|
||
help="分片大小(样本数),控制内存峰值(默认500万,约2.9GB/分片未压缩)",
|
||
)
|
||
parser.add_argument(
|
||
"--py-style-weight",
|
||
type=str,
|
||
default="9,2,1",
|
||
help="拼音风格权重(逗号分隔)",
|
||
)
|
||
parser.add_argument(
|
||
"--shuffle-buffer-size",
|
||
type=int,
|
||
default=2000000,
|
||
help="数据集shuffle缓冲区大小",
|
||
)
|
||
parser.add_argument(
|
||
"--length-weights",
|
||
type=str,
|
||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||
help="词长权重",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
torch.manual_seed(args.seed)
|
||
np.random.seed(args.seed)
|
||
|
||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||
length_weights = {
|
||
int(k): int(v)
|
||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||
}
|
||
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
train_max_iter = args.num_train_samples * 5
|
||
eval_max_iter = args.num_eval_samples * 5
|
||
|
||
shard_mem_gb = args.shard_size * 578 / (1024**3)
|
||
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
|
||
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
|
||
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
|
||
console.print(f"输出目录: {output_dir}")
|
||
console.print(f"数据类型: int16")
|
||
console.print(
|
||
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
|
||
)
|
||
console.print()
|
||
|
||
num_train_workers = args.num_workers
|
||
num_eval_workers = max(1, args.num_workers // 2)
|
||
|
||
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
|
||
train_dataset = PinyinInputDataset(
|
||
data_path=args.train_data_path,
|
||
max_workers=num_train_workers,
|
||
max_iter_length=train_max_iter,
|
||
max_seq_length=args.max_seq_length,
|
||
text_field="text",
|
||
py_style_weight=py_style_weight,
|
||
shuffle_buffer_size=100,
|
||
length_weights=length_weights,
|
||
)
|
||
|
||
train_dataloader = DataLoader(
|
||
train_dataset,
|
||
batch_size=args.batch_size,
|
||
num_workers=num_train_workers,
|
||
pin_memory=False,
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
||
prefetch_factor=2,
|
||
persistent_workers=True if num_train_workers > 0 else False,
|
||
)
|
||
|
||
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
|
||
eval_dataset = PinyinInputDataset(
|
||
data_path=args.eval_data_path,
|
||
max_workers=num_eval_workers,
|
||
max_iter_length=eval_max_iter,
|
||
max_seq_length=args.max_seq_length,
|
||
text_field="text",
|
||
py_style_weight=py_style_weight,
|
||
shuffle_buffer_size=100,
|
||
length_weights=length_weights,
|
||
)
|
||
|
||
eval_dataloader = DataLoader(
|
||
eval_dataset,
|
||
batch_size=args.batch_size,
|
||
num_workers=num_eval_workers,
|
||
pin_memory=False,
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
||
prefetch_factor=2,
|
||
persistent_workers=True if num_eval_workers > 0 else False,
|
||
)
|
||
|
||
logger.info("开始收集训练数据...")
|
||
train_count = collect_samples(
|
||
train_dataloader,
|
||
args.num_train_samples,
|
||
output_dir,
|
||
"train",
|
||
args.max_seq_length,
|
||
args.shard_size,
|
||
)
|
||
|
||
if train_count < args.num_train_samples:
|
||
logger.warning(
|
||
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
|
||
)
|
||
|
||
logger.info("开始收集评估数据...")
|
||
eval_count = collect_samples(
|
||
eval_dataloader,
|
||
args.num_eval_samples,
|
||
output_dir,
|
||
"eval",
|
||
args.max_seq_length,
|
||
args.shard_size,
|
||
)
|
||
|
||
if eval_count < args.num_eval_samples:
|
||
logger.warning(
|
||
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
|
||
)
|
||
|
||
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
|
||
console.print(f"训练集: {train_count:,} 样本")
|
||
console.print(f"评估集: {eval_count:,} 样本")
|
||
console.print(f"输出目录: {output_dir}")
|
||
|
||
for split in ["train", "eval"]:
|
||
split_dir = output_dir / split
|
||
if split_dir.exists():
|
||
total_size = sum(
|
||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
||
)
|
||
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
||
|
||
|
||
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
|
||
"""
|
||
将过大的 .npz 分片拆分到目标大小。
|
||
|
||
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。
|
||
建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。
|
||
"""
|
||
import gc
|
||
import time
|
||
|
||
from tqdm import tqdm
|
||
|
||
input_path = Path(input_dir)
|
||
output_path = Path(output_dir)
|
||
output_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
metadata_path = input_path / "metadata.json"
|
||
if not metadata_path.exists():
|
||
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
|
||
|
||
with open(metadata_path) as f:
|
||
metadata = json.load(f)
|
||
|
||
shard_files = sorted(input_path.glob("shard_*.npz"))
|
||
console = Console()
|
||
console.print(
|
||
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
|
||
)
|
||
|
||
global_shard_idx = 0
|
||
all_shard_sizes: List[int] = []
|
||
total_samples = 0
|
||
|
||
for sf in tqdm(shard_files, desc="Shards"):
|
||
t0 = time.time()
|
||
data = np.load(str(sf))
|
||
n_samples = data[list(data.keys())[0]].shape[0]
|
||
|
||
for start in range(0, n_samples, target_size):
|
||
end = min(start + target_size, n_samples)
|
||
chunk_data: Dict[str, np.ndarray] = {}
|
||
for field in data.keys():
|
||
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
|
||
|
||
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
|
||
np.savez_compressed(out_file, **chunk_data)
|
||
chunk_size = end - start
|
||
all_shard_sizes.append(chunk_size)
|
||
total_samples += chunk_size
|
||
del chunk_data
|
||
gc.collect()
|
||
|
||
global_shard_idx += 1
|
||
|
||
data.close()
|
||
del data
|
||
gc.collect()
|
||
|
||
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
|
||
metadata["num_samples"] = total_samples
|
||
metadata["num_shards"] = global_shard_idx
|
||
metadata["shard_sizes"] = all_shard_sizes
|
||
metadata.pop("shard_size", None)
|
||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||
|
||
console.print(
|
||
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
|
||
)
|
||
|
||
|
||
def _dispatch():
|
||
import sys
|
||
|
||
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
|
||
import argparse as ap
|
||
|
||
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
|
||
p.add_argument(
|
||
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
|
||
)
|
||
p.add_argument("--output-dir", required=True, help="输出目录")
|
||
p.add_argument(
|
||
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
|
||
)
|
||
args = p.parse_args(sys.argv[2:])
|
||
resplit_shards(args.input_dir, args.output_dir, args.target_size)
|
||
else:
|
||
main()
|
||
|
||
|
||
app = _dispatch
|
||
|
||
if __name__ == "__main__":
|
||
_dispatch()
|