From 3e529d805fb50da43c531511156f572fd3549e12 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 5 Apr 2026 01:35:53 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E8=B0=83=E6=95=B4=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=BF=9D=E5=AD=98=E9=A2=91=E7=8E=87=E4=BB=A5=E9=81=BF?= =?UTF-8?q?=E5=85=8D=E9=A2=91=E7=B9=81=E5=86=99=E7=9B=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- src/model/trainer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c11e041..ac9c52b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "requests>=2.32.5", "rich>=14.3.1", "tensorboard>=2.20.0", - "torch>=2.11.0", + "torch>=2.10.0", "transformers==5.1.0", "typer>=0.21.1", ] diff --git a/src/model/trainer.py b/src/model/trainer.py index 8ae627a..fca8c64 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -64,7 +64,7 @@ class Trainer: grad_accum_steps: int = 1, clip_grad_norm: float = 1.0, eval_frequency: int = 500, - save_frequency: int = 1000, + save_frequency: int = 10000, mixed_precision: bool = True, device: Optional[torch.device] = None, use_tensorboard: bool = True, @@ -507,7 +507,7 @@ class Trainer: if eval_metrics["eval_loss"] < self.best_eval_loss: self.best_eval_loss = eval_metrics["eval_loss"] self.save_checkpoint( - f"step_{global_step}.pt", is_best=True + f"step_{global_step + 1}.pt", is_best=True ) # 记录到TensorBoard @@ -661,7 +661,7 @@ def train( grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"), - save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"), + save_frequency: int = typer.Option(10000, "--save-frequency", help="保存频率"), # 其他参数 mixed_precision: bool = typer.Option( True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"