feat(data-preprocess): 预处理数据预打乱以提升训练效率
This commit is contained in:
parent
0862b5b8fc
commit
722912f296
|
|
@ -0,0 +1,229 @@
|
|||
# 预处理数据预打乱方案
|
||||
|
||||
## 目标
|
||||
|
||||
用 CPU 机时预打乱数据,让训练时直接用 `shuffle=False` 顺序读取,消除跨分片缓存抖动和 CPU 利用率低的问题。
|
||||
|
||||
## 改动清单
|
||||
|
||||
### 1. `src/model/subsample.py` — 默认开启输出打乱
|
||||
|
||||
**函数签名变更:**
|
||||
```python
|
||||
def pass2_subsample(
|
||||
...,
|
||||
shuffle: bool = True, # 新增
|
||||
seed: int = 42, # 新增
|
||||
) -> Tuple[int, int]:
|
||||
```
|
||||
|
||||
**改动点:**
|
||||
- `rng = np.random.RandomState()` → `rng = np.random.RandomState(seed)`
|
||||
- 新增 `shuffle_rng = np.random.RandomState(seed + 1)` 用于输出打乱
|
||||
- 在两次 `np.savez_compressed` 写入前(line ~251 和 line ~273),插入:
|
||||
```python
|
||||
if shuffle and train_buf_count > 1:
|
||||
perm = shuffle_rng.permutation(train_buf_count)
|
||||
for f in FIELDS:
|
||||
merged[f] = merged[f][perm]
|
||||
```
|
||||
- main() 中新增参数:
|
||||
```python
|
||||
parser.add_argument("--no-shuffle", action="store_false", dest="shuffle",
|
||||
help="禁用输出分片内部打乱")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子(用于选择+打乱)")
|
||||
```
|
||||
- 调用 `pass2_subsample` 时传入 `shuffle=args.shuffle, seed=args.seed`
|
||||
|
||||
**metadata.json 新增字段:**
|
||||
```json
|
||||
"pre_shuffled": true,
|
||||
"seed": 42
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 新建 `src/model/shuffle_npz.py` — 平衡+打乱已有数据
|
||||
|
||||
处理流程:
|
||||
|
||||
```
|
||||
Phase 1: npz → 逐字段打乱 → .npy(mmap 友好)
|
||||
输入: 19个不平衡 .npz 分片(100M样本,~10GB 压缩)
|
||||
输出: 114个临时 .npy 文件(6字段 × 19分片,~144GB 未压缩)
|
||||
内存峰值: ~10GB(单字段 int16 最大 ~5GB + permuted copy ~5GB)
|
||||
耗时: ~15-20分钟(解压+写入)
|
||||
|
||||
Phase 2: .npy → 平衡分配 → .npz
|
||||
输入: Phase 1 的 .npy 文件(mmap 模式)
|
||||
输出: 100个平衡 .npz 分片(每片100万样本,~100MB/片压缩后)
|
||||
每个输出分片 = 从19个源各取比例份额 → concatenate → shuffle → save
|
||||
内存峰值: ~3GB(1个输出缓冲 + mmap pages)
|
||||
耗时: ~10-15分钟(mmap读取+压缩写入)
|
||||
|
||||
总计: ~30-40分钟,峰值内存 ~10GB
|
||||
```
|
||||
|
||||
**磁盘需求:**
|
||||
- 临时 .npy 文件(Phase 1→2 中间产物):~144GB
|
||||
- 最终输出 .npz:~10GB
|
||||
- 临时文件在 Phase 2 完成后自动删除
|
||||
|
||||
**用法:**
|
||||
```bash
|
||||
python -m src.model.shuffle_npz \
|
||||
--input-dir /home/songsenand/DataSet/SubPro \
|
||||
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
|
||||
--shard-size 1000000 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
**关键实现:**
|
||||
|
||||
Phase 1 — 逐字段加载+打乱:
|
||||
```python
|
||||
for src_idx in range(num_shards):
|
||||
data = np.load(shard_path) # lazy NpzFile
|
||||
n = shard_sizes[src_idx]
|
||||
perm = rng.permutation(n)
|
||||
|
||||
for field in FIELDS:
|
||||
arr = data[field].copy() # ~5GB peak (input_ids)
|
||||
shuffled = arr[perm] # ~5GB temp
|
||||
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
|
||||
del arr, shuffled
|
||||
gc.collect()
|
||||
|
||||
data.close()
|
||||
```
|
||||
|
||||
Phase 2 — 平衡分配:
|
||||
```python
|
||||
# 打开所有源 mmap(读模式,零内存)
|
||||
src_mmaps = [
|
||||
{f: np.load(temp_dir / f / f"shard_{i:06d}.npy", mmap_mode='r') for f in FIELDS}
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
for out_j in range(num_output_shards):
|
||||
buffers = {f: [] for f in FIELDS}
|
||||
|
||||
for src_i in range(num_shards):
|
||||
s = shard_sizes[src_i]
|
||||
start = (out_j * s) // num_output_shards
|
||||
end = ((out_j + 1) * s) // num_output_shards
|
||||
if start >= end:
|
||||
continue
|
||||
for f in FIELDS:
|
||||
chunk = src_mmaps[src_i][f][start:end].copy()
|
||||
buffers[f].append(chunk)
|
||||
|
||||
output = {f: np.concatenate(buffers[f]) for f in FIELDS}
|
||||
# 额外打乱(跨源混合)
|
||||
perm = rng.permutation(len(output[FIELDS[0]]))
|
||||
for f in FIELDS:
|
||||
output[f] = output[f][perm]
|
||||
|
||||
np.savez_compressed(output_dir / f"shard_{out_j:06d}.npz", **output)
|
||||
del output, buffers, perm
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
**输出 metadata.json:**
|
||||
```json
|
||||
{
|
||||
"num_samples": 99998406,
|
||||
"max_seq_length": 128,
|
||||
"dtype": "int16",
|
||||
"fields": [...],
|
||||
"shard_size": 1000000,
|
||||
"num_shards": 100,
|
||||
"shard_sizes": [1000000, ..., 998406],
|
||||
"pre_shuffled": true,
|
||||
"seed": 42
|
||||
}
|
||||
```
|
||||
|
||||
**eval 目录处理:** 如果 `--input-dir/eval/` 存在,直接复制到 `--output-dir/eval/`(eval 数据量小,不需要打乱)
|
||||
|
||||
---
|
||||
|
||||
### 3. `src/model/trainer.py` — 预处理数据禁用 shuffle
|
||||
|
||||
**改动点(train 函数,line ~1258-1272):**
|
||||
|
||||
```python
|
||||
if is_train_preprocessed:
|
||||
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
|
||||
# pre_shuffled 数据不需要 DataLoader 的 RandomSampler
|
||||
shuffle_train = not train_dataset.metadata.get("pre_shuffled", False)
|
||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
||||
# 支持 max_iter_length 限制总步数
|
||||
if max_iter_length > 0:
|
||||
max_steps_per_epoch = max_iter_length // batch_size
|
||||
total_steps = min(total_steps, max_steps_per_epoch * num_epochs)
|
||||
train_num_workers = min(num_workers, 1)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=train_num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=shuffle_train, # 预打乱数据不 shuffle
|
||||
)
|
||||
```
|
||||
|
||||
**eval DataLoader(line ~1295-1303):**
|
||||
|
||||
```python
|
||||
if is_eval_preprocessed:
|
||||
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=0, # eval 数据小,单进程足够
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=False, # eval 不需要打乱
|
||||
)
|
||||
```
|
||||
|
||||
**`create_dataloader` 函数(line ~1076-1114):** 无需改动,`shuffle` 参数已透传。
|
||||
|
||||
---
|
||||
|
||||
### 4. `src/model/preprocessed_dataset.py`
|
||||
|
||||
现有代码无需修改。`PreProcessedDataset` 已经可以正确处理 `shuffle=False` 的情况(PyTorch 的 `SequentialSampler` 会按 0..N-1 顺序读取)。
|
||||
|
||||
`metadata["pre_shuffled"]` 字段由 subsample.py 和 shuffle_npz.py 在写入 metadata.json 时添加,训练代码读取判断即可。
|
||||
|
||||
---
|
||||
|
||||
## 执行顺序
|
||||
|
||||
```bash
|
||||
# Step 1: 打乱并平衡已有的 100M 数据集
|
||||
python -m src.model.shuffle_npz \
|
||||
--input-dir /home/songsenand/DataSet/SubPro \
|
||||
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
|
||||
--shard-size 1000000 \
|
||||
--seed 42
|
||||
|
||||
# Step 2: 用新数据训练(数据已打乱,顺序读取即可)
|
||||
uv run train-model train \
|
||||
--train-data-path /home/songsenand/DataSet/SubPro_Shuffled/train \
|
||||
--eval-data-path /home/songsenand/DataSet/SubPro_Shuffled/eval \
|
||||
-b 16 \
|
||||
-o ~/tmp \
|
||||
--eval-frequency 20 \
|
||||
--save-frequency 40
|
||||
```
|
||||
|
||||
## 预期效果
|
||||
|
||||
| 改动前 | 改动后 |
|
||||
|---|---|
|
||||
| 每 batch 跨 13 个 shard | 顺序读,1 个 shard 在缓存中 |
|
||||
| 每 batch 数据加载 2-3 分钟 | ~0.1-0.5 秒(纯 mmap/memory) |
|
||||
| CPU 利用率 10% | 正常(训练计算是瓶颈) |
|
||||
| 内存 40GB+ | <20GB(单 shard 1M 样本 ≈ 1.4GB) |
|
||||
|
|
@ -351,7 +351,7 @@ def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_expert
|
|||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if idx.numel() > 0:
|
||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||
output.index_add_(0, idx, w * experts[e](x_flat[idx]))
|
||||
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
|
||||
return output
|
||||
|
||||
|
||||
|
|
@ -411,7 +411,11 @@ class MoELayer(nn.Module):
|
|||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if idx.numel() > 0:
|
||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||
out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx]))
|
||||
out_flat.index_add_(
|
||||
0,
|
||||
idx,
|
||||
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
|
||||
)
|
||||
|
||||
elif self.moe_mode == "sparse_allow_graph":
|
||||
out_flat = _sparse_moe_dispatch(
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ def main():
|
|||
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
||||
if dataset._is_sharded:
|
||||
console.print(
|
||||
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
|
||||
f" 分片数: {dataset._num_shards}, 每分片: {min(dataset._shard_sizes):,} - {max(dataset._shard_sizes):,} 样本"
|
||||
)
|
||||
console.print()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@
|
|||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
@ -132,6 +135,10 @@ def collect_samples(
|
|||
actual_count = min(total, num_samples)
|
||||
num_shards = shard_idx
|
||||
|
||||
shard_sizes = [shard_size] * (num_shards - 1)
|
||||
if num_shards > 0:
|
||||
shard_sizes.append(actual_count - sum(shard_sizes))
|
||||
|
||||
metadata = {
|
||||
"num_samples": actual_count,
|
||||
"max_seq_length": max_seq_length,
|
||||
|
|
@ -139,6 +146,7 @@ def collect_samples(
|
|||
"fields": FIELDS,
|
||||
"shard_size": shard_size,
|
||||
"num_shards": num_shards,
|
||||
"shard_sizes": shard_sizes,
|
||||
}
|
||||
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
||||
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
||||
|
|
@ -329,7 +337,97 @@ def main():
|
|||
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
||||
|
||||
|
||||
app = main
|
||||
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
|
||||
"""
|
||||
将过大的 .npz 分片拆分到目标大小。
|
||||
|
||||
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。
|
||||
建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
input_path = Path(input_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata_path = input_path / "metadata.json"
|
||||
if not metadata_path.exists():
|
||||
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
|
||||
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
shard_files = sorted(input_path.glob("shard_*.npz"))
|
||||
console = Console()
|
||||
console.print(
|
||||
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
|
||||
)
|
||||
|
||||
global_shard_idx = 0
|
||||
all_shard_sizes: List[int] = []
|
||||
total_samples = 0
|
||||
|
||||
for sf in tqdm(shard_files, desc="Shards"):
|
||||
t0 = time.time()
|
||||
data = np.load(str(sf))
|
||||
n_samples = data[list(data.keys())[0]].shape[0]
|
||||
|
||||
for start in range(0, n_samples, target_size):
|
||||
end = min(start + target_size, n_samples)
|
||||
chunk_data: Dict[str, np.ndarray] = {}
|
||||
for field in data.keys():
|
||||
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
|
||||
|
||||
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
|
||||
np.savez_compressed(out_file, **chunk_data)
|
||||
chunk_size = end - start
|
||||
all_shard_sizes.append(chunk_size)
|
||||
total_samples += chunk_size
|
||||
del chunk_data
|
||||
gc.collect()
|
||||
|
||||
global_shard_idx += 1
|
||||
|
||||
data.close()
|
||||
del data
|
||||
gc.collect()
|
||||
|
||||
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
|
||||
metadata["num_samples"] = total_samples
|
||||
metadata["num_shards"] = global_shard_idx
|
||||
metadata["shard_sizes"] = all_shard_sizes
|
||||
metadata.pop("shard_size", None)
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
console.print(
|
||||
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
|
||||
)
|
||||
|
||||
|
||||
def _dispatch():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
|
||||
import argparse as ap
|
||||
|
||||
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
|
||||
p.add_argument(
|
||||
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
|
||||
)
|
||||
p.add_argument("--output-dir", required=True, help="输出目录")
|
||||
p.add_argument(
|
||||
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
|
||||
)
|
||||
args = p.parse_args(sys.argv[2:])
|
||||
resplit_shards(args.input_dir, args.output_dir, args.target_size)
|
||||
else:
|
||||
main()
|
||||
|
||||
|
||||
app = _dispatch
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
_dispatch()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
||||
"""
|
||||
|
||||
import time
|
||||
import gc
|
||||
import json
|
||||
import struct
|
||||
|
|
@ -47,6 +48,13 @@ def _read_shard_size(npz_path: Path) -> int:
|
|||
return header["shape"][0]
|
||||
|
||||
|
||||
def _build_offsets(sizes: List[int]) -> List[int]:
|
||||
offsets = [0]
|
||||
for s in sizes:
|
||||
offsets.append(offsets[-1] + s)
|
||||
return offsets
|
||||
|
||||
|
||||
def is_preprocessed_data(path: str) -> bool:
|
||||
"""判断路径是否为预处理数据目录"""
|
||||
p = Path(path)
|
||||
|
|
@ -94,38 +102,48 @@ class PreProcessedDataset(Dataset):
|
|||
"""
|
||||
|
||||
def __init__(self, data_dir: str, max_cache_shards: int = 1):
|
||||
t_start = time.perf_counter()
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
self.max_seq_length = self.metadata["max_seq_length"]
|
||||
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
||||
|
||||
if self._shard_size is not None:
|
||||
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
||||
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
||||
if shard_files:
|
||||
self._num_shards = len(shard_files)
|
||||
if self._num_shards == 0:
|
||||
raise FileNotFoundError(
|
||||
f"No shard_*.npz files found in {self.data_dir}"
|
||||
if (
|
||||
"shard_sizes" in self.metadata
|
||||
and len(self.metadata["shard_sizes"]) == self._num_shards
|
||||
):
|
||||
self._shard_sizes = self.metadata["shard_sizes"]
|
||||
logger.info(
|
||||
f"Using shard_sizes from metadata.json ({self._num_shards} shards)"
|
||||
)
|
||||
self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files]
|
||||
self._shard_offsets = [0]
|
||||
for s in self._shard_sizes:
|
||||
self._shard_offsets.append(self._shard_offsets[-1] + s)
|
||||
else:
|
||||
t_scan = time.perf_counter()
|
||||
self._shard_sizes = [_read_shard_size(sf) for sf in shard_files]
|
||||
logger.info(
|
||||
f"Scanned {self._num_shards} shard headers in "
|
||||
f"{time.perf_counter() - t_scan:.2f}s"
|
||||
)
|
||||
self._shard_offsets = _build_offsets(self._shard_sizes)
|
||||
self.num_samples = self._shard_offsets[-1]
|
||||
self._is_sharded = True
|
||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||
logger.info(
|
||||
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
||||
f"{self._num_shards} shards"
|
||||
f"{self._num_shards} shards, "
|
||||
f"init in {time.perf_counter() - t_start:.2f}s"
|
||||
)
|
||||
else:
|
||||
self.num_samples = self.metadata["num_samples"]
|
||||
self._is_sharded = False
|
||||
self._load_single_files()
|
||||
logger.info(
|
||||
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)"
|
||||
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap), "
|
||||
f"init in {time.perf_counter() - t_start:.2f}s"
|
||||
)
|
||||
|
||||
def _load_single_files(self):
|
||||
|
|
@ -157,6 +175,10 @@ class PreProcessedDataset(Dataset):
|
|||
f"Index {idx} out of range for dataset with {self.num_samples} samples"
|
||||
)
|
||||
|
||||
if not hasattr(self, "_first_access_logged"):
|
||||
self._first_access_logged = True
|
||||
logger.info("First __getitem__ call (initial shard load may be slow)")
|
||||
|
||||
if self._is_sharded:
|
||||
lo, hi = 0, len(self._shard_offsets) - 1
|
||||
while lo < hi:
|
||||
|
|
@ -168,7 +190,7 @@ class PreProcessedDataset(Dataset):
|
|||
shard_idx = lo - 1
|
||||
local_idx = idx - self._shard_offsets[shard_idx]
|
||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
||||
return {
|
||||
result = {
|
||||
"input_ids": torch.from_numpy(
|
||||
shard_data["input_ids"][local_idx].astype(np.int64)
|
||||
),
|
||||
|
|
@ -189,7 +211,7 @@ class PreProcessedDataset(Dataset):
|
|||
),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
result = {
|
||||
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
|
||||
"token_type_ids": torch.from_numpy(
|
||||
self.token_type_ids[idx].astype(np.int64)
|
||||
|
|
@ -203,6 +225,7 @@ class PreProcessedDataset(Dataset):
|
|||
),
|
||||
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def preprocessed_collate_fn(batch):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,318 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
打乱并平衡预处理 .npz 分片。
|
||||
|
||||
两阶段处理:
|
||||
Phase 1: 逐分片内部打乱 → 写入临时 .npy(mmap 友好)
|
||||
Phase 2: 从临时 .npy 按比例分配到平衡的输出 .npz 分片
|
||||
|
||||
用法:
|
||||
python -m src.model.shuffle_npz \
|
||||
--input-dir /path/to/subsampled \
|
||||
--output-dir /path/to/shuffled \
|
||||
--shard-size 1000000 \
|
||||
--seed 42
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from tqdm import tqdm
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _phase1_shuffle_to_npy(
|
||||
input_dir: Path,
|
||||
split: str,
|
||||
temp_dir: Path,
|
||||
shard_sizes: List[int],
|
||||
seed: int,
|
||||
):
|
||||
"""
|
||||
Phase 1: 逐分片内部打乱,写入 .npy 文件。
|
||||
|
||||
每个分片的每个字段写入一个独立 .npy,路径格式:
|
||||
temp_dir/<field>/shard_<idx>.npy
|
||||
|
||||
峰值内存 ~10GB(单字段 int16 约 5GB + permuted copy 约 5GB)。
|
||||
"""
|
||||
metadata_path = input_dir / split / "metadata.json"
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
num_shards = metadata["num_shards"]
|
||||
rng = np.random.RandomState(seed)
|
||||
|
||||
for field in FIELDS:
|
||||
(temp_dir / field).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pbar = tqdm(
|
||||
total=num_shards,
|
||||
desc=f"Phase 1: shuffling {split}",
|
||||
unit="shard",
|
||||
)
|
||||
|
||||
for src_idx in range(num_shards):
|
||||
shard_path = input_dir / split / f"shard_{src_idx:06d}.npz"
|
||||
data = np.load(shard_path)
|
||||
n = shard_sizes[src_idx]
|
||||
perm = rng.permutation(n)
|
||||
|
||||
for field in FIELDS:
|
||||
arr = data[field].copy()
|
||||
shuffled = arr[perm]
|
||||
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
|
||||
del arr, shuffled
|
||||
gc.collect()
|
||||
|
||||
data.close()
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
|
||||
def _phase2_rebalance_to_npz(
|
||||
temp_dir: Path,
|
||||
output_dir: Path,
|
||||
split: str,
|
||||
shard_sizes: List[int],
|
||||
target_shard_size: int,
|
||||
max_seq_length: int,
|
||||
seed: int,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Phase 2: 从 mmap 的 .npy 文件中按比例分配样本到平衡的 .npz 输出分片。
|
||||
|
||||
每个输出分片从所有源分片各取 proportional chunk → concat → shuffle → save。
|
||||
内存峰值 ~3GB(一个输出缓冲 + mmap pages)。
|
||||
"""
|
||||
output_split_dir = output_dir / split
|
||||
output_split_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_samples = sum(shard_sizes)
|
||||
num_output_shards = (total_samples + target_shard_size - 1) // target_shard_size
|
||||
rng = np.random.RandomState(seed + 2)
|
||||
|
||||
logger.info(
|
||||
f"Phase 2: distributing {total_samples:,} samples into "
|
||||
f"{num_output_shards} output shards (~{target_shard_size:,} each)"
|
||||
)
|
||||
|
||||
# 打开所有源 .npy 的 mmap(读模式,零 RAM 开销)
|
||||
src_mmaps: List[Dict[str, np.ndarray]] = []
|
||||
for src_idx in range(len(shard_sizes)):
|
||||
shard_mmap = {}
|
||||
for field in FIELDS:
|
||||
npy_path = temp_dir / field / f"shard_{src_idx:06d}.npy"
|
||||
shard_mmap[field] = np.load(npy_path, mmap_mode="r")
|
||||
src_mmaps.append(shard_mmap)
|
||||
|
||||
output_shard_sizes: List[int] = []
|
||||
|
||||
pbar = tqdm(
|
||||
total=num_output_shards,
|
||||
desc=f"Phase 2: writing {split}",
|
||||
unit="shard",
|
||||
)
|
||||
|
||||
for out_j in range(num_output_shards):
|
||||
buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
|
||||
for src_i in range(len(shard_sizes)):
|
||||
s = shard_sizes[src_i]
|
||||
start = (out_j * s) // num_output_shards
|
||||
end = ((out_j + 1) * s) // num_output_shards
|
||||
if start >= end:
|
||||
continue
|
||||
|
||||
for field in FIELDS:
|
||||
chunk = src_mmaps[src_i][field][start:end].copy()
|
||||
buffers[field].append(chunk)
|
||||
|
||||
output = {}
|
||||
for field in FIELDS:
|
||||
output[field] = (
|
||||
np.concatenate(buffers[field])
|
||||
if len(buffers[field]) > 1
|
||||
else buffers[field][0]
|
||||
)
|
||||
|
||||
out_count = len(output[FIELDS[0]])
|
||||
if out_count > 1:
|
||||
perm = rng.permutation(out_count)
|
||||
for field in FIELDS:
|
||||
output[field] = output[field][perm]
|
||||
|
||||
np.savez_compressed(output_split_dir / f"shard_{out_j:06d}.npz", **output)
|
||||
output_shard_sizes.append(out_count)
|
||||
|
||||
del output, buffers
|
||||
gc.collect()
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# 关闭所有 mmap
|
||||
src_mmaps.clear()
|
||||
gc.collect()
|
||||
|
||||
return output_shard_sizes
|
||||
|
||||
|
||||
def _copy_eval(input_dir: Path, output_dir: Path):
|
||||
"""复制 eval 目录(数据量小,无需打乱)。"""
|
||||
src_eval = input_dir / "eval"
|
||||
dst_eval = output_dir / "eval"
|
||||
if not src_eval.exists():
|
||||
return
|
||||
logger.info(f"Copying eval data from {src_eval}")
|
||||
if dst_eval.exists():
|
||||
shutil.rmtree(dst_eval)
|
||||
shutil.copytree(src_eval, dst_eval)
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="打乱并平衡预处理 .npz 分片")
|
||||
parser.add_argument(
|
||||
"--input-dir", type=str, required=True, help="输入目录(含 train/ 和 eval/)"
|
||||
)
|
||||
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
||||
parser.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=1_000_000,
|
||||
help="输出分片大小(样本数),默认 100 万",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
|
||||
args = parser.parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_meta_path = input_dir / "train" / "metadata.json"
|
||||
if not train_meta_path.exists():
|
||||
console.print(f"[red]错误: {train_meta_path} 不存在[/red]")
|
||||
return
|
||||
|
||||
with open(train_meta_path) as f:
|
||||
train_meta = json.load(f)
|
||||
|
||||
max_seq_length = train_meta["max_seq_length"]
|
||||
shard_sizes: List[int] = train_meta["shard_sizes"]
|
||||
total_samples = sum(shard_sizes)
|
||||
|
||||
console.print("[bold cyan]=== 打乱并平衡预处理数据 ===[/bold cyan]")
|
||||
console.print(f"输入目录: {input_dir}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"总样本数: {total_samples:,}")
|
||||
console.print(f"源分片数: {len(shard_sizes)}")
|
||||
console.print(
|
||||
f"目标分片大小: {args.shard_size:,} "
|
||||
f"(约 {(total_samples + args.shard_size - 1) // args.shard_size} 个分片)"
|
||||
)
|
||||
console.print(f"随机种子: {args.seed}")
|
||||
console.print()
|
||||
|
||||
# 临时目录
|
||||
temp_dir = Path("/home/songsenand/tmp/shuffle_npz_temp")
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_train_dir = temp_dir / "train"
|
||||
temp_train_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# ── Phase 1 ──
|
||||
console.print("[bold]Phase 1: 逐分片打乱 → 临时 .npy[/bold]")
|
||||
_phase1_shuffle_to_npy(
|
||||
input_dir=input_dir,
|
||||
split="train",
|
||||
temp_dir=temp_train_dir,
|
||||
shard_sizes=shard_sizes,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# ── Phase 2 ──
|
||||
console.print("[bold]Phase 2: 按比例分配到平衡输出分片[/bold]")
|
||||
output_shard_sizes = _phase2_rebalance_to_npz(
|
||||
temp_dir=temp_train_dir,
|
||||
output_dir=output_dir,
|
||||
split="train",
|
||||
shard_sizes=shard_sizes,
|
||||
target_shard_size=args.shard_size,
|
||||
max_seq_length=max_seq_length,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# ── 写 metadata ──
|
||||
train_output_meta = {
|
||||
"num_samples": sum(output_shard_sizes),
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": args.shard_size,
|
||||
"num_shards": len(output_shard_sizes),
|
||||
"shard_sizes": output_shard_sizes,
|
||||
"pre_shuffled": True,
|
||||
"seed": args.seed,
|
||||
}
|
||||
out_train_dir = output_dir / "train"
|
||||
with open(out_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(train_output_meta, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# ── 复制 eval ──
|
||||
_copy_eval(input_dir, output_dir)
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
console.print("[dim]清理临时文件...[/dim]")
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
# ── 总结 ──
|
||||
console.print()
|
||||
console.print("[bold green]=== 完成 ===[/bold green]")
|
||||
console.print(
|
||||
f"train/: {sum(output_shard_sizes):,} 样本, "
|
||||
f"{len(output_shard_sizes)} 分片, "
|
||||
f"pre_shuffled=True"
|
||||
)
|
||||
|
||||
for sdir_name in ["train", "eval"]:
|
||||
sdir = output_dir / sdir_name
|
||||
if sdir.exists():
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
meta_path = sdir / "metadata.json"
|
||||
if meta_path.exists():
|
||||
with open(meta_path) as mf:
|
||||
meta = json.load(mf)
|
||||
else:
|
||||
meta = {}
|
||||
console.print(
|
||||
f" {sdir_name}/: {total_size / (1024**3):.2f} GB, "
|
||||
f"{meta.get('num_shards', '?')} shards"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -128,6 +128,8 @@ def pass2_subsample(
|
|||
eval_map: Dict[int, List[int]],
|
||||
shard_sizes: List[int],
|
||||
shard_size: int = 5_000_000,
|
||||
shuffle: bool = True,
|
||||
seed: int = 42,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。
|
||||
|
|
@ -147,12 +149,19 @@ def pass2_subsample(
|
|||
output_train_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rng = np.random.RandomState()
|
||||
# 清理旧输出分片(避免残留文件污染)
|
||||
for old_shard in sorted(output_train_dir.glob("shard_*.npz")):
|
||||
old_shard.unlink()
|
||||
logger.debug(f"Removed old shard: {old_shard.name}")
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
shuffle_rng = np.random.RandomState(seed + 1)
|
||||
remaining = dict(quotas) # {label_id: remaining_to_keep}
|
||||
|
||||
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
train_buf_count = 0
|
||||
train_shard_idx = 0
|
||||
train_shard_sizes: List[int] = []
|
||||
total_train_kept = 0
|
||||
|
||||
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
|
|
@ -247,9 +256,14 @@ def pass2_subsample(
|
|||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||
if shuffle and train_buf_count > 1:
|
||||
perm = shuffle_rng.permutation(train_buf_count)
|
||||
for f in FIELDS:
|
||||
merged[f] = merged[f][perm]
|
||||
np.savez_compressed(
|
||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||
)
|
||||
train_shard_sizes.append(train_buf_count)
|
||||
logger.debug(
|
||||
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
|
||||
)
|
||||
|
|
@ -268,9 +282,14 @@ def pass2_subsample(
|
|||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||
if shuffle and train_buf_count > 1:
|
||||
perm = shuffle_rng.permutation(train_buf_count)
|
||||
for f in FIELDS:
|
||||
merged[f] = merged[f][perm]
|
||||
np.savez_compressed(
|
||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||
)
|
||||
train_shard_sizes.append(train_buf_count)
|
||||
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
|
||||
train_shard_idx += 1
|
||||
|
||||
|
|
@ -291,6 +310,9 @@ def pass2_subsample(
|
|||
"fields": FIELDS,
|
||||
"shard_size": shard_size,
|
||||
"num_shards": train_shard_idx,
|
||||
"shard_sizes": train_shard_sizes,
|
||||
"pre_shuffled": shuffle,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
|
@ -334,6 +356,15 @@ def main():
|
|||
help="输出分片大小(样本数)",
|
||||
)
|
||||
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
|
||||
parser.add_argument(
|
||||
"--no-shuffle",
|
||||
action="store_false",
|
||||
dest="shuffle",
|
||||
help="禁用输出分片内部打乱",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="随机种子(用于标签选择 + 输出打乱)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
|
|
@ -346,6 +377,8 @@ def main():
|
|||
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
|
||||
console.print(f"目标训练集: {args.target_total:,}")
|
||||
console.print(f"评估集: {args.num_eval}")
|
||||
console.print(f"输出分片打乱: {'是' if args.shuffle else '否'}")
|
||||
console.print(f"随机种子: {args.seed}")
|
||||
console.print()
|
||||
|
||||
# ── 第 1 遍:统计 ──
|
||||
|
|
@ -382,8 +415,8 @@ def main():
|
|||
console.print()
|
||||
|
||||
# ── 抽取评估集位置 ──
|
||||
rng = np.random.RandomState()
|
||||
eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False)
|
||||
eval_rng = np.random.RandomState(args.seed + 100)
|
||||
eval_positions = eval_rng.choice(total_samples, size=args.num_eval, replace=False)
|
||||
eval_positions.sort()
|
||||
eval_map = _global_to_shard(eval_positions, shard_sizes)
|
||||
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
|
||||
|
|
@ -399,6 +432,8 @@ def main():
|
|||
eval_map,
|
||||
shard_sizes,
|
||||
args.shard_size,
|
||||
shuffle=args.shuffle,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# ── 输出总结 ──
|
||||
|
|
|
|||
|
|
@ -632,7 +632,6 @@ class Trainer:
|
|||
|
||||
def _create_progress(self) -> Progress:
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
|
|
@ -725,11 +724,7 @@ class Trainer:
|
|||
progress.reset(
|
||||
batch_task,
|
||||
total=steps_per_epoch,
|
||||
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}",
|
||||
)
|
||||
progress.update(
|
||||
epoch_task,
|
||||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}",
|
||||
description=f"[green]Epoch {epoch + 1} Step 0/{steps_per_epoch}",
|
||||
)
|
||||
|
||||
epoch_step = 0
|
||||
|
|
@ -762,7 +757,8 @@ class Trainer:
|
|||
progress.update(
|
||||
batch_task,
|
||||
advance=1,
|
||||
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}"
|
||||
description=f"[green]Epoch {epoch + 1} "
|
||||
f"Step {epoch_step + 1}/{steps_per_epoch}"
|
||||
f" | Loss: {loss:.4f}"
|
||||
f" | LR: {current_lr:.2e}",
|
||||
)
|
||||
|
|
@ -816,7 +812,7 @@ class Trainer:
|
|||
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
||||
)
|
||||
|
||||
progress.console.log(log_text)
|
||||
progress.log(log_text)
|
||||
|
||||
# 重置累积指标
|
||||
accumulated_loss = 0.0
|
||||
|
|
@ -1260,16 +1256,31 @@ def train(
|
|||
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
||||
|
||||
if is_train_preprocessed:
|
||||
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
|
||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
||||
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=2)
|
||||
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
|
||||
# 预打乱数据不需要 DataLoader 的 RandomSampler(避免跨分片解压抖动)
|
||||
shuffle_train = not pre_shuffled
|
||||
if max_iter_length > 0:
|
||||
max_samples_per_epoch = max_iter_length
|
||||
capped_samples = min(len(train_dataset), max_samples_per_epoch)
|
||||
else:
|
||||
capped_samples = len(train_dataset)
|
||||
total_steps = (capped_samples // batch_size) * num_epochs
|
||||
train_num_workers = min(num_workers, 1)
|
||||
logger.info(
|
||||
f"Preprocessed dataset: {len(train_dataset):,} samples, "
|
||||
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
|
||||
f"workers={train_num_workers}, steps={total_steps:,}"
|
||||
)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
num_workers=train_num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=True,
|
||||
shuffle=shuffle_train,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
||||
config_table.add_row("数据", "预打乱", str(pre_shuffled))
|
||||
else:
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=train_data_path,
|
||||
|
|
@ -1296,7 +1307,7 @@ def train(
|
|||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2,
|
||||
num_workers=0,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue