SUimeModelTraner/src/model/preprocess.py

434 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
预处理脚本:将 PinyinInputDataset 的输出转换为分片压缩 .npz 文件。
采用分片流式写入,内存峰值固定为 shard_size 级别,不随总样本数增长。
每个分片使用 np.savez_compressed 保存zlib 压缩GPU 服务器无需解压到硬盘。
用法:
python -m model.preprocess \
--train-data-path "some/hf_dataset" \
--eval-data-path "some/hf_dataset" \
--output-dir ./preprocessed \
--num-train-samples 5000000 \
--num-eval-samples 8192
生成目录结构:
output_dir/
train/
metadata.json
shard_000.npz (5M样本, 6个字段, zlib压缩)
shard_001.npz
...
eval/
metadata.json
shard_000.npz
...
"""
import argparse
import gc
import json
import struct
import time
import zipfile
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
from loguru import logger
from rich.console import Console
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .trainer import preprocess_collate_fn, worker_init_fn
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
result = {}
for f in FIELDS:
tensor = batch[f][:take]
arr = tensor.numpy().astype(np.int16)
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
arr = arr.squeeze(-1)
result[f] = arr
return result
def collect_samples(
dataloader: DataLoader,
num_samples: int,
output_dir: Path,
split_name: str,
max_seq_length: int = 128,
shard_size: int = 5_000_000,
) -> int:
"""
分片流式收集样本,每累积 shard_size 个样本保存为一个压缩 .npz 分片。
内存峰值 = shard_size × 每样本字节数约578字节 @ shard_size=5M → 约2.9GB
"""
split_dir = output_dir / split_name
split_dir.mkdir(parents=True, exist_ok=True)
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
shard_count = 0
shard_idx = 0
total = 0
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
for batch in dataloader:
batch_size = batch["input_ids"].size(0)
remaining = num_samples - total
if remaining <= 0:
break
take = min(batch_size, remaining)
extracted = _extract_batch(batch, take)
for f in FIELDS:
shard_buffers[f].append(extracted[f])
shard_count += take
total += take
pbar.update(take)
if shard_count >= shard_size:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(shard_buffers[f], axis=0)
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
shard_idx += 1
shard_buffers = {f: [] for f in FIELDS}
shard_count = 0
del merged
gc.collect()
if total >= num_samples:
break
# 写入最后一个不满的分片
if shard_count > 0:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(shard_buffers[f], axis=0)
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
shard_idx += 1
pbar.close()
actual_count = min(total, num_samples)
num_shards = shard_idx
shard_sizes = [shard_size] * (num_shards - 1)
if num_shards > 0:
shard_sizes.append(actual_count - sum(shard_sizes))
metadata = {
"num_samples": actual_count,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": num_shards,
"shard_sizes": shard_sizes,
}
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
json.dump(metadata, fp, indent=2, ensure_ascii=False)
total_size = sum(
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
)
logger.info(
f"{split_name}: {actual_count} samples in {num_shards} shards, "
f"{total_size / (1024**3):.2f} GB (compressed)"
)
return actual_count
def main():
console = Console()
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
parser.add_argument(
"--train-data-path",
type=str,
required=True,
help="训练数据集路径HuggingFace格式",
)
parser.add_argument(
"--eval-data-path",
type=str,
required=True,
help="评估数据集路径HuggingFace格式",
)
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
parser.add_argument(
"--num-train-samples", type=int, required=True, help="训练集样本数量"
)
parser.add_argument(
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
)
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
parser.add_argument(
"--num-workers", type=int, default=2, help="DataLoader worker数量"
)
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小样本数控制内存峰值默认500万约2.9GB/分片未压缩)",
)
parser.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔)",
)
parser.add_argument(
"--shuffle-buffer-size",
type=int,
default=2000000,
help="数据集shuffle缓冲区大小",
)
parser.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_max_iter = args.num_train_samples * 5
eval_max_iter = args.num_eval_samples * 5
shard_mem_gb = args.shard_size * 578 / (1024**3)
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
console.print(f"输出目录: {output_dir}")
console.print(f"数据类型: int16")
console.print(
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
)
console.print()
num_train_workers = args.num_workers
num_eval_workers = max(1, args.num_workers // 2)
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
train_dataset = PinyinInputDataset(
data_path=args.train_data_path,
max_workers=num_train_workers,
max_iter_length=train_max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=num_train_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=preprocess_collate_fn(args.max_seq_length),
prefetch_factor=2,
persistent_workers=True if num_train_workers > 0 else False,
)
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
eval_dataset = PinyinInputDataset(
data_path=args.eval_data_path,
max_workers=num_eval_workers,
max_iter_length=eval_max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.batch_size,
num_workers=num_eval_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=preprocess_collate_fn(args.max_seq_length),
prefetch_factor=2,
persistent_workers=True if num_eval_workers > 0 else False,
)
logger.info("开始收集训练数据...")
train_count = collect_samples(
train_dataloader,
args.num_train_samples,
output_dir,
"train",
args.max_seq_length,
args.shard_size,
)
if train_count < args.num_train_samples:
logger.warning(
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
)
logger.info("开始收集评估数据...")
eval_count = collect_samples(
eval_dataloader,
args.num_eval_samples,
output_dir,
"eval",
args.max_seq_length,
args.shard_size,
)
if eval_count < args.num_eval_samples:
logger.warning(
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
)
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
console.print(f"训练集: {train_count:,} 样本")
console.print(f"评估集: {eval_count:,} 样本")
console.print(f"输出目录: {output_dir}")
for split in ["train", "eval"]:
split_dir = output_dir / split
if split_dir.exists():
total_size = sum(
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
"""
将过大的 .npz 分片拆分到目标大小。
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。
建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。
"""
import gc
import time
from tqdm import tqdm
input_path = Path(input_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
metadata_path = input_path / "metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
with open(metadata_path) as f:
metadata = json.load(f)
shard_files = sorted(input_path.glob("shard_*.npz"))
console = Console()
console.print(
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
)
global_shard_idx = 0
all_shard_sizes: List[int] = []
total_samples = 0
for sf in tqdm(shard_files, desc="Shards"):
t0 = time.time()
data = np.load(str(sf))
n_samples = data[list(data.keys())[0]].shape[0]
for start in range(0, n_samples, target_size):
end = min(start + target_size, n_samples)
chunk_data: Dict[str, np.ndarray] = {}
for field in data.keys():
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
np.savez_compressed(out_file, **chunk_data)
chunk_size = end - start
all_shard_sizes.append(chunk_size)
total_samples += chunk_size
del chunk_data
gc.collect()
global_shard_idx += 1
data.close()
del data
gc.collect()
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
metadata["num_samples"] = total_samples
metadata["num_shards"] = global_shard_idx
metadata["shard_sizes"] = all_shard_sizes
metadata.pop("shard_size", None)
json.dump(metadata, f, indent=2, ensure_ascii=False)
console.print(
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
)
def _dispatch():
import sys
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
import argparse as ap
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
p.add_argument(
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
)
p.add_argument("--output-dir", required=True, help="输出目录")
p.add_argument(
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
)
args = p.parse_args(sys.argv[2:])
resplit_shards(args.input_dir, args.output_dir, args.target_size)
else:
main()
app = _dispatch
if __name__ == "__main__":
_dispatch()