feat(data-preprocess): 预处理数据预打乱以提升训练效率
This commit is contained in:
parent
0862b5b8fc
commit
722912f296
|
|
@ -0,0 +1,229 @@
|
||||||
|
# 预处理数据预打乱方案
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
用 CPU 机时预打乱数据,让训练时直接用 `shuffle=False` 顺序读取,消除跨分片缓存抖动和 CPU 利用率低的问题。
|
||||||
|
|
||||||
|
## 改动清单
|
||||||
|
|
||||||
|
### 1. `src/model/subsample.py` — 默认开启输出打乱
|
||||||
|
|
||||||
|
**函数签名变更:**
|
||||||
|
```python
|
||||||
|
def pass2_subsample(
|
||||||
|
...,
|
||||||
|
shuffle: bool = True, # 新增
|
||||||
|
seed: int = 42, # 新增
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
```
|
||||||
|
|
||||||
|
**改动点:**
|
||||||
|
- `rng = np.random.RandomState()` → `rng = np.random.RandomState(seed)`
|
||||||
|
- 新增 `shuffle_rng = np.random.RandomState(seed + 1)` 用于输出打乱
|
||||||
|
- 在两次 `np.savez_compressed` 写入前(line ~251 和 line ~273),插入:
|
||||||
|
```python
|
||||||
|
if shuffle and train_buf_count > 1:
|
||||||
|
perm = shuffle_rng.permutation(train_buf_count)
|
||||||
|
for f in FIELDS:
|
||||||
|
merged[f] = merged[f][perm]
|
||||||
|
```
|
||||||
|
- main() 中新增参数:
|
||||||
|
```python
|
||||||
|
parser.add_argument("--no-shuffle", action="store_false", dest="shuffle",
|
||||||
|
help="禁用输出分片内部打乱")
|
||||||
|
parser.add_argument("--seed", type=int, default=42,
|
||||||
|
help="随机种子(用于选择+打乱)")
|
||||||
|
```
|
||||||
|
- 调用 `pass2_subsample` 时传入 `shuffle=args.shuffle, seed=args.seed`
|
||||||
|
|
||||||
|
**metadata.json 新增字段:**
|
||||||
|
```json
|
||||||
|
"pre_shuffled": true,
|
||||||
|
"seed": 42
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. 新建 `src/model/shuffle_npz.py` — 平衡+打乱已有数据
|
||||||
|
|
||||||
|
处理流程:
|
||||||
|
|
||||||
|
```
|
||||||
|
Phase 1: npz → 逐字段打乱 → .npy(mmap 友好)
|
||||||
|
输入: 19个不平衡 .npz 分片(100M样本,~10GB 压缩)
|
||||||
|
输出: 114个临时 .npy 文件(6字段 × 19分片,~144GB 未压缩)
|
||||||
|
内存峰值: ~10GB(单字段 int16 最大 ~5GB + permuted copy ~5GB)
|
||||||
|
耗时: ~15-20分钟(解压+写入)
|
||||||
|
|
||||||
|
Phase 2: .npy → 平衡分配 → .npz
|
||||||
|
输入: Phase 1 的 .npy 文件(mmap 模式)
|
||||||
|
输出: 100个平衡 .npz 分片(每片100万样本,~100MB/片压缩后)
|
||||||
|
每个输出分片 = 从19个源各取比例份额 → concatenate → shuffle → save
|
||||||
|
内存峰值: ~3GB(1个输出缓冲 + mmap pages)
|
||||||
|
耗时: ~10-15分钟(mmap读取+压缩写入)
|
||||||
|
|
||||||
|
总计: ~30-40分钟,峰值内存 ~10GB
|
||||||
|
```
|
||||||
|
|
||||||
|
**磁盘需求:**
|
||||||
|
- 临时 .npy 文件(Phase 1→2 中间产物):~144GB
|
||||||
|
- 最终输出 .npz:~10GB
|
||||||
|
- 临时文件在 Phase 2 完成后自动删除
|
||||||
|
|
||||||
|
**用法:**
|
||||||
|
```bash
|
||||||
|
python -m src.model.shuffle_npz \
|
||||||
|
--input-dir /home/songsenand/DataSet/SubPro \
|
||||||
|
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
|
||||||
|
--shard-size 1000000 \
|
||||||
|
--seed 42
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键实现:**
|
||||||
|
|
||||||
|
Phase 1 — 逐字段加载+打乱:
|
||||||
|
```python
|
||||||
|
for src_idx in range(num_shards):
|
||||||
|
data = np.load(shard_path) # lazy NpzFile
|
||||||
|
n = shard_sizes[src_idx]
|
||||||
|
perm = rng.permutation(n)
|
||||||
|
|
||||||
|
for field in FIELDS:
|
||||||
|
arr = data[field].copy() # ~5GB peak (input_ids)
|
||||||
|
shuffled = arr[perm] # ~5GB temp
|
||||||
|
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
|
||||||
|
del arr, shuffled
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
data.close()
|
||||||
|
```
|
||||||
|
|
||||||
|
Phase 2 — 平衡分配:
|
||||||
|
```python
|
||||||
|
# 打开所有源 mmap(读模式,零内存)
|
||||||
|
src_mmaps = [
|
||||||
|
{f: np.load(temp_dir / f / f"shard_{i:06d}.npy", mmap_mode='r') for f in FIELDS}
|
||||||
|
for i in range(num_shards)
|
||||||
|
]
|
||||||
|
|
||||||
|
for out_j in range(num_output_shards):
|
||||||
|
buffers = {f: [] for f in FIELDS}
|
||||||
|
|
||||||
|
for src_i in range(num_shards):
|
||||||
|
s = shard_sizes[src_i]
|
||||||
|
start = (out_j * s) // num_output_shards
|
||||||
|
end = ((out_j + 1) * s) // num_output_shards
|
||||||
|
if start >= end:
|
||||||
|
continue
|
||||||
|
for f in FIELDS:
|
||||||
|
chunk = src_mmaps[src_i][f][start:end].copy()
|
||||||
|
buffers[f].append(chunk)
|
||||||
|
|
||||||
|
output = {f: np.concatenate(buffers[f]) for f in FIELDS}
|
||||||
|
# 额外打乱(跨源混合)
|
||||||
|
perm = rng.permutation(len(output[FIELDS[0]]))
|
||||||
|
for f in FIELDS:
|
||||||
|
output[f] = output[f][perm]
|
||||||
|
|
||||||
|
np.savez_compressed(output_dir / f"shard_{out_j:06d}.npz", **output)
|
||||||
|
del output, buffers, perm
|
||||||
|
gc.collect()
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出 metadata.json:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"num_samples": 99998406,
|
||||||
|
"max_seq_length": 128,
|
||||||
|
"dtype": "int16",
|
||||||
|
"fields": [...],
|
||||||
|
"shard_size": 1000000,
|
||||||
|
"num_shards": 100,
|
||||||
|
"shard_sizes": [1000000, ..., 998406],
|
||||||
|
"pre_shuffled": true,
|
||||||
|
"seed": 42
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**eval 目录处理:** 如果 `--input-dir/eval/` 存在,直接复制到 `--output-dir/eval/`(eval 数据量小,不需要打乱)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. `src/model/trainer.py` — 预处理数据禁用 shuffle
|
||||||
|
|
||||||
|
**改动点(train 函数,line ~1258-1272):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
if is_train_preprocessed:
|
||||||
|
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
|
||||||
|
# pre_shuffled 数据不需要 DataLoader 的 RandomSampler
|
||||||
|
shuffle_train = not train_dataset.metadata.get("pre_shuffled", False)
|
||||||
|
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
||||||
|
# 支持 max_iter_length 限制总步数
|
||||||
|
if max_iter_length > 0:
|
||||||
|
max_steps_per_epoch = max_iter_length // batch_size
|
||||||
|
total_steps = min(total_steps, max_steps_per_epoch * num_epochs)
|
||||||
|
train_num_workers = min(num_workers, 1)
|
||||||
|
train_dataloader = create_dataloader(
|
||||||
|
dataset=train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=train_num_workers,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
shuffle=shuffle_train, # 预打乱数据不 shuffle
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**eval DataLoader(line ~1295-1303):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
if is_eval_preprocessed:
|
||||||
|
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
|
||||||
|
eval_dataloader = create_dataloader(
|
||||||
|
dataset=eval_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=0, # eval 数据小,单进程足够
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
shuffle=False, # eval 不需要打乱
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**`create_dataloader` 函数(line ~1076-1114):** 无需改动,`shuffle` 参数已透传。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. `src/model/preprocessed_dataset.py`
|
||||||
|
|
||||||
|
现有代码无需修改。`PreProcessedDataset` 已经可以正确处理 `shuffle=False` 的情况(PyTorch 的 `SequentialSampler` 会按 0..N-1 顺序读取)。
|
||||||
|
|
||||||
|
`metadata["pre_shuffled"]` 字段由 subsample.py 和 shuffle_npz.py 在写入 metadata.json 时添加,训练代码读取判断即可。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 执行顺序
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Step 1: 打乱并平衡已有的 100M 数据集
|
||||||
|
python -m src.model.shuffle_npz \
|
||||||
|
--input-dir /home/songsenand/DataSet/SubPro \
|
||||||
|
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
|
||||||
|
--shard-size 1000000 \
|
||||||
|
--seed 42
|
||||||
|
|
||||||
|
# Step 2: 用新数据训练(数据已打乱,顺序读取即可)
|
||||||
|
uv run train-model train \
|
||||||
|
--train-data-path /home/songsenand/DataSet/SubPro_Shuffled/train \
|
||||||
|
--eval-data-path /home/songsenand/DataSet/SubPro_Shuffled/eval \
|
||||||
|
-b 16 \
|
||||||
|
-o ~/tmp \
|
||||||
|
--eval-frequency 20 \
|
||||||
|
--save-frequency 40
|
||||||
|
```
|
||||||
|
|
||||||
|
## 预期效果
|
||||||
|
|
||||||
|
| 改动前 | 改动后 |
|
||||||
|
|---|---|
|
||||||
|
| 每 batch 跨 13 个 shard | 顺序读,1 个 shard 在缓存中 |
|
||||||
|
| 每 batch 数据加载 2-3 分钟 | ~0.1-0.5 秒(纯 mmap/memory) |
|
||||||
|
| CPU 利用率 10% | 正常(训练计算是瓶颈) |
|
||||||
|
| 内存 40GB+ | <20GB(单 shard 1M 样本 ≈ 1.4GB) |
|
||||||
|
|
@ -351,7 +351,7 @@ def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_expert
|
||||||
idx, k_idx = mask.nonzero(as_tuple=True)
|
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
if idx.numel() > 0:
|
if idx.numel() > 0:
|
||||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||||
output.index_add_(0, idx, w * experts[e](x_flat[idx]))
|
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -411,7 +411,11 @@ class MoELayer(nn.Module):
|
||||||
idx, k_idx = mask.nonzero(as_tuple=True)
|
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
if idx.numel() > 0:
|
if idx.numel() > 0:
|
||||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||||
out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx]))
|
out_flat.index_add_(
|
||||||
|
0,
|
||||||
|
idx,
|
||||||
|
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
elif self.moe_mode == "sparse_allow_graph":
|
elif self.moe_mode == "sparse_allow_graph":
|
||||||
out_flat = _sparse_moe_dispatch(
|
out_flat = _sparse_moe_dispatch(
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ def main():
|
||||||
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
||||||
if dataset._is_sharded:
|
if dataset._is_sharded:
|
||||||
console.print(
|
console.print(
|
||||||
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
|
f" 分片数: {dataset._num_shards}, 每分片: {min(dataset._shard_sizes):,} - {max(dataset._shard_sizes):,} 样本"
|
||||||
)
|
)
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,9 @@
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
@ -132,6 +135,10 @@ def collect_samples(
|
||||||
actual_count = min(total, num_samples)
|
actual_count = min(total, num_samples)
|
||||||
num_shards = shard_idx
|
num_shards = shard_idx
|
||||||
|
|
||||||
|
shard_sizes = [shard_size] * (num_shards - 1)
|
||||||
|
if num_shards > 0:
|
||||||
|
shard_sizes.append(actual_count - sum(shard_sizes))
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"num_samples": actual_count,
|
"num_samples": actual_count,
|
||||||
"max_seq_length": max_seq_length,
|
"max_seq_length": max_seq_length,
|
||||||
|
|
@ -139,6 +146,7 @@ def collect_samples(
|
||||||
"fields": FIELDS,
|
"fields": FIELDS,
|
||||||
"shard_size": shard_size,
|
"shard_size": shard_size,
|
||||||
"num_shards": num_shards,
|
"num_shards": num_shards,
|
||||||
|
"shard_sizes": shard_sizes,
|
||||||
}
|
}
|
||||||
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
||||||
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
||||||
|
|
@ -329,7 +337,97 @@ def main():
|
||||||
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
||||||
|
|
||||||
|
|
||||||
app = main
|
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
|
||||||
|
"""
|
||||||
|
将过大的 .npz 分片拆分到目标大小。
|
||||||
|
|
||||||
|
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。
|
||||||
|
建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
input_path = Path(input_dir)
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
metadata_path = input_path / "metadata.json"
|
||||||
|
if not metadata_path.exists():
|
||||||
|
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
|
||||||
|
|
||||||
|
with open(metadata_path) as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
shard_files = sorted(input_path.glob("shard_*.npz"))
|
||||||
|
console = Console()
|
||||||
|
console.print(
|
||||||
|
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
|
||||||
|
)
|
||||||
|
|
||||||
|
global_shard_idx = 0
|
||||||
|
all_shard_sizes: List[int] = []
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
for sf in tqdm(shard_files, desc="Shards"):
|
||||||
|
t0 = time.time()
|
||||||
|
data = np.load(str(sf))
|
||||||
|
n_samples = data[list(data.keys())[0]].shape[0]
|
||||||
|
|
||||||
|
for start in range(0, n_samples, target_size):
|
||||||
|
end = min(start + target_size, n_samples)
|
||||||
|
chunk_data: Dict[str, np.ndarray] = {}
|
||||||
|
for field in data.keys():
|
||||||
|
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
|
||||||
|
|
||||||
|
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
|
||||||
|
np.savez_compressed(out_file, **chunk_data)
|
||||||
|
chunk_size = end - start
|
||||||
|
all_shard_sizes.append(chunk_size)
|
||||||
|
total_samples += chunk_size
|
||||||
|
del chunk_data
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
global_shard_idx += 1
|
||||||
|
|
||||||
|
data.close()
|
||||||
|
del data
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
metadata["num_samples"] = total_samples
|
||||||
|
metadata["num_shards"] = global_shard_idx
|
||||||
|
metadata["shard_sizes"] = all_shard_sizes
|
||||||
|
metadata.pop("shard_size", None)
|
||||||
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch():
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
|
||||||
|
import argparse as ap
|
||||||
|
|
||||||
|
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
|
||||||
|
p.add_argument(
|
||||||
|
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
|
||||||
|
)
|
||||||
|
p.add_argument("--output-dir", required=True, help="输出目录")
|
||||||
|
p.add_argument(
|
||||||
|
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
|
||||||
|
)
|
||||||
|
args = p.parse_args(sys.argv[2:])
|
||||||
|
resplit_shards(args.input_dir, args.output_dir, args.target_size)
|
||||||
|
else:
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
app = _dispatch
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
_dispatch()
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
|
|
@ -47,6 +48,13 @@ def _read_shard_size(npz_path: Path) -> int:
|
||||||
return header["shape"][0]
|
return header["shape"][0]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_offsets(sizes: List[int]) -> List[int]:
|
||||||
|
offsets = [0]
|
||||||
|
for s in sizes:
|
||||||
|
offsets.append(offsets[-1] + s)
|
||||||
|
return offsets
|
||||||
|
|
||||||
|
|
||||||
def is_preprocessed_data(path: str) -> bool:
|
def is_preprocessed_data(path: str) -> bool:
|
||||||
"""判断路径是否为预处理数据目录"""
|
"""判断路径是否为预处理数据目录"""
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
|
|
@ -94,38 +102,48 @@ class PreProcessedDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_dir: str, max_cache_shards: int = 1):
|
def __init__(self, data_dir: str, max_cache_shards: int = 1):
|
||||||
|
t_start = time.perf_counter()
|
||||||
self.data_dir = Path(data_dir)
|
self.data_dir = Path(data_dir)
|
||||||
|
|
||||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||||
self.metadata = json.load(f)
|
self.metadata = json.load(f)
|
||||||
|
|
||||||
self.max_seq_length = self.metadata["max_seq_length"]
|
self.max_seq_length = self.metadata["max_seq_length"]
|
||||||
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
|
||||||
|
|
||||||
if self._shard_size is not None:
|
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
||||||
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
if shard_files:
|
||||||
self._num_shards = len(shard_files)
|
self._num_shards = len(shard_files)
|
||||||
if self._num_shards == 0:
|
if (
|
||||||
raise FileNotFoundError(
|
"shard_sizes" in self.metadata
|
||||||
f"No shard_*.npz files found in {self.data_dir}"
|
and len(self.metadata["shard_sizes"]) == self._num_shards
|
||||||
|
):
|
||||||
|
self._shard_sizes = self.metadata["shard_sizes"]
|
||||||
|
logger.info(
|
||||||
|
f"Using shard_sizes from metadata.json ({self._num_shards} shards)"
|
||||||
)
|
)
|
||||||
self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files]
|
else:
|
||||||
self._shard_offsets = [0]
|
t_scan = time.perf_counter()
|
||||||
for s in self._shard_sizes:
|
self._shard_sizes = [_read_shard_size(sf) for sf in shard_files]
|
||||||
self._shard_offsets.append(self._shard_offsets[-1] + s)
|
logger.info(
|
||||||
|
f"Scanned {self._num_shards} shard headers in "
|
||||||
|
f"{time.perf_counter() - t_scan:.2f}s"
|
||||||
|
)
|
||||||
|
self._shard_offsets = _build_offsets(self._shard_sizes)
|
||||||
self.num_samples = self._shard_offsets[-1]
|
self.num_samples = self._shard_offsets[-1]
|
||||||
self._is_sharded = True
|
self._is_sharded = True
|
||||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
||||||
f"{self._num_shards} shards"
|
f"{self._num_shards} shards, "
|
||||||
|
f"init in {time.perf_counter() - t_start:.2f}s"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.num_samples = self.metadata["num_samples"]
|
self.num_samples = self.metadata["num_samples"]
|
||||||
self._is_sharded = False
|
self._is_sharded = False
|
||||||
self._load_single_files()
|
self._load_single_files()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)"
|
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap), "
|
||||||
|
f"init in {time.perf_counter() - t_start:.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_single_files(self):
|
def _load_single_files(self):
|
||||||
|
|
@ -157,6 +175,10 @@ class PreProcessedDataset(Dataset):
|
||||||
f"Index {idx} out of range for dataset with {self.num_samples} samples"
|
f"Index {idx} out of range for dataset with {self.num_samples} samples"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not hasattr(self, "_first_access_logged"):
|
||||||
|
self._first_access_logged = True
|
||||||
|
logger.info("First __getitem__ call (initial shard load may be slow)")
|
||||||
|
|
||||||
if self._is_sharded:
|
if self._is_sharded:
|
||||||
lo, hi = 0, len(self._shard_offsets) - 1
|
lo, hi = 0, len(self._shard_offsets) - 1
|
||||||
while lo < hi:
|
while lo < hi:
|
||||||
|
|
@ -168,7 +190,7 @@ class PreProcessedDataset(Dataset):
|
||||||
shard_idx = lo - 1
|
shard_idx = lo - 1
|
||||||
local_idx = idx - self._shard_offsets[shard_idx]
|
local_idx = idx - self._shard_offsets[shard_idx]
|
||||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
shard_data = self._cache.get(shard_idx, self._load_shard)
|
||||||
return {
|
result = {
|
||||||
"input_ids": torch.from_numpy(
|
"input_ids": torch.from_numpy(
|
||||||
shard_data["input_ids"][local_idx].astype(np.int64)
|
shard_data["input_ids"][local_idx].astype(np.int64)
|
||||||
),
|
),
|
||||||
|
|
@ -189,7 +211,7 @@ class PreProcessedDataset(Dataset):
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
result = {
|
||||||
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
|
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
|
||||||
"token_type_ids": torch.from_numpy(
|
"token_type_ids": torch.from_numpy(
|
||||||
self.token_type_ids[idx].astype(np.int64)
|
self.token_type_ids[idx].astype(np.int64)
|
||||||
|
|
@ -203,6 +225,7 @@ class PreProcessedDataset(Dataset):
|
||||||
),
|
),
|
||||||
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
|
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def preprocessed_collate_fn(batch):
|
def preprocessed_collate_fn(batch):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,318 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
打乱并平衡预处理 .npz 分片。
|
||||||
|
|
||||||
|
两阶段处理:
|
||||||
|
Phase 1: 逐分片内部打乱 → 写入临时 .npy(mmap 友好)
|
||||||
|
Phase 2: 从临时 .npy 按比例分配到平衡的输出 .npz 分片
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python -m src.model.shuffle_npz \
|
||||||
|
--input-dir /path/to/subsampled \
|
||||||
|
--output-dir /path/to/shuffled \
|
||||||
|
--shard-size 1000000 \
|
||||||
|
--seed 42
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
from rich.console import Console
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
FIELDS = [
|
||||||
|
"input_ids",
|
||||||
|
"token_type_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"labels",
|
||||||
|
"history_slot_ids",
|
||||||
|
"pinyin_ids",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _phase1_shuffle_to_npy(
|
||||||
|
input_dir: Path,
|
||||||
|
split: str,
|
||||||
|
temp_dir: Path,
|
||||||
|
shard_sizes: List[int],
|
||||||
|
seed: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Phase 1: 逐分片内部打乱,写入 .npy 文件。
|
||||||
|
|
||||||
|
每个分片的每个字段写入一个独立 .npy,路径格式:
|
||||||
|
temp_dir/<field>/shard_<idx>.npy
|
||||||
|
|
||||||
|
峰值内存 ~10GB(单字段 int16 约 5GB + permuted copy 约 5GB)。
|
||||||
|
"""
|
||||||
|
metadata_path = input_dir / split / "metadata.json"
|
||||||
|
with open(metadata_path) as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
num_shards = metadata["num_shards"]
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
|
||||||
|
for field in FIELDS:
|
||||||
|
(temp_dir / field).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pbar = tqdm(
|
||||||
|
total=num_shards,
|
||||||
|
desc=f"Phase 1: shuffling {split}",
|
||||||
|
unit="shard",
|
||||||
|
)
|
||||||
|
|
||||||
|
for src_idx in range(num_shards):
|
||||||
|
shard_path = input_dir / split / f"shard_{src_idx:06d}.npz"
|
||||||
|
data = np.load(shard_path)
|
||||||
|
n = shard_sizes[src_idx]
|
||||||
|
perm = rng.permutation(n)
|
||||||
|
|
||||||
|
for field in FIELDS:
|
||||||
|
arr = data[field].copy()
|
||||||
|
shuffled = arr[perm]
|
||||||
|
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
|
||||||
|
del arr, shuffled
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
data.close()
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _phase2_rebalance_to_npz(
|
||||||
|
temp_dir: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
split: str,
|
||||||
|
shard_sizes: List[int],
|
||||||
|
target_shard_size: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
seed: int,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Phase 2: 从 mmap 的 .npy 文件中按比例分配样本到平衡的 .npz 输出分片。
|
||||||
|
|
||||||
|
每个输出分片从所有源分片各取 proportional chunk → concat → shuffle → save。
|
||||||
|
内存峰值 ~3GB(一个输出缓冲 + mmap pages)。
|
||||||
|
"""
|
||||||
|
output_split_dir = output_dir / split
|
||||||
|
output_split_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
total_samples = sum(shard_sizes)
|
||||||
|
num_output_shards = (total_samples + target_shard_size - 1) // target_shard_size
|
||||||
|
rng = np.random.RandomState(seed + 2)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Phase 2: distributing {total_samples:,} samples into "
|
||||||
|
f"{num_output_shards} output shards (~{target_shard_size:,} each)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打开所有源 .npy 的 mmap(读模式,零 RAM 开销)
|
||||||
|
src_mmaps: List[Dict[str, np.ndarray]] = []
|
||||||
|
for src_idx in range(len(shard_sizes)):
|
||||||
|
shard_mmap = {}
|
||||||
|
for field in FIELDS:
|
||||||
|
npy_path = temp_dir / field / f"shard_{src_idx:06d}.npy"
|
||||||
|
shard_mmap[field] = np.load(npy_path, mmap_mode="r")
|
||||||
|
src_mmaps.append(shard_mmap)
|
||||||
|
|
||||||
|
output_shard_sizes: List[int] = []
|
||||||
|
|
||||||
|
pbar = tqdm(
|
||||||
|
total=num_output_shards,
|
||||||
|
desc=f"Phase 2: writing {split}",
|
||||||
|
unit="shard",
|
||||||
|
)
|
||||||
|
|
||||||
|
for out_j in range(num_output_shards):
|
||||||
|
buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||||
|
|
||||||
|
for src_i in range(len(shard_sizes)):
|
||||||
|
s = shard_sizes[src_i]
|
||||||
|
start = (out_j * s) // num_output_shards
|
||||||
|
end = ((out_j + 1) * s) // num_output_shards
|
||||||
|
if start >= end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for field in FIELDS:
|
||||||
|
chunk = src_mmaps[src_i][field][start:end].copy()
|
||||||
|
buffers[field].append(chunk)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
for field in FIELDS:
|
||||||
|
output[field] = (
|
||||||
|
np.concatenate(buffers[field])
|
||||||
|
if len(buffers[field]) > 1
|
||||||
|
else buffers[field][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
out_count = len(output[FIELDS[0]])
|
||||||
|
if out_count > 1:
|
||||||
|
perm = rng.permutation(out_count)
|
||||||
|
for field in FIELDS:
|
||||||
|
output[field] = output[field][perm]
|
||||||
|
|
||||||
|
np.savez_compressed(output_split_dir / f"shard_{out_j:06d}.npz", **output)
|
||||||
|
output_shard_sizes.append(out_count)
|
||||||
|
|
||||||
|
del output, buffers
|
||||||
|
gc.collect()
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
# 关闭所有 mmap
|
||||||
|
src_mmaps.clear()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return output_shard_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_eval(input_dir: Path, output_dir: Path):
|
||||||
|
"""复制 eval 目录(数据量小,无需打乱)。"""
|
||||||
|
src_eval = input_dir / "eval"
|
||||||
|
dst_eval = output_dir / "eval"
|
||||||
|
if not src_eval.exists():
|
||||||
|
return
|
||||||
|
logger.info(f"Copying eval data from {src_eval}")
|
||||||
|
if dst_eval.exists():
|
||||||
|
shutil.rmtree(dst_eval)
|
||||||
|
shutil.copytree(src_eval, dst_eval)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="打乱并平衡预处理 .npz 分片")
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-dir", type=str, required=True, help="输入目录(含 train/ 和 eval/)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
||||||
|
parser.add_argument(
|
||||||
|
"--shard-size",
|
||||||
|
type=int,
|
||||||
|
default=1_000_000,
|
||||||
|
help="输出分片大小(样本数),默认 100 万",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
input_dir = Path(args.input_dir)
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_meta_path = input_dir / "train" / "metadata.json"
|
||||||
|
if not train_meta_path.exists():
|
||||||
|
console.print(f"[red]错误: {train_meta_path} 不存在[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(train_meta_path) as f:
|
||||||
|
train_meta = json.load(f)
|
||||||
|
|
||||||
|
max_seq_length = train_meta["max_seq_length"]
|
||||||
|
shard_sizes: List[int] = train_meta["shard_sizes"]
|
||||||
|
total_samples = sum(shard_sizes)
|
||||||
|
|
||||||
|
console.print("[bold cyan]=== 打乱并平衡预处理数据 ===[/bold cyan]")
|
||||||
|
console.print(f"输入目录: {input_dir}")
|
||||||
|
console.print(f"输出目录: {output_dir}")
|
||||||
|
console.print(f"总样本数: {total_samples:,}")
|
||||||
|
console.print(f"源分片数: {len(shard_sizes)}")
|
||||||
|
console.print(
|
||||||
|
f"目标分片大小: {args.shard_size:,} "
|
||||||
|
f"(约 {(total_samples + args.shard_size - 1) // args.shard_size} 个分片)"
|
||||||
|
)
|
||||||
|
console.print(f"随机种子: {args.seed}")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# 临时目录
|
||||||
|
temp_dir = Path("/home/songsenand/tmp/shuffle_npz_temp")
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_train_dir = temp_dir / "train"
|
||||||
|
temp_train_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── Phase 1 ──
|
||||||
|
console.print("[bold]Phase 1: 逐分片打乱 → 临时 .npy[/bold]")
|
||||||
|
_phase1_shuffle_to_npy(
|
||||||
|
input_dir=input_dir,
|
||||||
|
split="train",
|
||||||
|
temp_dir=temp_train_dir,
|
||||||
|
shard_sizes=shard_sizes,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Phase 2 ──
|
||||||
|
console.print("[bold]Phase 2: 按比例分配到平衡输出分片[/bold]")
|
||||||
|
output_shard_sizes = _phase2_rebalance_to_npz(
|
||||||
|
temp_dir=temp_train_dir,
|
||||||
|
output_dir=output_dir,
|
||||||
|
split="train",
|
||||||
|
shard_sizes=shard_sizes,
|
||||||
|
target_shard_size=args.shard_size,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 写 metadata ──
|
||||||
|
train_output_meta = {
|
||||||
|
"num_samples": sum(output_shard_sizes),
|
||||||
|
"max_seq_length": max_seq_length,
|
||||||
|
"dtype": "int16",
|
||||||
|
"fields": FIELDS,
|
||||||
|
"shard_size": args.shard_size,
|
||||||
|
"num_shards": len(output_shard_sizes),
|
||||||
|
"shard_sizes": output_shard_sizes,
|
||||||
|
"pre_shuffled": True,
|
||||||
|
"seed": args.seed,
|
||||||
|
}
|
||||||
|
out_train_dir = output_dir / "train"
|
||||||
|
with open(out_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(train_output_meta, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# ── 复制 eval ──
|
||||||
|
_copy_eval(input_dir, output_dir)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
console.print("[dim]清理临时文件...[/dim]")
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
# ── 总结 ──
|
||||||
|
console.print()
|
||||||
|
console.print("[bold green]=== 完成 ===[/bold green]")
|
||||||
|
console.print(
|
||||||
|
f"train/: {sum(output_shard_sizes):,} 样本, "
|
||||||
|
f"{len(output_shard_sizes)} 分片, "
|
||||||
|
f"pre_shuffled=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
for sdir_name in ["train", "eval"]:
|
||||||
|
sdir = output_dir / sdir_name
|
||||||
|
if sdir.exists():
|
||||||
|
total_size = sum(
|
||||||
|
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
|
||||||
|
)
|
||||||
|
meta_path = sdir / "metadata.json"
|
||||||
|
if meta_path.exists():
|
||||||
|
with open(meta_path) as mf:
|
||||||
|
meta = json.load(mf)
|
||||||
|
else:
|
||||||
|
meta = {}
|
||||||
|
console.print(
|
||||||
|
f" {sdir_name}/: {total_size / (1024**3):.2f} GB, "
|
||||||
|
f"{meta.get('num_shards', '?')} shards"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -128,6 +128,8 @@ def pass2_subsample(
|
||||||
eval_map: Dict[int, List[int]],
|
eval_map: Dict[int, List[int]],
|
||||||
shard_sizes: List[int],
|
shard_sizes: List[int],
|
||||||
shard_size: int = 5_000_000,
|
shard_size: int = 5_000_000,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed: int = 42,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。
|
第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。
|
||||||
|
|
@ -147,12 +149,19 @@ def pass2_subsample(
|
||||||
output_train_dir.mkdir(parents=True, exist_ok=True)
|
output_train_dir.mkdir(parents=True, exist_ok=True)
|
||||||
output_eval_dir.mkdir(parents=True, exist_ok=True)
|
output_eval_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
rng = np.random.RandomState()
|
# 清理旧输出分片(避免残留文件污染)
|
||||||
|
for old_shard in sorted(output_train_dir.glob("shard_*.npz")):
|
||||||
|
old_shard.unlink()
|
||||||
|
logger.debug(f"Removed old shard: {old_shard.name}")
|
||||||
|
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
shuffle_rng = np.random.RandomState(seed + 1)
|
||||||
remaining = dict(quotas) # {label_id: remaining_to_keep}
|
remaining = dict(quotas) # {label_id: remaining_to_keep}
|
||||||
|
|
||||||
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||||
train_buf_count = 0
|
train_buf_count = 0
|
||||||
train_shard_idx = 0
|
train_shard_idx = 0
|
||||||
|
train_shard_sizes: List[int] = []
|
||||||
total_train_kept = 0
|
total_train_kept = 0
|
||||||
|
|
||||||
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||||
|
|
@ -247,9 +256,14 @@ def pass2_subsample(
|
||||||
merged = {}
|
merged = {}
|
||||||
for f in FIELDS:
|
for f in FIELDS:
|
||||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||||
|
if shuffle and train_buf_count > 1:
|
||||||
|
perm = shuffle_rng.permutation(train_buf_count)
|
||||||
|
for f in FIELDS:
|
||||||
|
merged[f] = merged[f][perm]
|
||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||||
)
|
)
|
||||||
|
train_shard_sizes.append(train_buf_count)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
|
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
|
||||||
)
|
)
|
||||||
|
|
@ -268,9 +282,14 @@ def pass2_subsample(
|
||||||
merged = {}
|
merged = {}
|
||||||
for f in FIELDS:
|
for f in FIELDS:
|
||||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||||
|
if shuffle and train_buf_count > 1:
|
||||||
|
perm = shuffle_rng.permutation(train_buf_count)
|
||||||
|
for f in FIELDS:
|
||||||
|
merged[f] = merged[f][perm]
|
||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||||
)
|
)
|
||||||
|
train_shard_sizes.append(train_buf_count)
|
||||||
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
|
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
|
||||||
train_shard_idx += 1
|
train_shard_idx += 1
|
||||||
|
|
||||||
|
|
@ -291,6 +310,9 @@ def pass2_subsample(
|
||||||
"fields": FIELDS,
|
"fields": FIELDS,
|
||||||
"shard_size": shard_size,
|
"shard_size": shard_size,
|
||||||
"num_shards": train_shard_idx,
|
"num_shards": train_shard_idx,
|
||||||
|
"shard_sizes": train_shard_sizes,
|
||||||
|
"pre_shuffled": shuffle,
|
||||||
|
"seed": seed,
|
||||||
}
|
}
|
||||||
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
|
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
@ -334,6 +356,15 @@ def main():
|
||||||
help="输出分片大小(样本数)",
|
help="输出分片大小(样本数)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
|
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-shuffle",
|
||||||
|
action="store_false",
|
||||||
|
dest="shuffle",
|
||||||
|
help="禁用输出分片内部打乱",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=42, help="随机种子(用于标签选择 + 输出打乱)"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
input_dir = Path(args.input_dir)
|
input_dir = Path(args.input_dir)
|
||||||
|
|
@ -346,6 +377,8 @@ def main():
|
||||||
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
|
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
|
||||||
console.print(f"目标训练集: {args.target_total:,}")
|
console.print(f"目标训练集: {args.target_total:,}")
|
||||||
console.print(f"评估集: {args.num_eval}")
|
console.print(f"评估集: {args.num_eval}")
|
||||||
|
console.print(f"输出分片打乱: {'是' if args.shuffle else '否'}")
|
||||||
|
console.print(f"随机种子: {args.seed}")
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
# ── 第 1 遍:统计 ──
|
# ── 第 1 遍:统计 ──
|
||||||
|
|
@ -382,8 +415,8 @@ def main():
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
# ── 抽取评估集位置 ──
|
# ── 抽取评估集位置 ──
|
||||||
rng = np.random.RandomState()
|
eval_rng = np.random.RandomState(args.seed + 100)
|
||||||
eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False)
|
eval_positions = eval_rng.choice(total_samples, size=args.num_eval, replace=False)
|
||||||
eval_positions.sort()
|
eval_positions.sort()
|
||||||
eval_map = _global_to_shard(eval_positions, shard_sizes)
|
eval_map = _global_to_shard(eval_positions, shard_sizes)
|
||||||
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
|
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
|
||||||
|
|
@ -399,6 +432,8 @@ def main():
|
||||||
eval_map,
|
eval_map,
|
||||||
shard_sizes,
|
shard_sizes,
|
||||||
args.shard_size,
|
args.shard_size,
|
||||||
|
shuffle=args.shuffle,
|
||||||
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 输出总结 ──
|
# ── 输出总结 ──
|
||||||
|
|
|
||||||
|
|
@ -632,7 +632,6 @@ class Trainer:
|
||||||
|
|
||||||
def _create_progress(self) -> Progress:
|
def _create_progress(self) -> Progress:
|
||||||
return Progress(
|
return Progress(
|
||||||
SpinnerColumn(),
|
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
BarColumn(),
|
BarColumn(),
|
||||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
|
|
@ -725,11 +724,7 @@ class Trainer:
|
||||||
progress.reset(
|
progress.reset(
|
||||||
batch_task,
|
batch_task,
|
||||||
total=steps_per_epoch,
|
total=steps_per_epoch,
|
||||||
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}",
|
description=f"[green]Epoch {epoch + 1} Step 0/{steps_per_epoch}",
|
||||||
)
|
|
||||||
progress.update(
|
|
||||||
epoch_task,
|
|
||||||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
epoch_step = 0
|
epoch_step = 0
|
||||||
|
|
@ -762,7 +757,8 @@ class Trainer:
|
||||||
progress.update(
|
progress.update(
|
||||||
batch_task,
|
batch_task,
|
||||||
advance=1,
|
advance=1,
|
||||||
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}"
|
description=f"[green]Epoch {epoch + 1} "
|
||||||
|
f"Step {epoch_step + 1}/{steps_per_epoch}"
|
||||||
f" | Loss: {loss:.4f}"
|
f" | Loss: {loss:.4f}"
|
||||||
f" | LR: {current_lr:.2e}",
|
f" | LR: {current_lr:.2e}",
|
||||||
)
|
)
|
||||||
|
|
@ -816,7 +812,7 @@ class Trainer:
|
||||||
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
progress.console.log(log_text)
|
progress.log(log_text)
|
||||||
|
|
||||||
# 重置累积指标
|
# 重置累积指标
|
||||||
accumulated_loss = 0.0
|
accumulated_loss = 0.0
|
||||||
|
|
@ -1260,16 +1256,31 @@ def train(
|
||||||
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
||||||
|
|
||||||
if is_train_preprocessed:
|
if is_train_preprocessed:
|
||||||
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
|
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=2)
|
||||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
|
||||||
|
# 预打乱数据不需要 DataLoader 的 RandomSampler(避免跨分片解压抖动)
|
||||||
|
shuffle_train = not pre_shuffled
|
||||||
|
if max_iter_length > 0:
|
||||||
|
max_samples_per_epoch = max_iter_length
|
||||||
|
capped_samples = min(len(train_dataset), max_samples_per_epoch)
|
||||||
|
else:
|
||||||
|
capped_samples = len(train_dataset)
|
||||||
|
total_steps = (capped_samples // batch_size) * num_epochs
|
||||||
|
train_num_workers = min(num_workers, 1)
|
||||||
|
logger.info(
|
||||||
|
f"Preprocessed dataset: {len(train_dataset):,} samples, "
|
||||||
|
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
|
||||||
|
f"workers={train_num_workers}, steps={total_steps:,}"
|
||||||
|
)
|
||||||
train_dataloader = create_dataloader(
|
train_dataloader = create_dataloader(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=train_num_workers,
|
||||||
pin_memory=torch.cuda.is_available(),
|
pin_memory=torch.cuda.is_available(),
|
||||||
shuffle=True,
|
shuffle=shuffle_train,
|
||||||
)
|
)
|
||||||
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
||||||
|
config_table.add_row("数据", "预打乱", str(pre_shuffled))
|
||||||
else:
|
else:
|
||||||
train_dataset = PinyinInputDataset(
|
train_dataset = PinyinInputDataset(
|
||||||
data_path=train_data_path,
|
data_path=train_data_path,
|
||||||
|
|
@ -1296,7 +1307,7 @@ def train(
|
||||||
eval_dataloader = create_dataloader(
|
eval_dataloader = create_dataloader(
|
||||||
dataset=eval_dataset,
|
dataset=eval_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=2,
|
num_workers=0,
|
||||||
pin_memory=torch.cuda.is_available(),
|
pin_memory=torch.cuda.is_available(),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue