diff --git a/src/model/preprocessed_dataset.py b/src/model/preprocessed_dataset.py index 4320bb0..13faf36 100644 --- a/src/model/preprocessed_dataset.py +++ b/src/model/preprocessed_dataset.py @@ -10,9 +10,11 @@ GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。 import gc import json +import struct +import zipfile from collections import OrderedDict from pathlib import Path -from typing import Dict, Optional +from typing import Dict, List, Optional import numpy as np import torch @@ -29,6 +31,22 @@ FIELDS = [ ] +def _read_shard_size(npz_path: Path) -> int: + with zipfile.ZipFile(npz_path, "r") as z: + first = sorted(z.namelist())[0] + with z.open(first) as f: + magic = f.read(6) + if magic != b"\x93NUMPY": + raise ValueError(f"Not a numpy file: {npz_path}") + ver = struct.unpack("= 2: + header_len = struct.unpack(" bool: """判断路径是否为预处理数据目录""" p = Path(path) @@ -91,12 +109,16 @@ class PreProcessedDataset(Dataset): raise FileNotFoundError( f"No shard_*.npz files found in {self.data_dir}" ) - self.num_samples = self._num_shards * self._shard_size + self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files] + self._shard_offsets = [0] + for s in self._shard_sizes: + self._shard_offsets.append(self._shard_offsets[-1] + s) + self.num_samples = self._shard_offsets[-1] self._is_sharded = True self._cache = _ShardCache(max_size=max_cache_shards) logger.info( f"Loaded sharded dataset: {self.num_samples:,} samples, " - f"{self._num_shards} shards, shard_size={self._shard_size:,}" + f"{self._num_shards} shards" ) else: self.num_samples = self.metadata["num_samples"] @@ -139,8 +161,15 @@ class PreProcessedDataset(Dataset): ) if self._is_sharded: - shard_idx = idx // self._shard_size - local_idx = idx % self._shard_size + lo, hi = 0, len(self._shard_offsets) - 1 + while lo < hi: + mid = (lo + hi) // 2 + if self._shard_offsets[mid] <= idx: + lo = mid + 1 + else: + hi = mid + shard_idx = lo - 1 + local_idx = idx - self._shard_offsets[shard_idx] shard_data = self._cache.get(shard_idx, self._load_shard) return { "input_ids": torch.from_numpy( diff --git a/src/model/trainer.py b/src/model/trainer.py index 6f7f8b9..2fde42e 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -630,14 +630,15 @@ class Trainer: except Exception as e: logger.error(f"Failed to write training status: {e}") - def _create_progress_bar(self) -> Progress: - """创建Rich进度条""" + def _create_progress(self) -> Progress: return Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("|"), TimeElapsedColumn(), + TextColumn("|"), TimeRemainingColumn(), console=self.console, expand=True, @@ -704,20 +705,34 @@ class Trainer: accumulated_accuracy = 0.0 accumulation_counter = 0 - # 创建进度条 - with self._create_progress_bar() as progress: + # 计算每个 epoch 的步数 + steps_per_epoch = max(1, self.total_steps // self.num_epochs) + + # 创建双进度条:epoch 级 + epoch 内 batch 级 + with self._create_progress() as progress: epoch_task = progress.add_task( - f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}", - total=self.total_steps, + f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}", + total=self.num_epochs, + completed=self.current_epoch, + ) + batch_task = progress.add_task( + "[green]Batch", + total=steps_per_epoch, ) - # 训练循环 for epoch in range(self.current_epoch, self.num_epochs): self.current_epoch = epoch + progress.reset( + batch_task, + total=steps_per_epoch, + description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}", + ) progress.update( - epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}" + epoch_task, + description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}", ) + epoch_step = 0 for batch_idx, batch in enumerate(self.train_dataloader): # 更新学习率 current_lr = self._update_learning_rate() @@ -743,16 +758,17 @@ class Trainer: self.scaler.update() self.optimizer.zero_grad() - # 更新进度条 + # 更新 batch 进度条 (每 step 推进一次) progress.update( - epoch_task, + batch_task, advance=1, - description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | " - f"Step {global_step}/{self.total_steps} | " - f"Loss: {loss:.4f} | " - f"LR: {current_lr:.2e}", + description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}" + f" | Loss: {loss:.4f}" + f" | LR: {current_lr:.2e}", ) + epoch_step += 1 + # 定期评估和记录 if (global_step + 1) % self.eval_frequency == 0: # 计算平均指标 @@ -780,7 +796,6 @@ class Trainer: # 更新最佳模型 if eval_metrics["eval_loss"] < self.best_eval_loss: self.best_eval_loss = eval_metrics["eval_loss"] - # 只保存best_model,不创建额外的checkpoint文件 self.save_checkpoint("best_model.pt", is_best=True) # 记录到TensorBoard @@ -818,12 +833,16 @@ class Trainer: # 检查是否达到总步数 if global_step >= self.total_steps: - progress.update(epoch_task, completed=self.total_steps) break - # 进度条不重置,显示整体训练进度 + # epoch 内步数检查 + if epoch_step >= steps_per_epoch: + break - # 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个) + # epoch 完成 + progress.update(epoch_task, advance=1) + + # 每个 epoch 结束后保存检查点 self.save_epoch_checkpoint(epoch + 1) # 检查是否达到总步数