refactor(trainer): 优化进度条逻辑与训练循环结构

This commit is contained in:
songsenand 2026-05-11 00:14:17 +08:00
parent d0f1534086
commit 27beb7f0b1
2 changed files with 71 additions and 23 deletions

View File

@ -10,9 +10,11 @@ GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
import gc
import json
import struct
import zipfile
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, List, Optional
import numpy as np
import torch
@ -29,6 +31,22 @@ FIELDS = [
]
def _read_shard_size(npz_path: Path) -> int:
with zipfile.ZipFile(npz_path, "r") as z:
first = sorted(z.namelist())[0]
with z.open(first) as f:
magic = f.read(6)
if magic != b"\x93NUMPY":
raise ValueError(f"Not a numpy file: {npz_path}")
ver = struct.unpack("<BB", f.read(2))
if ver[0] >= 2:
header_len = struct.unpack("<I", f.read(4))[0]
else:
header_len = struct.unpack("<H", f.read(2))[0]
header = eval(f.read(header_len))
return header["shape"][0]
def is_preprocessed_data(path: str) -> bool:
"""判断路径是否为预处理数据目录"""
p = Path(path)
@ -91,12 +109,16 @@ class PreProcessedDataset(Dataset):
raise FileNotFoundError(
f"No shard_*.npz files found in {self.data_dir}"
)
self.num_samples = self._num_shards * self._shard_size
self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files]
self._shard_offsets = [0]
for s in self._shard_sizes:
self._shard_offsets.append(self._shard_offsets[-1] + s)
self.num_samples = self._shard_offsets[-1]
self._is_sharded = True
self._cache = _ShardCache(max_size=max_cache_shards)
logger.info(
f"Loaded sharded dataset: {self.num_samples:,} samples, "
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
f"{self._num_shards} shards"
)
else:
self.num_samples = self.metadata["num_samples"]
@ -139,8 +161,15 @@ class PreProcessedDataset(Dataset):
)
if self._is_sharded:
shard_idx = idx // self._shard_size
local_idx = idx % self._shard_size
lo, hi = 0, len(self._shard_offsets) - 1
while lo < hi:
mid = (lo + hi) // 2
if self._shard_offsets[mid] <= idx:
lo = mid + 1
else:
hi = mid
shard_idx = lo - 1
local_idx = idx - self._shard_offsets[shard_idx]
shard_data = self._cache.get(shard_idx, self._load_shard)
return {
"input_ids": torch.from_numpy(

View File

@ -630,14 +630,15 @@ class Trainer:
except Exception as e:
logger.error(f"Failed to write training status: {e}")
def _create_progress_bar(self) -> Progress:
"""创建Rich进度条"""
def _create_progress(self) -> Progress:
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("|"),
TimeElapsedColumn(),
TextColumn("|"),
TimeRemainingColumn(),
console=self.console,
expand=True,
@ -704,20 +705,34 @@ class Trainer:
accumulated_accuracy = 0.0
accumulation_counter = 0
# 创建进度条
with self._create_progress_bar() as progress:
# 计算每个 epoch 的步数
steps_per_epoch = max(1, self.total_steps // self.num_epochs)
# 创建双进度条epoch 级 + epoch 内 batch 级
with self._create_progress() as progress:
epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.total_steps,
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.num_epochs,
completed=self.current_epoch,
)
batch_task = progress.add_task(
"[green]Batch",
total=steps_per_epoch,
)
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch
progress.reset(
batch_task,
total=steps_per_epoch,
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}",
)
progress.update(
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
epoch_task,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}",
)
epoch_step = 0
for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率
current_lr = self._update_learning_rate()
@ -743,16 +758,17 @@ class Trainer:
self.scaler.update()
self.optimizer.zero_grad()
# 更新进度条
# 更新 batch 进度条 (每 step 推进一次)
progress.update(
epoch_task,
batch_task,
advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
f"Step {global_step}/{self.total_steps} | "
f"Loss: {loss:.4f} | "
f"LR: {current_lr:.2e}",
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}"
f" | Loss: {loss:.4f}"
f" | LR: {current_lr:.2e}",
)
epoch_step += 1
# 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标
@ -780,7 +796,6 @@ class Trainer:
# 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
# 只保存best_model不创建额外的checkpoint文件
self.save_checkpoint("best_model.pt", is_best=True)
# 记录到TensorBoard
@ -818,12 +833,16 @@ class Trainer:
# 检查是否达到总步数
if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break
# 进度条不重置,显示整体训练进度
# epoch 内步数检查
if epoch_step >= steps_per_epoch:
break
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个)
# epoch 完成
progress.update(epoch_task, advance=1)
# 每个 epoch 结束后保存检查点
self.save_epoch_checkpoint(epoch + 1)
# 检查是否达到总步数