refactor(trainer): 优化进度条逻辑与训练循环结构
This commit is contained in:
parent
d0f1534086
commit
27beb7f0b1
|
|
@ -10,9 +10,11 @@ GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
|||
|
||||
import gc
|
||||
import json
|
||||
import struct
|
||||
import zipfile
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -29,6 +31,22 @@ FIELDS = [
|
|||
]
|
||||
|
||||
|
||||
def _read_shard_size(npz_path: Path) -> int:
|
||||
with zipfile.ZipFile(npz_path, "r") as z:
|
||||
first = sorted(z.namelist())[0]
|
||||
with z.open(first) as f:
|
||||
magic = f.read(6)
|
||||
if magic != b"\x93NUMPY":
|
||||
raise ValueError(f"Not a numpy file: {npz_path}")
|
||||
ver = struct.unpack("<BB", f.read(2))
|
||||
if ver[0] >= 2:
|
||||
header_len = struct.unpack("<I", f.read(4))[0]
|
||||
else:
|
||||
header_len = struct.unpack("<H", f.read(2))[0]
|
||||
header = eval(f.read(header_len))
|
||||
return header["shape"][0]
|
||||
|
||||
|
||||
def is_preprocessed_data(path: str) -> bool:
|
||||
"""判断路径是否为预处理数据目录"""
|
||||
p = Path(path)
|
||||
|
|
@ -91,12 +109,16 @@ class PreProcessedDataset(Dataset):
|
|||
raise FileNotFoundError(
|
||||
f"No shard_*.npz files found in {self.data_dir}"
|
||||
)
|
||||
self.num_samples = self._num_shards * self._shard_size
|
||||
self._shard_sizes: List[int] = [_read_shard_size(sf) for sf in shard_files]
|
||||
self._shard_offsets = [0]
|
||||
for s in self._shard_sizes:
|
||||
self._shard_offsets.append(self._shard_offsets[-1] + s)
|
||||
self.num_samples = self._shard_offsets[-1]
|
||||
self._is_sharded = True
|
||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||
logger.info(
|
||||
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
||||
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
|
||||
f"{self._num_shards} shards"
|
||||
)
|
||||
else:
|
||||
self.num_samples = self.metadata["num_samples"]
|
||||
|
|
@ -139,8 +161,15 @@ class PreProcessedDataset(Dataset):
|
|||
)
|
||||
|
||||
if self._is_sharded:
|
||||
shard_idx = idx // self._shard_size
|
||||
local_idx = idx % self._shard_size
|
||||
lo, hi = 0, len(self._shard_offsets) - 1
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if self._shard_offsets[mid] <= idx:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
shard_idx = lo - 1
|
||||
local_idx = idx - self._shard_offsets[shard_idx]
|
||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
||||
return {
|
||||
"input_ids": torch.from_numpy(
|
||||
|
|
|
|||
|
|
@ -630,14 +630,15 @@ class Trainer:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to write training status: {e}")
|
||||
|
||||
def _create_progress_bar(self) -> Progress:
|
||||
"""创建Rich进度条"""
|
||||
def _create_progress(self) -> Progress:
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TextColumn("|"),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("|"),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
expand=True,
|
||||
|
|
@ -704,20 +705,34 @@ class Trainer:
|
|||
accumulated_accuracy = 0.0
|
||||
accumulation_counter = 0
|
||||
|
||||
# 创建进度条
|
||||
with self._create_progress_bar() as progress:
|
||||
# 计算每个 epoch 的步数
|
||||
steps_per_epoch = max(1, self.total_steps // self.num_epochs)
|
||||
|
||||
# 创建双进度条:epoch 级 + epoch 内 batch 级
|
||||
with self._create_progress() as progress:
|
||||
epoch_task = progress.add_task(
|
||||
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
|
||||
total=self.total_steps,
|
||||
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
|
||||
total=self.num_epochs,
|
||||
completed=self.current_epoch,
|
||||
)
|
||||
batch_task = progress.add_task(
|
||||
"[green]Batch",
|
||||
total=steps_per_epoch,
|
||||
)
|
||||
|
||||
# 训练循环
|
||||
for epoch in range(self.current_epoch, self.num_epochs):
|
||||
self.current_epoch = epoch
|
||||
progress.reset(
|
||||
batch_task,
|
||||
total=steps_per_epoch,
|
||||
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}",
|
||||
)
|
||||
progress.update(
|
||||
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
|
||||
epoch_task,
|
||||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}",
|
||||
)
|
||||
|
||||
epoch_step = 0
|
||||
for batch_idx, batch in enumerate(self.train_dataloader):
|
||||
# 更新学习率
|
||||
current_lr = self._update_learning_rate()
|
||||
|
|
@ -743,16 +758,17 @@ class Trainer:
|
|||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 更新进度条
|
||||
# 更新 batch 进度条 (每 step 推进一次)
|
||||
progress.update(
|
||||
epoch_task,
|
||||
batch_task,
|
||||
advance=1,
|
||||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
|
||||
f"Step {global_step}/{self.total_steps} | "
|
||||
f"Loss: {loss:.4f} | "
|
||||
f"LR: {current_lr:.2e}",
|
||||
description=f"[green]Batch Epoch {epoch + 1}/{self.num_epochs}"
|
||||
f" | Loss: {loss:.4f}"
|
||||
f" | LR: {current_lr:.2e}",
|
||||
)
|
||||
|
||||
epoch_step += 1
|
||||
|
||||
# 定期评估和记录
|
||||
if (global_step + 1) % self.eval_frequency == 0:
|
||||
# 计算平均指标
|
||||
|
|
@ -780,7 +796,6 @@ class Trainer:
|
|||
# 更新最佳模型
|
||||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||
# 只保存best_model,不创建额外的checkpoint文件
|
||||
self.save_checkpoint("best_model.pt", is_best=True)
|
||||
|
||||
# 记录到TensorBoard
|
||||
|
|
@ -818,12 +833,16 @@ class Trainer:
|
|||
|
||||
# 检查是否达到总步数
|
||||
if global_step >= self.total_steps:
|
||||
progress.update(epoch_task, completed=self.total_steps)
|
||||
break
|
||||
|
||||
# 进度条不重置,显示整体训练进度
|
||||
# epoch 内步数检查
|
||||
if epoch_step >= steps_per_epoch:
|
||||
break
|
||||
|
||||
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个)
|
||||
# epoch 完成
|
||||
progress.update(epoch_task, advance=1)
|
||||
|
||||
# 每个 epoch 结束后保存检查点
|
||||
self.save_epoch_checkpoint(epoch + 1)
|
||||
|
||||
# 检查是否达到总步数
|
||||
|
|
|
|||
Loading…
Reference in New Issue