feat(data-preprocess): 预处理数据预打乱以提升训练效率

This commit is contained in:
songsenand 2026-05-15 13:49:41 +08:00
parent 0862b5b8fc
commit 722912f296
8 changed files with 753 additions and 35 deletions

View File

@ -0,0 +1,229 @@
# 预处理数据预打乱方案
## 目标
用 CPU 机时预打乱数据,让训练时直接用 `shuffle=False` 顺序读取,消除跨分片缓存抖动和 CPU 利用率低的问题。
## 改动清单
### 1. `src/model/subsample.py` — 默认开启输出打乱
**函数签名变更:**
```python
def pass2_subsample(
...,
shuffle: bool = True, # 新增
seed: int = 42, # 新增
) -> Tuple[int, int]:
```
**改动点:**
- `rng = np.random.RandomState()``rng = np.random.RandomState(seed)`
- 新增 `shuffle_rng = np.random.RandomState(seed + 1)` 用于输出打乱
- 在两次 `np.savez_compressed` 写入前line ~251 和 line ~273插入
```python
if shuffle and train_buf_count > 1:
perm = shuffle_rng.permutation(train_buf_count)
for f in FIELDS:
merged[f] = merged[f][perm]
```
- main() 中新增参数:
```python
parser.add_argument("--no-shuffle", action="store_false", dest="shuffle",
help="禁用输出分片内部打乱")
parser.add_argument("--seed", type=int, default=42,
help="随机种子(用于选择+打乱)")
```
- 调用 `pass2_subsample` 时传入 `shuffle=args.shuffle, seed=args.seed`
**metadata.json 新增字段:**
```json
"pre_shuffled": true,
"seed": 42
```
---
### 2. 新建 `src/model/shuffle_npz.py` — 平衡+打乱已有数据
处理流程:
```
Phase 1: npz → 逐字段打乱 → .npymmap 友好)
输入: 19个不平衡 .npz 分片100M样本~10GB 压缩)
输出: 114个临时 .npy 文件6字段 × 19分片~144GB 未压缩)
内存峰值: ~10GB单字段 int16 最大 ~5GB + permuted copy ~5GB
耗时: ~15-20分钟解压+写入)
Phase 2: .npy → 平衡分配 → .npz
输入: Phase 1 的 .npy 文件mmap 模式)
输出: 100个平衡 .npz 分片每片100万样本~100MB/片压缩后)
每个输出分片 = 从19个源各取比例份额 → concatenate → shuffle → save
内存峰值: ~3GB1个输出缓冲 + mmap pages
耗时: ~10-15分钟mmap读取+压缩写入)
总计: ~30-40分钟峰值内存 ~10GB
```
**磁盘需求:**
- 临时 .npy 文件Phase 1→2 中间产物):~144GB
- 最终输出 .npz~10GB
- 临时文件在 Phase 2 完成后自动删除
**用法:**
```bash
python -m src.model.shuffle_npz \
--input-dir /home/songsenand/DataSet/SubPro \
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
--shard-size 1000000 \
--seed 42
```
**关键实现:**
Phase 1 — 逐字段加载+打乱:
```python
for src_idx in range(num_shards):
data = np.load(shard_path) # lazy NpzFile
n = shard_sizes[src_idx]
perm = rng.permutation(n)
for field in FIELDS:
arr = data[field].copy() # ~5GB peak (input_ids)
shuffled = arr[perm] # ~5GB temp
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
del arr, shuffled
gc.collect()
data.close()
```
Phase 2 — 平衡分配:
```python
# 打开所有源 mmap读模式零内存
src_mmaps = [
{f: np.load(temp_dir / f / f"shard_{i:06d}.npy", mmap_mode='r') for f in FIELDS}
for i in range(num_shards)
]
for out_j in range(num_output_shards):
buffers = {f: [] for f in FIELDS}
for src_i in range(num_shards):
s = shard_sizes[src_i]
start = (out_j * s) // num_output_shards
end = ((out_j + 1) * s) // num_output_shards
if start >= end:
continue
for f in FIELDS:
chunk = src_mmaps[src_i][f][start:end].copy()
buffers[f].append(chunk)
output = {f: np.concatenate(buffers[f]) for f in FIELDS}
# 额外打乱(跨源混合)
perm = rng.permutation(len(output[FIELDS[0]]))
for f in FIELDS:
output[f] = output[f][perm]
np.savez_compressed(output_dir / f"shard_{out_j:06d}.npz", **output)
del output, buffers, perm
gc.collect()
```
**输出 metadata.json**
```json
{
"num_samples": 99998406,
"max_seq_length": 128,
"dtype": "int16",
"fields": [...],
"shard_size": 1000000,
"num_shards": 100,
"shard_sizes": [1000000, ..., 998406],
"pre_shuffled": true,
"seed": 42
}
```
**eval 目录处理:** 如果 `--input-dir/eval/` 存在,直接复制到 `--output-dir/eval/`eval 数据量小,不需要打乱)
---
### 3. `src/model/trainer.py` — 预处理数据禁用 shuffle
**改动点train 函数line ~1258-1272**
```python
if is_train_preprocessed:
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
# pre_shuffled 数据不需要 DataLoader 的 RandomSampler
shuffle_train = not train_dataset.metadata.get("pre_shuffled", False)
total_steps = (len(train_dataset) // batch_size) * num_epochs
# 支持 max_iter_length 限制总步数
if max_iter_length > 0:
max_steps_per_epoch = max_iter_length // batch_size
total_steps = min(total_steps, max_steps_per_epoch * num_epochs)
train_num_workers = min(num_workers, 1)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=shuffle_train, # 预打乱数据不 shuffle
)
```
**eval DataLoaderline ~1295-1303**
```python
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=0, # eval 数据小,单进程足够
pin_memory=torch.cuda.is_available(),
shuffle=False, # eval 不需要打乱
)
```
**`create_dataloader` 函数line ~1076-1114** 无需改动,`shuffle` 参数已透传。
---
### 4. `src/model/preprocessed_dataset.py`
现有代码无需修改。`PreProcessedDataset` 已经可以正确处理 `shuffle=False` 的情况PyTorch 的 `SequentialSampler` 会按 0..N-1 顺序读取)。
`metadata["pre_shuffled"]` 字段由 subsample.py 和 shuffle_npz.py 在写入 metadata.json 时添加,训练代码读取判断即可。
---
## 执行顺序
```bash
# Step 1: 打乱并平衡已有的 100M 数据集
python -m src.model.shuffle_npz \
--input-dir /home/songsenand/DataSet/SubPro \
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
--shard-size 1000000 \
--seed 42
# Step 2: 用新数据训练(数据已打乱,顺序读取即可)
uv run train-model train \
--train-data-path /home/songsenand/DataSet/SubPro_Shuffled/train \
--eval-data-path /home/songsenand/DataSet/SubPro_Shuffled/eval \
-b 16 \
-o ~/tmp \
--eval-frequency 20 \
--save-frequency 40
```
## 预期效果
| 改动前 | 改动后 |
|---|---|
| 每 batch 跨 13 个 shard | 顺序读1 个 shard 在缓存中 |
| 每 batch 数据加载 2-3 分钟 | ~0.1-0.5 秒(纯 mmap/memory |
| CPU 利用率 10% | 正常(训练计算是瓶颈) |
| 内存 40GB+ | <20GB shard 1M 样本 1.4GB |

View File

@ -351,7 +351,7 @@ def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_expert
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
output.index_add_(0, idx, w * experts[e](x_flat[idx]))
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
return output
@ -411,7 +411,11 @@ class MoELayer(nn.Module):
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx]))
out_flat.index_add_(
0,
idx,
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
)
elif self.moe_mode == "sparse_allow_graph":
out_flat = _sparse_moe_dispatch(

View File

@ -231,7 +231,7 @@ def main():
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
if dataset._is_sharded:
console.print(
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
f" 分片数: {dataset._num_shards}, 每分片: {min(dataset._shard_sizes):,} - {max(dataset._shard_sizes):,} 样本"
)
console.print()

View File

@ -29,6 +29,9 @@
import argparse
import gc
import json
import struct
import time
import zipfile
from pathlib import Path
from typing import Dict, List
@ -132,6 +135,10 @@ def collect_samples(
actual_count = min(total, num_samples)
num_shards = shard_idx
shard_sizes = [shard_size] * (num_shards - 1)
if num_shards > 0:
shard_sizes.append(actual_count - sum(shard_sizes))
metadata = {
"num_samples": actual_count,
"max_seq_length": max_seq_length,
@ -139,6 +146,7 @@ def collect_samples(
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": num_shards,
"shard_sizes": shard_sizes,
}
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
json.dump(metadata, fp, indent=2, ensure_ascii=False)
@ -329,7 +337,97 @@ def main():
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
app = main
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
"""
将过大的 .npz 分片拆分到目标大小
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)
建议在内存充裕的 CPU 机器上运行拆分后将小分片拷贝至 GPU 服务器
"""
import gc
import time
from tqdm import tqdm
input_path = Path(input_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
metadata_path = input_path / "metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
with open(metadata_path) as f:
metadata = json.load(f)
shard_files = sorted(input_path.glob("shard_*.npz"))
console = Console()
console.print(
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
)
global_shard_idx = 0
all_shard_sizes: List[int] = []
total_samples = 0
for sf in tqdm(shard_files, desc="Shards"):
t0 = time.time()
data = np.load(str(sf))
n_samples = data[list(data.keys())[0]].shape[0]
for start in range(0, n_samples, target_size):
end = min(start + target_size, n_samples)
chunk_data: Dict[str, np.ndarray] = {}
for field in data.keys():
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
np.savez_compressed(out_file, **chunk_data)
chunk_size = end - start
all_shard_sizes.append(chunk_size)
total_samples += chunk_size
del chunk_data
gc.collect()
global_shard_idx += 1
data.close()
del data
gc.collect()
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
metadata["num_samples"] = total_samples
metadata["num_shards"] = global_shard_idx
metadata["shard_sizes"] = all_shard_sizes
metadata.pop("shard_size", None)
json.dump(metadata, f, indent=2, ensure_ascii=False)
console.print(
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
)
def _dispatch():
import sys
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
import argparse as ap
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
p.add_argument(
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
)
p.add_argument("--output-dir", required=True, help="输出目录")
p.add_argument(
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
)
args = p.parse_args(sys.argv[2:])
resplit_shards(args.input_dir, args.output_dir, args.target_size)
else:
main()
app = _dispatch
if __name__ == "__main__":
main()
_dispatch()

View File

@ -8,6 +8,7 @@
GPU 服务器仅需存放压缩后的 .npz 文件无需解压到硬盘
"""
import time
import gc
import json
import struct
@ -47,6 +48,13 @@ def _read_shard_size(npz_path: Path) -> int:
return header["shape"][0]
def _build_offsets(sizes: List[int]) -> List[int]:
offsets = [0]
for s in sizes:
offsets.append(offsets[-1] + s)
return offsets
def is_preprocessed_data(path: str) -> bool:
"""判断路径是否为预处理数据目录"""
p = Path(path)
@ -94,38 +102,48 @@ class PreProcessedDataset(Dataset):
"""
def __init__(self, data_dir: str, max_cache_shards: int = 1):
t_start = time.perf_counter()
self.data_dir = Path(data_dir)
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
self.metadata = json.load(f)
self.max_seq_length = self.metadata["max_seq_length"]
self._shard_size: Optional[int] = self.metadata.get("shard_size")
if self._shard_size is not None:
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
if shard_files:
self._num_shards = len(shard_files)
if self._num_shards == 0:
raise FileNotFoundError(
f"No shard_*.npz files found in {self.data_dir}"
if (
"shard_sizes" in self.metadata
and len(self.metadata["shard_sizes"]) == self._num_shards
):
self._shard_sizes = self.metadata["shard_sizes"]
logger.info(
f"Using shard_sizes from metadata.json ({self._num_shards} shards)"
)
self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files]
self._shard_offsets = [0]
for s in self._shard_sizes:
self._shard_offsets.append(self._shard_offsets[-1] + s)
else:
t_scan = time.perf_counter()
self._shard_sizes = [_read_shard_size(sf) for sf in shard_files]
logger.info(
f"Scanned {self._num_shards} shard headers in "
f"{time.perf_counter() - t_scan:.2f}s"
)
self._shard_offsets = _build_offsets(self._shard_sizes)
self.num_samples = self._shard_offsets[-1]
self._is_sharded = True
self._cache = _ShardCache(max_size=max_cache_shards)
logger.info(
f"Loaded sharded dataset: {self.num_samples:,} samples, "
f"{self._num_shards} shards"
f"{self._num_shards} shards, "
f"init in {time.perf_counter() - t_start:.2f}s"
)
else:
self.num_samples = self.metadata["num_samples"]
self._is_sharded = False
self._load_single_files()
logger.info(
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)"
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap), "
f"init in {time.perf_counter() - t_start:.2f}s"
)
def _load_single_files(self):
@ -157,6 +175,10 @@ class PreProcessedDataset(Dataset):
f"Index {idx} out of range for dataset with {self.num_samples} samples"
)
if not hasattr(self, "_first_access_logged"):
self._first_access_logged = True
logger.info("First __getitem__ call (initial shard load may be slow)")
if self._is_sharded:
lo, hi = 0, len(self._shard_offsets) - 1
while lo < hi:
@ -168,7 +190,7 @@ class PreProcessedDataset(Dataset):
shard_idx = lo - 1
local_idx = idx - self._shard_offsets[shard_idx]
shard_data = self._cache.get(shard_idx, self._load_shard)
return {
result = {
"input_ids": torch.from_numpy(
shard_data["input_ids"][local_idx].astype(np.int64)
),
@ -189,7 +211,7 @@ class PreProcessedDataset(Dataset):
),
}
else:
return {
result = {
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
"token_type_ids": torch.from_numpy(
self.token_type_ids[idx].astype(np.int64)
@ -203,6 +225,7 @@ class PreProcessedDataset(Dataset):
),
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
}
return result
def preprocessed_collate_fn(batch):

318
src/model/shuffle_npz.py Normal file
View File

@ -0,0 +1,318 @@
#!/usr/bin/env python3
"""
打乱并平衡预处理 .npz 分片
两阶段处理
Phase 1: 逐分片内部打乱 写入临时 .npymmap 友好
Phase 2: 从临时 .npy 按比例分配到平衡的输出 .npz 分片
用法
python -m src.model.shuffle_npz \
--input-dir /path/to/subsampled \
--output-dir /path/to/shuffled \
--shard-size 1000000 \
--seed 42
"""
import argparse
import gc
import json
import shutil
from pathlib import Path
from typing import Dict, List
import numpy as np
from loguru import logger
from rich.console import Console
from tqdm import tqdm
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _phase1_shuffle_to_npy(
input_dir: Path,
split: str,
temp_dir: Path,
shard_sizes: List[int],
seed: int,
):
"""
Phase 1: 逐分片内部打乱写入 .npy 文件
每个分片的每个字段写入一个独立 .npy路径格式
temp_dir/<field>/shard_<idx>.npy
峰值内存 ~10GB单字段 int16 5GB + permuted copy 5GB
"""
metadata_path = input_dir / split / "metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
num_shards = metadata["num_shards"]
rng = np.random.RandomState(seed)
for field in FIELDS:
(temp_dir / field).mkdir(parents=True, exist_ok=True)
pbar = tqdm(
total=num_shards,
desc=f"Phase 1: shuffling {split}",
unit="shard",
)
for src_idx in range(num_shards):
shard_path = input_dir / split / f"shard_{src_idx:06d}.npz"
data = np.load(shard_path)
n = shard_sizes[src_idx]
perm = rng.permutation(n)
for field in FIELDS:
arr = data[field].copy()
shuffled = arr[perm]
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
del arr, shuffled
gc.collect()
data.close()
pbar.update(1)
pbar.close()
def _phase2_rebalance_to_npz(
temp_dir: Path,
output_dir: Path,
split: str,
shard_sizes: List[int],
target_shard_size: int,
max_seq_length: int,
seed: int,
) -> List[int]:
"""
Phase 2: mmap .npy 文件中按比例分配样本到平衡的 .npz 输出分片
每个输出分片从所有源分片各取 proportional chunk concat shuffle save
内存峰值 ~3GB一个输出缓冲 + mmap pages
"""
output_split_dir = output_dir / split
output_split_dir.mkdir(parents=True, exist_ok=True)
total_samples = sum(shard_sizes)
num_output_shards = (total_samples + target_shard_size - 1) // target_shard_size
rng = np.random.RandomState(seed + 2)
logger.info(
f"Phase 2: distributing {total_samples:,} samples into "
f"{num_output_shards} output shards (~{target_shard_size:,} each)"
)
# 打开所有源 .npy 的 mmap读模式零 RAM 开销)
src_mmaps: List[Dict[str, np.ndarray]] = []
for src_idx in range(len(shard_sizes)):
shard_mmap = {}
for field in FIELDS:
npy_path = temp_dir / field / f"shard_{src_idx:06d}.npy"
shard_mmap[field] = np.load(npy_path, mmap_mode="r")
src_mmaps.append(shard_mmap)
output_shard_sizes: List[int] = []
pbar = tqdm(
total=num_output_shards,
desc=f"Phase 2: writing {split}",
unit="shard",
)
for out_j in range(num_output_shards):
buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
for src_i in range(len(shard_sizes)):
s = shard_sizes[src_i]
start = (out_j * s) // num_output_shards
end = ((out_j + 1) * s) // num_output_shards
if start >= end:
continue
for field in FIELDS:
chunk = src_mmaps[src_i][field][start:end].copy()
buffers[field].append(chunk)
output = {}
for field in FIELDS:
output[field] = (
np.concatenate(buffers[field])
if len(buffers[field]) > 1
else buffers[field][0]
)
out_count = len(output[FIELDS[0]])
if out_count > 1:
perm = rng.permutation(out_count)
for field in FIELDS:
output[field] = output[field][perm]
np.savez_compressed(output_split_dir / f"shard_{out_j:06d}.npz", **output)
output_shard_sizes.append(out_count)
del output, buffers
gc.collect()
pbar.update(1)
pbar.close()
# 关闭所有 mmap
src_mmaps.clear()
gc.collect()
return output_shard_sizes
def _copy_eval(input_dir: Path, output_dir: Path):
"""复制 eval 目录(数据量小,无需打乱)。"""
src_eval = input_dir / "eval"
dst_eval = output_dir / "eval"
if not src_eval.exists():
return
logger.info(f"Copying eval data from {src_eval}")
if dst_eval.exists():
shutil.rmtree(dst_eval)
shutil.copytree(src_eval, dst_eval)
def main():
console = Console()
parser = argparse.ArgumentParser(description="打乱并平衡预处理 .npz 分片")
parser.add_argument(
"--input-dir", type=str, required=True, help="输入目录(含 train/ 和 eval/"
)
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
parser.add_argument(
"--shard-size",
type=int,
default=1_000_000,
help="输出分片大小(样本数),默认 100 万",
)
parser.add_argument("--seed", type=int, default=42, help="随机种子")
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_meta_path = input_dir / "train" / "metadata.json"
if not train_meta_path.exists():
console.print(f"[red]错误: {train_meta_path} 不存在[/red]")
return
with open(train_meta_path) as f:
train_meta = json.load(f)
max_seq_length = train_meta["max_seq_length"]
shard_sizes: List[int] = train_meta["shard_sizes"]
total_samples = sum(shard_sizes)
console.print("[bold cyan]=== 打乱并平衡预处理数据 ===[/bold cyan]")
console.print(f"输入目录: {input_dir}")
console.print(f"输出目录: {output_dir}")
console.print(f"总样本数: {total_samples:,}")
console.print(f"源分片数: {len(shard_sizes)}")
console.print(
f"目标分片大小: {args.shard_size:,} "
f"(约 {(total_samples + args.shard_size - 1) // args.shard_size} 个分片)"
)
console.print(f"随机种子: {args.seed}")
console.print()
# 临时目录
temp_dir = Path("/home/songsenand/tmp/shuffle_npz_temp")
if temp_dir.exists():
shutil.rmtree(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
temp_train_dir = temp_dir / "train"
temp_train_dir.mkdir(parents=True, exist_ok=True)
try:
# ── Phase 1 ──
console.print("[bold]Phase 1: 逐分片打乱 → 临时 .npy[/bold]")
_phase1_shuffle_to_npy(
input_dir=input_dir,
split="train",
temp_dir=temp_train_dir,
shard_sizes=shard_sizes,
seed=args.seed,
)
# ── Phase 2 ──
console.print("[bold]Phase 2: 按比例分配到平衡输出分片[/bold]")
output_shard_sizes = _phase2_rebalance_to_npz(
temp_dir=temp_train_dir,
output_dir=output_dir,
split="train",
shard_sizes=shard_sizes,
target_shard_size=args.shard_size,
max_seq_length=max_seq_length,
seed=args.seed,
)
# ── 写 metadata ──
train_output_meta = {
"num_samples": sum(output_shard_sizes),
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": args.shard_size,
"num_shards": len(output_shard_sizes),
"shard_sizes": output_shard_sizes,
"pre_shuffled": True,
"seed": args.seed,
}
out_train_dir = output_dir / "train"
with open(out_train_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(train_output_meta, f, indent=2, ensure_ascii=False)
# ── 复制 eval ──
_copy_eval(input_dir, output_dir)
finally:
# 清理临时文件
console.print("[dim]清理临时文件...[/dim]")
if temp_dir.exists():
shutil.rmtree(temp_dir)
# ── 总结 ──
console.print()
console.print("[bold green]=== 完成 ===[/bold green]")
console.print(
f"train/: {sum(output_shard_sizes):,} 样本, "
f"{len(output_shard_sizes)} 分片, "
f"pre_shuffled=True"
)
for sdir_name in ["train", "eval"]:
sdir = output_dir / sdir_name
if sdir.exists():
total_size = sum(
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
)
meta_path = sdir / "metadata.json"
if meta_path.exists():
with open(meta_path) as mf:
meta = json.load(mf)
else:
meta = {}
console.print(
f" {sdir_name}/: {total_size / (1024**3):.2f} GB, "
f"{meta.get('num_shards', '?')} shards"
)
if __name__ == "__main__":
main()

View File

@ -128,6 +128,8 @@ def pass2_subsample(
eval_map: Dict[int, List[int]],
shard_sizes: List[int],
shard_size: int = 5_000_000,
shuffle: bool = True,
seed: int = 42,
) -> Tuple[int, int]:
"""
第2遍扫描读取全部字段抽取评估样本 + 按精确保留配额子采样训练样本
@ -147,12 +149,19 @@ def pass2_subsample(
output_train_dir.mkdir(parents=True, exist_ok=True)
output_eval_dir.mkdir(parents=True, exist_ok=True)
rng = np.random.RandomState()
# 清理旧输出分片(避免残留文件污染)
for old_shard in sorted(output_train_dir.glob("shard_*.npz")):
old_shard.unlink()
logger.debug(f"Removed old shard: {old_shard.name}")
rng = np.random.RandomState(seed)
shuffle_rng = np.random.RandomState(seed + 1)
remaining = dict(quotas) # {label_id: remaining_to_keep}
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
train_buf_count = 0
train_shard_idx = 0
train_shard_sizes: List[int] = []
total_train_kept = 0
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
@ -247,9 +256,14 @@ def pass2_subsample(
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(train_buffers[f], axis=0)
if shuffle and train_buf_count > 1:
perm = shuffle_rng.permutation(train_buf_count)
for f in FIELDS:
merged[f] = merged[f][perm]
np.savez_compressed(
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
)
train_shard_sizes.append(train_buf_count)
logger.debug(
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
)
@ -268,9 +282,14 @@ def pass2_subsample(
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(train_buffers[f], axis=0)
if shuffle and train_buf_count > 1:
perm = shuffle_rng.permutation(train_buf_count)
for f in FIELDS:
merged[f] = merged[f][perm]
np.savez_compressed(
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
)
train_shard_sizes.append(train_buf_count)
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
train_shard_idx += 1
@ -291,6 +310,9 @@ def pass2_subsample(
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": train_shard_idx,
"shard_sizes": train_shard_sizes,
"pre_shuffled": shuffle,
"seed": seed,
}
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
@ -334,6 +356,15 @@ def main():
help="输出分片大小(样本数)",
)
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
parser.add_argument(
"--no-shuffle",
action="store_false",
dest="shuffle",
help="禁用输出分片内部打乱",
)
parser.add_argument(
"--seed", type=int, default=42, help="随机种子(用于标签选择 + 输出打乱)"
)
args = parser.parse_args()
input_dir = Path(args.input_dir)
@ -346,6 +377,8 @@ def main():
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
console.print(f"目标训练集: {args.target_total:,}")
console.print(f"评估集: {args.num_eval}")
console.print(f"输出分片打乱: {'' if args.shuffle else ''}")
console.print(f"随机种子: {args.seed}")
console.print()
# ── 第 1 遍:统计 ──
@ -382,8 +415,8 @@ def main():
console.print()
# ── 抽取评估集位置 ──
rng = np.random.RandomState()
eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False)
eval_rng = np.random.RandomState(args.seed + 100)
eval_positions = eval_rng.choice(total_samples, size=args.num_eval, replace=False)
eval_positions.sort()
eval_map = _global_to_shard(eval_positions, shard_sizes)
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
@ -399,6 +432,8 @@ def main():
eval_map,
shard_sizes,
args.shard_size,
shuffle=args.shuffle,
seed=args.seed,
)
# ── 输出总结 ──

View File

@ -632,7 +632,6 @@ class Trainer:
def _create_progress(self) -> Progress:
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
@ -725,11 +724,7 @@ class Trainer:
progress.reset(
batch_task,
total=steps_per_epoch,
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}",
)
progress.update(
epoch_task,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}",
description=f"[green]Epoch {epoch + 1} Step 0/{steps_per_epoch}",
)
epoch_step = 0
@ -762,7 +757,8 @@ class Trainer:
progress.update(
batch_task,
advance=1,
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}"
description=f"[green]Epoch {epoch + 1} "
f"Step {epoch_step + 1}/{steps_per_epoch}"
f" | Loss: {loss:.4f}"
f" | LR: {current_lr:.2e}",
)
@ -816,7 +812,7 @@ class Trainer:
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
)
progress.console.log(log_text)
progress.log(log_text)
# 重置累积指标
accumulated_loss = 0.0
@ -1260,16 +1256,31 @@ def train(
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
if is_train_preprocessed:
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
total_steps = (len(train_dataset) // batch_size) * num_epochs
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=2)
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
# 预打乱数据不需要 DataLoader 的 RandomSampler避免跨分片解压抖动
shuffle_train = not pre_shuffled
if max_iter_length > 0:
max_samples_per_epoch = max_iter_length
capped_samples = min(len(train_dataset), max_samples_per_epoch)
else:
capped_samples = len(train_dataset)
total_steps = (capped_samples // batch_size) * num_epochs
train_num_workers = min(num_workers, 1)
logger.info(
f"Preprocessed dataset: {len(train_dataset):,} samples, "
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
f"workers={train_num_workers}, steps={total_steps:,}"
)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=True,
shuffle=shuffle_train,
)
config_table.add_row("数据", "训练数据类型", "预处理数据")
config_table.add_row("数据", "预打乱", str(pre_shuffled))
else:
train_dataset = PinyinInputDataset(
data_path=train_data_path,
@ -1296,7 +1307,7 @@ def train(
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=2,
num_workers=0,
pin_memory=torch.cuda.is_available(),
shuffle=False,
)