diff --git a/.opencode/plans/precompute_shuffle.md b/.opencode/plans/precompute_shuffle.md new file mode 100644 index 0000000..fac3d92 --- /dev/null +++ b/.opencode/plans/precompute_shuffle.md @@ -0,0 +1,229 @@ +# 预处理数据预打乱方案 + +## 目标 + +用 CPU 机时预打乱数据,让训练时直接用 `shuffle=False` 顺序读取,消除跨分片缓存抖动和 CPU 利用率低的问题。 + +## 改动清单 + +### 1. `src/model/subsample.py` — 默认开启输出打乱 + +**函数签名变更:** +```python +def pass2_subsample( + ..., + shuffle: bool = True, # 新增 + seed: int = 42, # 新增 +) -> Tuple[int, int]: +``` + +**改动点:** +- `rng = np.random.RandomState()` → `rng = np.random.RandomState(seed)` +- 新增 `shuffle_rng = np.random.RandomState(seed + 1)` 用于输出打乱 +- 在两次 `np.savez_compressed` 写入前(line ~251 和 line ~273),插入: +```python +if shuffle and train_buf_count > 1: + perm = shuffle_rng.permutation(train_buf_count) + for f in FIELDS: + merged[f] = merged[f][perm] +``` +- main() 中新增参数: +```python +parser.add_argument("--no-shuffle", action="store_false", dest="shuffle", + help="禁用输出分片内部打乱") +parser.add_argument("--seed", type=int, default=42, + help="随机种子(用于选择+打乱)") +``` +- 调用 `pass2_subsample` 时传入 `shuffle=args.shuffle, seed=args.seed` + +**metadata.json 新增字段:** +```json +"pre_shuffled": true, +"seed": 42 +``` + +--- + +### 2. 新建 `src/model/shuffle_npz.py` — 平衡+打乱已有数据 + +处理流程: + +``` +Phase 1: npz → 逐字段打乱 → .npy(mmap 友好) + 输入: 19个不平衡 .npz 分片(100M样本,~10GB 压缩) + 输出: 114个临时 .npy 文件(6字段 × 19分片,~144GB 未压缩) + 内存峰值: ~10GB(单字段 int16 最大 ~5GB + permuted copy ~5GB) + 耗时: ~15-20分钟(解压+写入) + +Phase 2: .npy → 平衡分配 → .npz + 输入: Phase 1 的 .npy 文件(mmap 模式) + 输出: 100个平衡 .npz 分片(每片100万样本,~100MB/片压缩后) + 每个输出分片 = 从19个源各取比例份额 → concatenate → shuffle → save + 内存峰值: ~3GB(1个输出缓冲 + mmap pages) + 耗时: ~10-15分钟(mmap读取+压缩写入) + +总计: ~30-40分钟,峰值内存 ~10GB +``` + +**磁盘需求:** +- 临时 .npy 文件(Phase 1→2 中间产物):~144GB +- 最终输出 .npz:~10GB +- 临时文件在 Phase 2 完成后自动删除 + +**用法:** +```bash +python -m src.model.shuffle_npz \ + --input-dir /home/songsenand/DataSet/SubPro \ + --output-dir /home/songsenand/DataSet/SubPro_Shuffled \ + --shard-size 1000000 \ + --seed 42 +``` + +**关键实现:** + +Phase 1 — 逐字段加载+打乱: +```python +for src_idx in range(num_shards): + data = np.load(shard_path) # lazy NpzFile + n = shard_sizes[src_idx] + perm = rng.permutation(n) + + for field in FIELDS: + arr = data[field].copy() # ~5GB peak (input_ids) + shuffled = arr[perm] # ~5GB temp + np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled) + del arr, shuffled + gc.collect() + + data.close() +``` + +Phase 2 — 平衡分配: +```python +# 打开所有源 mmap(读模式,零内存) +src_mmaps = [ + {f: np.load(temp_dir / f / f"shard_{i:06d}.npy", mmap_mode='r') for f in FIELDS} + for i in range(num_shards) +] + +for out_j in range(num_output_shards): + buffers = {f: [] for f in FIELDS} + + for src_i in range(num_shards): + s = shard_sizes[src_i] + start = (out_j * s) // num_output_shards + end = ((out_j + 1) * s) // num_output_shards + if start >= end: + continue + for f in FIELDS: + chunk = src_mmaps[src_i][f][start:end].copy() + buffers[f].append(chunk) + + output = {f: np.concatenate(buffers[f]) for f in FIELDS} + # 额外打乱(跨源混合) + perm = rng.permutation(len(output[FIELDS[0]])) + for f in FIELDS: + output[f] = output[f][perm] + + np.savez_compressed(output_dir / f"shard_{out_j:06d}.npz", **output) + del output, buffers, perm + gc.collect() +``` + +**输出 metadata.json:** +```json +{ + "num_samples": 99998406, + "max_seq_length": 128, + "dtype": "int16", + "fields": [...], + "shard_size": 1000000, + "num_shards": 100, + "shard_sizes": [1000000, ..., 998406], + "pre_shuffled": true, + "seed": 42 +} +``` + +**eval 目录处理:** 如果 `--input-dir/eval/` 存在,直接复制到 `--output-dir/eval/`(eval 数据量小,不需要打乱) + +--- + +### 3. `src/model/trainer.py` — 预处理数据禁用 shuffle + +**改动点(train 函数,line ~1258-1272):** + +```python +if is_train_preprocessed: + train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1) + # pre_shuffled 数据不需要 DataLoader 的 RandomSampler + shuffle_train = not train_dataset.metadata.get("pre_shuffled", False) + total_steps = (len(train_dataset) // batch_size) * num_epochs + # 支持 max_iter_length 限制总步数 + if max_iter_length > 0: + max_steps_per_epoch = max_iter_length // batch_size + total_steps = min(total_steps, max_steps_per_epoch * num_epochs) + train_num_workers = min(num_workers, 1) + train_dataloader = create_dataloader( + dataset=train_dataset, + batch_size=batch_size, + num_workers=train_num_workers, + pin_memory=torch.cuda.is_available(), + shuffle=shuffle_train, # 预打乱数据不 shuffle + ) +``` + +**eval DataLoader(line ~1295-1303):** + +```python +if is_eval_preprocessed: + eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1) + eval_dataloader = create_dataloader( + dataset=eval_dataset, + batch_size=batch_size, + num_workers=0, # eval 数据小,单进程足够 + pin_memory=torch.cuda.is_available(), + shuffle=False, # eval 不需要打乱 + ) +``` + +**`create_dataloader` 函数(line ~1076-1114):** 无需改动,`shuffle` 参数已透传。 + +--- + +### 4. `src/model/preprocessed_dataset.py` + +现有代码无需修改。`PreProcessedDataset` 已经可以正确处理 `shuffle=False` 的情况(PyTorch 的 `SequentialSampler` 会按 0..N-1 顺序读取)。 + +`metadata["pre_shuffled"]` 字段由 subsample.py 和 shuffle_npz.py 在写入 metadata.json 时添加,训练代码读取判断即可。 + +--- + +## 执行顺序 + +```bash +# Step 1: 打乱并平衡已有的 100M 数据集 +python -m src.model.shuffle_npz \ + --input-dir /home/songsenand/DataSet/SubPro \ + --output-dir /home/songsenand/DataSet/SubPro_Shuffled \ + --shard-size 1000000 \ + --seed 42 + +# Step 2: 用新数据训练(数据已打乱,顺序读取即可) +uv run train-model train \ + --train-data-path /home/songsenand/DataSet/SubPro_Shuffled/train \ + --eval-data-path /home/songsenand/DataSet/SubPro_Shuffled/eval \ + -b 16 \ + -o ~/tmp \ + --eval-frequency 20 \ + --save-frequency 40 +``` + +## 预期效果 + +| 改动前 | 改动后 | +|---|---| +| 每 batch 跨 13 个 shard | 顺序读,1 个 shard 在缓存中 | +| 每 batch 数据加载 2-3 分钟 | ~0.1-0.5 秒(纯 mmap/memory) | +| CPU 利用率 10% | 正常(训练计算是瓶颈) | +| 内存 40GB+ | <20GB(单 shard 1M 样本 ≈ 1.4GB) | diff --git a/src/model/components.py b/src/model/components.py index b5b67bf..8cbbbf9 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -351,7 +351,7 @@ def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_expert idx, k_idx = mask.nonzero(as_tuple=True) if idx.numel() > 0: w = topk_weights[idx, k_idx].unsqueeze(-1) - output.index_add_(0, idx, w * experts[e](x_flat[idx])) + output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype)) return output @@ -411,7 +411,11 @@ class MoELayer(nn.Module): idx, k_idx = mask.nonzero(as_tuple=True) if idx.numel() > 0: w = topk_weights[idx, k_idx].unsqueeze(-1) - out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx])) + out_flat.index_add_( + 0, + idx, + (w * self.experts[e](x_flat[idx])).to(out_flat.dtype), + ) elif self.moe_mode == "sparse_allow_graph": out_flat = _sparse_moe_dispatch( diff --git a/src/model/inspect_preprocessed.py b/src/model/inspect_preprocessed.py index 8c0cb5f..6f8992c 100644 --- a/src/model/inspect_preprocessed.py +++ b/src/model/inspect_preprocessed.py @@ -231,7 +231,7 @@ def main(): console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]") if dataset._is_sharded: console.print( - f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本" + f" 分片数: {dataset._num_shards}, 每分片: {min(dataset._shard_sizes):,} - {max(dataset._shard_sizes):,} 样本" ) console.print() diff --git a/src/model/preprocess.py b/src/model/preprocess.py index 63eac20..7449dee 100644 --- a/src/model/preprocess.py +++ b/src/model/preprocess.py @@ -29,6 +29,9 @@ import argparse import gc import json +import struct +import time +import zipfile from pathlib import Path from typing import Dict, List @@ -132,6 +135,10 @@ def collect_samples( actual_count = min(total, num_samples) num_shards = shard_idx + shard_sizes = [shard_size] * (num_shards - 1) + if num_shards > 0: + shard_sizes.append(actual_count - sum(shard_sizes)) + metadata = { "num_samples": actual_count, "max_seq_length": max_seq_length, @@ -139,6 +146,7 @@ def collect_samples( "fields": FIELDS, "shard_size": shard_size, "num_shards": num_shards, + "shard_sizes": shard_sizes, } with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp: json.dump(metadata, fp, indent=2, ensure_ascii=False) @@ -329,7 +337,97 @@ def main(): console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)") -app = main +def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000): + """ + 将过大的 .npz 分片拆分到目标大小。 + + 内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。 + 建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。 + """ + import gc + import time + + from tqdm import tqdm + + input_path = Path(input_dir) + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + metadata_path = input_path / "metadata.json" + if not metadata_path.exists(): + raise FileNotFoundError(f"metadata.json not found in {input_dir}") + + with open(metadata_path) as f: + metadata = json.load(f) + + shard_files = sorted(input_path.glob("shard_*.npz")) + console = Console() + console.print( + f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]" + ) + + global_shard_idx = 0 + all_shard_sizes: List[int] = [] + total_samples = 0 + + for sf in tqdm(shard_files, desc="Shards"): + t0 = time.time() + data = np.load(str(sf)) + n_samples = data[list(data.keys())[0]].shape[0] + + for start in range(0, n_samples, target_size): + end = min(start + target_size, n_samples) + chunk_data: Dict[str, np.ndarray] = {} + for field in data.keys(): + chunk_data[field] = data[field][start:end].astype(np.int16).copy() + + out_file = output_path / f"shard_{global_shard_idx:06d}.npz" + np.savez_compressed(out_file, **chunk_data) + chunk_size = end - start + all_shard_sizes.append(chunk_size) + total_samples += chunk_size + del chunk_data + gc.collect() + + global_shard_idx += 1 + + data.close() + del data + gc.collect() + + with open(output_path / "metadata.json", "w", encoding="utf-8") as f: + metadata["num_samples"] = total_samples + metadata["num_shards"] = global_shard_idx + metadata["shard_sizes"] = all_shard_sizes + metadata.pop("shard_size", None) + json.dump(metadata, f, indent=2, ensure_ascii=False) + + console.print( + f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]" + ) + + +def _dispatch(): + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "resplit": + import argparse as ap + + p = ap.ArgumentParser(description="拆分过大的 .npz 分片") + p.add_argument( + "--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz" + ) + p.add_argument("--output-dir", required=True, help="输出目录") + p.add_argument( + "--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)" + ) + args = p.parse_args(sys.argv[2:]) + resplit_shards(args.input_dir, args.output_dir, args.target_size) + else: + main() + + +app = _dispatch if __name__ == "__main__": - main() + _dispatch() diff --git a/src/model/preprocessed_dataset.py b/src/model/preprocessed_dataset.py index 5f1eb51..5d3771d 100644 --- a/src/model/preprocessed_dataset.py +++ b/src/model/preprocessed_dataset.py @@ -8,6 +8,7 @@ GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。 """ +import time import gc import json import struct @@ -47,6 +48,13 @@ def _read_shard_size(npz_path: Path) -> int: return header["shape"][0] +def _build_offsets(sizes: List[int]) -> List[int]: + offsets = [0] + for s in sizes: + offsets.append(offsets[-1] + s) + return offsets + + def is_preprocessed_data(path: str) -> bool: """判断路径是否为预处理数据目录""" p = Path(path) @@ -94,38 +102,48 @@ class PreProcessedDataset(Dataset): """ def __init__(self, data_dir: str, max_cache_shards: int = 1): + t_start = time.perf_counter() self.data_dir = Path(data_dir) with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f: self.metadata = json.load(f) self.max_seq_length = self.metadata["max_seq_length"] - self._shard_size: Optional[int] = self.metadata.get("shard_size") - if self._shard_size is not None: - shard_files = sorted(self.data_dir.glob("shard_*.npz")) + shard_files = sorted(self.data_dir.glob("shard_*.npz")) + if shard_files: self._num_shards = len(shard_files) - if self._num_shards == 0: - raise FileNotFoundError( - f"No shard_*.npz files found in {self.data_dir}" + if ( + "shard_sizes" in self.metadata + and len(self.metadata["shard_sizes"]) == self._num_shards + ): + self._shard_sizes = self.metadata["shard_sizes"] + logger.info( + f"Using shard_sizes from metadata.json ({self._num_shards} shards)" ) - self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files] - self._shard_offsets = [0] - for s in self._shard_sizes: - self._shard_offsets.append(self._shard_offsets[-1] + s) + else: + t_scan = time.perf_counter() + self._shard_sizes = [_read_shard_size(sf) for sf in shard_files] + logger.info( + f"Scanned {self._num_shards} shard headers in " + f"{time.perf_counter() - t_scan:.2f}s" + ) + self._shard_offsets = _build_offsets(self._shard_sizes) self.num_samples = self._shard_offsets[-1] self._is_sharded = True self._cache = _ShardCache(max_size=max_cache_shards) logger.info( f"Loaded sharded dataset: {self.num_samples:,} samples, " - f"{self._num_shards} shards" + f"{self._num_shards} shards, " + f"init in {time.perf_counter() - t_start:.2f}s" ) else: self.num_samples = self.metadata["num_samples"] self._is_sharded = False self._load_single_files() logger.info( - f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)" + f"Loaded single-file dataset: {self.num_samples:,} samples (mmap), " + f"init in {time.perf_counter() - t_start:.2f}s" ) def _load_single_files(self): @@ -157,6 +175,10 @@ class PreProcessedDataset(Dataset): f"Index {idx} out of range for dataset with {self.num_samples} samples" ) + if not hasattr(self, "_first_access_logged"): + self._first_access_logged = True + logger.info("First __getitem__ call (initial shard load may be slow)") + if self._is_sharded: lo, hi = 0, len(self._shard_offsets) - 1 while lo < hi: @@ -168,7 +190,7 @@ class PreProcessedDataset(Dataset): shard_idx = lo - 1 local_idx = idx - self._shard_offsets[shard_idx] shard_data = self._cache.get(shard_idx, self._load_shard) - return { + result = { "input_ids": torch.from_numpy( shard_data["input_ids"][local_idx].astype(np.int64) ), @@ -189,7 +211,7 @@ class PreProcessedDataset(Dataset): ), } else: - return { + result = { "input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)), "token_type_ids": torch.from_numpy( self.token_type_ids[idx].astype(np.int64) @@ -203,6 +225,7 @@ class PreProcessedDataset(Dataset): ), "pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)), } + return result def preprocessed_collate_fn(batch): diff --git a/src/model/shuffle_npz.py b/src/model/shuffle_npz.py new file mode 100644 index 0000000..9c1be11 --- /dev/null +++ b/src/model/shuffle_npz.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +""" +打乱并平衡预处理 .npz 分片。 + +两阶段处理: + Phase 1: 逐分片内部打乱 → 写入临时 .npy(mmap 友好) + Phase 2: 从临时 .npy 按比例分配到平衡的输出 .npz 分片 + +用法: + python -m src.model.shuffle_npz \ + --input-dir /path/to/subsampled \ + --output-dir /path/to/shuffled \ + --shard-size 1000000 \ + --seed 42 +""" + +import argparse +import gc +import json +import shutil +from pathlib import Path +from typing import Dict, List + +import numpy as np +from loguru import logger +from rich.console import Console +from tqdm import tqdm + +FIELDS = [ + "input_ids", + "token_type_ids", + "attention_mask", + "labels", + "history_slot_ids", + "pinyin_ids", +] + + +def _phase1_shuffle_to_npy( + input_dir: Path, + split: str, + temp_dir: Path, + shard_sizes: List[int], + seed: int, +): + """ + Phase 1: 逐分片内部打乱,写入 .npy 文件。 + + 每个分片的每个字段写入一个独立 .npy,路径格式: + temp_dir//shard_.npy + + 峰值内存 ~10GB(单字段 int16 约 5GB + permuted copy 约 5GB)。 + """ + metadata_path = input_dir / split / "metadata.json" + with open(metadata_path) as f: + metadata = json.load(f) + + num_shards = metadata["num_shards"] + rng = np.random.RandomState(seed) + + for field in FIELDS: + (temp_dir / field).mkdir(parents=True, exist_ok=True) + + pbar = tqdm( + total=num_shards, + desc=f"Phase 1: shuffling {split}", + unit="shard", + ) + + for src_idx in range(num_shards): + shard_path = input_dir / split / f"shard_{src_idx:06d}.npz" + data = np.load(shard_path) + n = shard_sizes[src_idx] + perm = rng.permutation(n) + + for field in FIELDS: + arr = data[field].copy() + shuffled = arr[perm] + np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled) + del arr, shuffled + gc.collect() + + data.close() + pbar.update(1) + + pbar.close() + + +def _phase2_rebalance_to_npz( + temp_dir: Path, + output_dir: Path, + split: str, + shard_sizes: List[int], + target_shard_size: int, + max_seq_length: int, + seed: int, +) -> List[int]: + """ + Phase 2: 从 mmap 的 .npy 文件中按比例分配样本到平衡的 .npz 输出分片。 + + 每个输出分片从所有源分片各取 proportional chunk → concat → shuffle → save。 + 内存峰值 ~3GB(一个输出缓冲 + mmap pages)。 + """ + output_split_dir = output_dir / split + output_split_dir.mkdir(parents=True, exist_ok=True) + + total_samples = sum(shard_sizes) + num_output_shards = (total_samples + target_shard_size - 1) // target_shard_size + rng = np.random.RandomState(seed + 2) + + logger.info( + f"Phase 2: distributing {total_samples:,} samples into " + f"{num_output_shards} output shards (~{target_shard_size:,} each)" + ) + + # 打开所有源 .npy 的 mmap(读模式,零 RAM 开销) + src_mmaps: List[Dict[str, np.ndarray]] = [] + for src_idx in range(len(shard_sizes)): + shard_mmap = {} + for field in FIELDS: + npy_path = temp_dir / field / f"shard_{src_idx:06d}.npy" + shard_mmap[field] = np.load(npy_path, mmap_mode="r") + src_mmaps.append(shard_mmap) + + output_shard_sizes: List[int] = [] + + pbar = tqdm( + total=num_output_shards, + desc=f"Phase 2: writing {split}", + unit="shard", + ) + + for out_j in range(num_output_shards): + buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS} + + for src_i in range(len(shard_sizes)): + s = shard_sizes[src_i] + start = (out_j * s) // num_output_shards + end = ((out_j + 1) * s) // num_output_shards + if start >= end: + continue + + for field in FIELDS: + chunk = src_mmaps[src_i][field][start:end].copy() + buffers[field].append(chunk) + + output = {} + for field in FIELDS: + output[field] = ( + np.concatenate(buffers[field]) + if len(buffers[field]) > 1 + else buffers[field][0] + ) + + out_count = len(output[FIELDS[0]]) + if out_count > 1: + perm = rng.permutation(out_count) + for field in FIELDS: + output[field] = output[field][perm] + + np.savez_compressed(output_split_dir / f"shard_{out_j:06d}.npz", **output) + output_shard_sizes.append(out_count) + + del output, buffers + gc.collect() + pbar.update(1) + + pbar.close() + + # 关闭所有 mmap + src_mmaps.clear() + gc.collect() + + return output_shard_sizes + + +def _copy_eval(input_dir: Path, output_dir: Path): + """复制 eval 目录(数据量小,无需打乱)。""" + src_eval = input_dir / "eval" + dst_eval = output_dir / "eval" + if not src_eval.exists(): + return + logger.info(f"Copying eval data from {src_eval}") + if dst_eval.exists(): + shutil.rmtree(dst_eval) + shutil.copytree(src_eval, dst_eval) + + +def main(): + console = Console() + + parser = argparse.ArgumentParser(description="打乱并平衡预处理 .npz 分片") + parser.add_argument( + "--input-dir", type=str, required=True, help="输入目录(含 train/ 和 eval/)" + ) + parser.add_argument("--output-dir", type=str, required=True, help="输出目录") + parser.add_argument( + "--shard-size", + type=int, + default=1_000_000, + help="输出分片大小(样本数),默认 100 万", + ) + parser.add_argument("--seed", type=int, default=42, help="随机种子") + + args = parser.parse_args() + input_dir = Path(args.input_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + train_meta_path = input_dir / "train" / "metadata.json" + if not train_meta_path.exists(): + console.print(f"[red]错误: {train_meta_path} 不存在[/red]") + return + + with open(train_meta_path) as f: + train_meta = json.load(f) + + max_seq_length = train_meta["max_seq_length"] + shard_sizes: List[int] = train_meta["shard_sizes"] + total_samples = sum(shard_sizes) + + console.print("[bold cyan]=== 打乱并平衡预处理数据 ===[/bold cyan]") + console.print(f"输入目录: {input_dir}") + console.print(f"输出目录: {output_dir}") + console.print(f"总样本数: {total_samples:,}") + console.print(f"源分片数: {len(shard_sizes)}") + console.print( + f"目标分片大小: {args.shard_size:,} " + f"(约 {(total_samples + args.shard_size - 1) // args.shard_size} 个分片)" + ) + console.print(f"随机种子: {args.seed}") + console.print() + + # 临时目录 + temp_dir = Path("/home/songsenand/tmp/shuffle_npz_temp") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir(parents=True, exist_ok=True) + temp_train_dir = temp_dir / "train" + temp_train_dir.mkdir(parents=True, exist_ok=True) + + try: + # ── Phase 1 ── + console.print("[bold]Phase 1: 逐分片打乱 → 临时 .npy[/bold]") + _phase1_shuffle_to_npy( + input_dir=input_dir, + split="train", + temp_dir=temp_train_dir, + shard_sizes=shard_sizes, + seed=args.seed, + ) + + # ── Phase 2 ── + console.print("[bold]Phase 2: 按比例分配到平衡输出分片[/bold]") + output_shard_sizes = _phase2_rebalance_to_npz( + temp_dir=temp_train_dir, + output_dir=output_dir, + split="train", + shard_sizes=shard_sizes, + target_shard_size=args.shard_size, + max_seq_length=max_seq_length, + seed=args.seed, + ) + + # ── 写 metadata ── + train_output_meta = { + "num_samples": sum(output_shard_sizes), + "max_seq_length": max_seq_length, + "dtype": "int16", + "fields": FIELDS, + "shard_size": args.shard_size, + "num_shards": len(output_shard_sizes), + "shard_sizes": output_shard_sizes, + "pre_shuffled": True, + "seed": args.seed, + } + out_train_dir = output_dir / "train" + with open(out_train_dir / "metadata.json", "w", encoding="utf-8") as f: + json.dump(train_output_meta, f, indent=2, ensure_ascii=False) + + # ── 复制 eval ── + _copy_eval(input_dir, output_dir) + + finally: + # 清理临时文件 + console.print("[dim]清理临时文件...[/dim]") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + # ── 总结 ── + console.print() + console.print("[bold green]=== 完成 ===[/bold green]") + console.print( + f"train/: {sum(output_shard_sizes):,} 样本, " + f"{len(output_shard_sizes)} 分片, " + f"pre_shuffled=True" + ) + + for sdir_name in ["train", "eval"]: + sdir = output_dir / sdir_name + if sdir.exists(): + total_size = sum( + f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz" + ) + meta_path = sdir / "metadata.json" + if meta_path.exists(): + with open(meta_path) as mf: + meta = json.load(mf) + else: + meta = {} + console.print( + f" {sdir_name}/: {total_size / (1024**3):.2f} GB, " + f"{meta.get('num_shards', '?')} shards" + ) + + +if __name__ == "__main__": + main() diff --git a/src/model/subsample.py b/src/model/subsample.py index 16e5829..2bc61e3 100644 --- a/src/model/subsample.py +++ b/src/model/subsample.py @@ -128,6 +128,8 @@ def pass2_subsample( eval_map: Dict[int, List[int]], shard_sizes: List[int], shard_size: int = 5_000_000, + shuffle: bool = True, + seed: int = 42, ) -> Tuple[int, int]: """ 第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。 @@ -147,12 +149,19 @@ def pass2_subsample( output_train_dir.mkdir(parents=True, exist_ok=True) output_eval_dir.mkdir(parents=True, exist_ok=True) - rng = np.random.RandomState() + # 清理旧输出分片(避免残留文件污染) + for old_shard in sorted(output_train_dir.glob("shard_*.npz")): + old_shard.unlink() + logger.debug(f"Removed old shard: {old_shard.name}") + + rng = np.random.RandomState(seed) + shuffle_rng = np.random.RandomState(seed + 1) remaining = dict(quotas) # {label_id: remaining_to_keep} train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS} train_buf_count = 0 train_shard_idx = 0 + train_shard_sizes: List[int] = [] total_train_kept = 0 eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS} @@ -247,9 +256,14 @@ def pass2_subsample( merged = {} for f in FIELDS: merged[f] = np.concatenate(train_buffers[f], axis=0) + if shuffle and train_buf_count > 1: + perm = shuffle_rng.permutation(train_buf_count) + for f in FIELDS: + merged[f] = merged[f][perm] np.savez_compressed( output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged ) + train_shard_sizes.append(train_buf_count) logger.debug( f"Saved train shard {train_shard_idx}: {train_buf_count} samples" ) @@ -268,9 +282,14 @@ def pass2_subsample( merged = {} for f in FIELDS: merged[f] = np.concatenate(train_buffers[f], axis=0) + if shuffle and train_buf_count > 1: + perm = shuffle_rng.permutation(train_buf_count) + for f in FIELDS: + merged[f] = merged[f][perm] np.savez_compressed( output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged ) + train_shard_sizes.append(train_buf_count) logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples") train_shard_idx += 1 @@ -291,6 +310,9 @@ def pass2_subsample( "fields": FIELDS, "shard_size": shard_size, "num_shards": train_shard_idx, + "shard_sizes": train_shard_sizes, + "pre_shuffled": shuffle, + "seed": seed, } with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f: json.dump(train_metadata, f, indent=2, ensure_ascii=False) @@ -334,6 +356,15 @@ def main(): help="输出分片大小(样本数)", ) parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数") + parser.add_argument( + "--no-shuffle", + action="store_false", + dest="shuffle", + help="禁用输出分片内部打乱", + ) + parser.add_argument( + "--seed", type=int, default=42, help="随机种子(用于标签选择 + 输出打乱)" + ) args = parser.parse_args() input_dir = Path(args.input_dir) @@ -346,6 +377,8 @@ def main(): console.print(f"每 ID 封顶: {args.cap_per_label:,}") console.print(f"目标训练集: {args.target_total:,}") console.print(f"评估集: {args.num_eval}") + console.print(f"输出分片打乱: {'是' if args.shuffle else '否'}") + console.print(f"随机种子: {args.seed}") console.print() # ── 第 1 遍:统计 ── @@ -382,8 +415,8 @@ def main(): console.print() # ── 抽取评估集位置 ── - rng = np.random.RandomState() - eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False) + eval_rng = np.random.RandomState(args.seed + 100) + eval_positions = eval_rng.choice(total_samples, size=args.num_eval, replace=False) eval_positions.sort() eval_map = _global_to_shard(eval_positions, shard_sizes) console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中") @@ -399,6 +432,8 @@ def main(): eval_map, shard_sizes, args.shard_size, + shuffle=args.shuffle, + seed=args.seed, ) # ── 输出总结 ── diff --git a/src/model/trainer.py b/src/model/trainer.py index ffe65b7..cbea273 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -632,7 +632,6 @@ class Trainer: def _create_progress(self) -> Progress: return Progress( - SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), @@ -725,11 +724,7 @@ class Trainer: progress.reset( batch_task, total=steps_per_epoch, - description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}", - ) - progress.update( - epoch_task, - description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}", + description=f"[green]Epoch {epoch + 1} Step 0/{steps_per_epoch}", ) epoch_step = 0 @@ -762,7 +757,8 @@ class Trainer: progress.update( batch_task, advance=1, - description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}" + description=f"[green]Epoch {epoch + 1} " + f"Step {epoch_step + 1}/{steps_per_epoch}" f" | Loss: {loss:.4f}" f" | LR: {current_lr:.2e}", ) @@ -816,7 +812,7 @@ class Trainer: f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}" ) - progress.console.log(log_text) + progress.log(log_text) # 重置累积指标 accumulated_loss = 0.0 @@ -1260,16 +1256,31 @@ def train( is_eval_preprocessed = is_preprocessed_data(eval_data_path) if is_train_preprocessed: - train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1) - total_steps = (len(train_dataset) // batch_size) * num_epochs + train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=2) + pre_shuffled = train_dataset.metadata.get("pre_shuffled", False) + # 预打乱数据不需要 DataLoader 的 RandomSampler(避免跨分片解压抖动) + shuffle_train = not pre_shuffled + if max_iter_length > 0: + max_samples_per_epoch = max_iter_length + capped_samples = min(len(train_dataset), max_samples_per_epoch) + else: + capped_samples = len(train_dataset) + total_steps = (capped_samples // batch_size) * num_epochs + train_num_workers = min(num_workers, 1) + logger.info( + f"Preprocessed dataset: {len(train_dataset):,} samples, " + f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, " + f"workers={train_num_workers}, steps={total_steps:,}" + ) train_dataloader = create_dataloader( dataset=train_dataset, batch_size=batch_size, - num_workers=num_workers, + num_workers=train_num_workers, pin_memory=torch.cuda.is_available(), - shuffle=True, + shuffle=shuffle_train, ) config_table.add_row("数据", "训练数据类型", "预处理数据") + config_table.add_row("数据", "预打乱", str(pre_shuffled)) else: train_dataset = PinyinInputDataset( data_path=train_data_path, @@ -1296,7 +1307,7 @@ def train( eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=batch_size, - num_workers=2, + num_workers=0, pin_memory=torch.cuda.is_available(), shuffle=False, )