refactor(trainer): 优化检查点保存逻辑以支持定期覆盖保存

This commit is contained in:
songsenand 2026-04-05 07:54:16 +08:00
parent 3e529d805f
commit 310a926c98
1 changed files with 13 additions and 7 deletions

View File

@ -311,15 +311,22 @@ class Trainer:
return {"eval_loss": avg_loss, "eval_accuracy": accuracy} return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
def save_checkpoint(self, filename: str, is_best: bool = False): def save_checkpoint(
self, filename: str, is_best: bool = False, is_periodic: bool = False
):
""" """
保存检查点 保存检查点
Args: Args:
filename: 检查点文件名 filename: 检查点文件名
is_best: 是否是最佳模型 is_best: 是否是最佳模型
is_periodic: 是否是定期保存的检查点会覆盖之前的定期检查点
""" """
checkpoint_path = self.checkpoint_dir / filename # 如果是定期保存,使用固定的文件名来覆盖之前的
if is_periodic:
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
else:
checkpoint_path = self.checkpoint_dir / filename
checkpoint = { checkpoint = {
"step": self.current_step, "step": self.current_step,
@ -506,9 +513,8 @@ class Trainer:
# 更新最佳模型 # 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss: if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"] self.best_eval_loss = eval_metrics["eval_loss"]
self.save_checkpoint( # 只保存best_model不创建额外的checkpoint文件
f"step_{global_step + 1}.pt", is_best=True self.save_checkpoint("best_model.pt", is_best=True)
)
# 记录到TensorBoard # 记录到TensorBoard
self._log_to_tensorboard(log_metrics, global_step) self._log_to_tensorboard(log_metrics, global_step)
@ -535,9 +541,9 @@ class Trainer:
accumulated_accuracy = 0.0 accumulated_accuracy = 0.0
accumulation_counter = 0 accumulation_counter = 0
# 定期保存检查点 # 定期保存检查点(覆盖之前的定期检查点)
if (global_step + 1) % self.save_frequency == 0: if (global_step + 1) % self.save_frequency == 0:
self.save_checkpoint(f"step_{global_step}.pt") self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
# 更新步数 # 更新步数
global_step += 1 global_step += 1