diff --git a/src/model/trainer.py b/src/model/trainer.py index fca8c64..aa6eb05 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -311,15 +311,22 @@ class Trainer: return {"eval_loss": avg_loss, "eval_accuracy": accuracy} - def save_checkpoint(self, filename: str, is_best: bool = False): + def save_checkpoint( + self, filename: str, is_best: bool = False, is_periodic: bool = False + ): """ 保存检查点 Args: filename: 检查点文件名 is_best: 是否是最佳模型 + is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点) """ - checkpoint_path = self.checkpoint_dir / filename + # 如果是定期保存,使用固定的文件名来覆盖之前的 + if is_periodic: + checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt" + else: + checkpoint_path = self.checkpoint_dir / filename checkpoint = { "step": self.current_step, @@ -506,9 +513,8 @@ class Trainer: # 更新最佳模型 if eval_metrics["eval_loss"] < self.best_eval_loss: self.best_eval_loss = eval_metrics["eval_loss"] - self.save_checkpoint( - f"step_{global_step + 1}.pt", is_best=True - ) + # 只保存best_model,不创建额外的checkpoint文件 + self.save_checkpoint("best_model.pt", is_best=True) # 记录到TensorBoard self._log_to_tensorboard(log_metrics, global_step) @@ -535,9 +541,9 @@ class Trainer: accumulated_accuracy = 0.0 accumulation_counter = 0 - # 定期保存检查点 + # 定期保存检查点(覆盖之前的定期检查点) if (global_step + 1) % self.save_frequency == 0: - self.save_checkpoint(f"step_{global_step}.pt") + self.save_checkpoint("latest_checkpoint.pt", is_periodic=True) # 更新步数 global_step += 1