fix(trainer): 调整模型保存频率以避免频繁写盘

This commit is contained in:
songsenand 2026-04-05 01:35:53 +08:00
parent 369424be28
commit 3e529d805f
2 changed files with 4 additions and 4 deletions

View File

@ -17,7 +17,7 @@ dependencies = [
"requests>=2.32.5",
"rich>=14.3.1",
"tensorboard>=2.20.0",
"torch>=2.11.0",
"torch>=2.10.0",
"transformers==5.1.0",
"typer>=0.21.1",
]

View File

@ -64,7 +64,7 @@ class Trainer:
grad_accum_steps: int = 1,
clip_grad_norm: float = 1.0,
eval_frequency: int = 500,
save_frequency: int = 1000,
save_frequency: int = 10000,
mixed_precision: bool = True,
device: Optional[torch.device] = None,
use_tensorboard: bool = True,
@ -507,7 +507,7 @@ class Trainer:
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
self.save_checkpoint(
f"step_{global_step}.pt", is_best=True
f"step_{global_step + 1}.pt", is_best=True
)
# 记录到TensorBoard
@ -661,7 +661,7 @@ def train(
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
save_frequency: int = typer.Option(10000, "--save-frequency", help="保存频率"),
# 其他参数
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"