refactor(trainer): 优化进度条逻辑与训练循环结构

This commit is contained in:
songsenand 2026-05-11 00:14:17 +08:00
parent d0f1534086
commit 27beb7f0b1
2 changed files with 71 additions and 23 deletions

View File

@ -10,9 +10,11 @@ GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
import gc import gc
import json import json
import struct
import zipfile
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
@ -29,6 +31,22 @@ FIELDS = [
] ]
def _read_shard_size(npz_path: Path) -> int:
with zipfile.ZipFile(npz_path, "r") as z:
first = sorted(z.namelist())[0]
with z.open(first) as f:
magic = f.read(6)
if magic != b"\x93NUMPY":
raise ValueError(f"Not a numpy file: {npz_path}")
ver = struct.unpack("<BB", f.read(2))
if ver[0] >= 2:
header_len = struct.unpack("<I", f.read(4))[0]
else:
header_len = struct.unpack("<H", f.read(2))[0]
header = eval(f.read(header_len))
return header["shape"][0]
def is_preprocessed_data(path: str) -> bool: def is_preprocessed_data(path: str) -> bool:
"""判断路径是否为预处理数据目录""" """判断路径是否为预处理数据目录"""
p = Path(path) p = Path(path)
@ -91,12 +109,16 @@ class PreProcessedDataset(Dataset):
raise FileNotFoundError( raise FileNotFoundError(
f"No shard_*.npz files found in {self.data_dir}" f"No shard_*.npz files found in {self.data_dir}"
) )
self.num_samples = self._num_shards * self._shard_size self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files]
self._shard_offsets = [0]
for s in self._shard_sizes:
self._shard_offsets.append(self._shard_offsets[-1] + s)
self.num_samples = self._shard_offsets[-1]
self._is_sharded = True self._is_sharded = True
self._cache = _ShardCache(max_size=max_cache_shards) self._cache = _ShardCache(max_size=max_cache_shards)
logger.info( logger.info(
f"Loaded sharded dataset: {self.num_samples:,} samples, " f"Loaded sharded dataset: {self.num_samples:,} samples, "
f"{self._num_shards} shards, shard_size={self._shard_size:,}" f"{self._num_shards} shards"
) )
else: else:
self.num_samples = self.metadata["num_samples"] self.num_samples = self.metadata["num_samples"]
@ -139,8 +161,15 @@ class PreProcessedDataset(Dataset):
) )
if self._is_sharded: if self._is_sharded:
shard_idx = idx // self._shard_size lo, hi = 0, len(self._shard_offsets) - 1
local_idx = idx % self._shard_size while lo < hi:
mid = (lo + hi) // 2
if self._shard_offsets[mid] <= idx:
lo = mid + 1
else:
hi = mid
shard_idx = lo - 1
local_idx = idx - self._shard_offsets[shard_idx]
shard_data = self._cache.get(shard_idx, self._load_shard) shard_data = self._cache.get(shard_idx, self._load_shard)
return { return {
"input_ids": torch.from_numpy( "input_ids": torch.from_numpy(

View File

@ -630,14 +630,15 @@ class Trainer:
except Exception as e: except Exception as e:
logger.error(f"Failed to write training status: {e}") logger.error(f"Failed to write training status: {e}")
def _create_progress_bar(self) -> Progress: def _create_progress(self) -> Progress:
"""创建Rich进度条"""
return Progress( return Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("|"),
TimeElapsedColumn(), TimeElapsedColumn(),
TextColumn("|"),
TimeRemainingColumn(), TimeRemainingColumn(),
console=self.console, console=self.console,
expand=True, expand=True,
@ -704,20 +705,34 @@ class Trainer:
accumulated_accuracy = 0.0 accumulated_accuracy = 0.0
accumulation_counter = 0 accumulation_counter = 0
# 创建进度条 # 计算每个 epoch 的步数
with self._create_progress_bar() as progress: steps_per_epoch = max(1, self.total_steps // self.num_epochs)
# 创建双进度条epoch 级 + epoch 内 batch 级
with self._create_progress() as progress:
epoch_task = progress.add_task( epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}", f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.total_steps, total=self.num_epochs,
completed=self.current_epoch,
)
batch_task = progress.add_task(
"[green]Batch",
total=steps_per_epoch,
) )
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs): for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch self.current_epoch = epoch
progress.reset(
batch_task,
total=steps_per_epoch,
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}",
)
progress.update( progress.update(
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}" epoch_task,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}",
) )
epoch_step = 0
for batch_idx, batch in enumerate(self.train_dataloader): for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率 # 更新学习率
current_lr = self._update_learning_rate() current_lr = self._update_learning_rate()
@ -743,16 +758,17 @@ class Trainer:
self.scaler.update() self.scaler.update()
self.optimizer.zero_grad() self.optimizer.zero_grad()
# 更新进度条 # 更新 batch 进度条 (每 step 推进一次)
progress.update( progress.update(
epoch_task, batch_task,
advance=1, advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | " description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}"
f"Step {global_step}/{self.total_steps} | " f" | Loss: {loss:.4f}"
f"Loss: {loss:.4f} | " f" | LR: {current_lr:.2e}",
f"LR: {current_lr:.2e}",
) )
epoch_step += 1
# 定期评估和记录 # 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0: if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标 # 计算平均指标
@ -780,7 +796,6 @@ class Trainer:
# 更新最佳模型 # 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss: if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"] self.best_eval_loss = eval_metrics["eval_loss"]
# 只保存best_model不创建额外的checkpoint文件
self.save_checkpoint("best_model.pt", is_best=True) self.save_checkpoint("best_model.pt", is_best=True)
# 记录到TensorBoard # 记录到TensorBoard
@ -818,12 +833,16 @@ class Trainer:
# 检查是否达到总步数 # 检查是否达到总步数
if global_step >= self.total_steps: if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break break
# 进度条不重置,显示整体训练进度 # epoch 内步数检查
if epoch_step >= steps_per_epoch:
break
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个) # epoch 完成
progress.update(epoch_task, advance=1)
# 每个 epoch 结束后保存检查点
self.save_epoch_checkpoint(epoch + 1) self.save_epoch_checkpoint(epoch + 1)
# 检查是否达到总步数 # 检查是否达到总步数