From 310a926c98dee7135d4999237ed37411af3cfbe3 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 5 Apr 2026 07:54:16 +0800 Subject: [PATCH] =?UTF-8?q?refactor(trainer):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E7=82=B9=E4=BF=9D=E5=AD=98=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E4=BB=A5=E6=94=AF=E6=8C=81=E5=AE=9A=E6=9C=9F=E8=A6=86=E7=9B=96?= =?UTF-8?q?=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/trainer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/model/trainer.py b/src/model/trainer.py index fca8c64..aa6eb05 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -311,15 +311,22 @@ class Trainer: return {"eval_loss": avg_loss, "eval_accuracy": accuracy} - def save_checkpoint(self, filename: str, is_best: bool = False): + def save_checkpoint( + self, filename: str, is_best: bool = False, is_periodic: bool = False + ): """ 保存检查点 Args: filename: 检查点文件名 is_best: 是否是最佳模型 + is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点) """ - checkpoint_path = self.checkpoint_dir / filename + # 如果是定期保存,使用固定的文件名来覆盖之前的 + if is_periodic: + checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt" + else: + checkpoint_path = self.checkpoint_dir / filename checkpoint = { "step": self.current_step, @@ -506,9 +513,8 @@ class Trainer: # 更新最佳模型 if eval_metrics["eval_loss"] < self.best_eval_loss: self.best_eval_loss = eval_metrics["eval_loss"] - self.save_checkpoint( - f"step_{global_step + 1}.pt", is_best=True - ) + # 只保存best_model,不创建额外的checkpoint文件 + self.save_checkpoint("best_model.pt", is_best=True) # 记录到TensorBoard self._log_to_tensorboard(log_metrics, global_step) @@ -535,9 +541,9 @@ class Trainer: accumulated_accuracy = 0.0 accumulation_counter = 0 - # 定期保存检查点 + # 定期保存检查点(覆盖之前的定期检查点) if (global_step + 1) % self.save_frequency == 0: - self.save_checkpoint(f"step_{global_step}.pt") + self.save_checkpoint("latest_checkpoint.pt", is_periodic=True) # 更新步数 global_step += 1