refactor(trainer): 优化检查点保存逻辑以支持定期覆盖保存

This commit is contained in:
songsenand 2026-04-05 07:54:16 +08:00
parent 3e529d805f
commit 310a926c98
1 changed files with 13 additions and 7 deletions

View File

@ -311,14 +311,21 @@ class Trainer:
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
def save_checkpoint(self, filename: str, is_best: bool = False):
def save_checkpoint(
self, filename: str, is_best: bool = False, is_periodic: bool = False
):
"""
保存检查点
Args:
filename: 检查点文件名
is_best: 是否是最佳模型
is_periodic: 是否是定期保存的检查点会覆盖之前的定期检查点
"""
# 如果是定期保存,使用固定的文件名来覆盖之前的
if is_periodic:
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
else:
checkpoint_path = self.checkpoint_dir / filename
checkpoint = {
@ -506,9 +513,8 @@ class Trainer:
# 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
self.save_checkpoint(
f"step_{global_step + 1}.pt", is_best=True
)
# 只保存best_model不创建额外的checkpoint文件
self.save_checkpoint("best_model.pt", is_best=True)
# 记录到TensorBoard
self._log_to_tensorboard(log_metrics, global_step)
@ -535,9 +541,9 @@ class Trainer:
accumulated_accuracy = 0.0
accumulation_counter = 0
# 定期保存检查点
# 定期保存检查点(覆盖之前的定期检查点)
if (global_step + 1) % self.save_frequency == 0:
self.save_checkpoint(f"step_{global_step}.pt")
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
# 更新步数
global_step += 1