refactor(trainer): 优化检查点保存逻辑以支持定期覆盖保存
This commit is contained in:
parent
3e529d805f
commit
310a926c98
|
|
@ -311,15 +311,22 @@ class Trainer:
|
|||
|
||||
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
|
||||
|
||||
def save_checkpoint(self, filename: str, is_best: bool = False):
|
||||
def save_checkpoint(
|
||||
self, filename: str, is_best: bool = False, is_periodic: bool = False
|
||||
):
|
||||
"""
|
||||
保存检查点
|
||||
|
||||
Args:
|
||||
filename: 检查点文件名
|
||||
is_best: 是否是最佳模型
|
||||
is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点)
|
||||
"""
|
||||
checkpoint_path = self.checkpoint_dir / filename
|
||||
# 如果是定期保存,使用固定的文件名来覆盖之前的
|
||||
if is_periodic:
|
||||
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
|
||||
else:
|
||||
checkpoint_path = self.checkpoint_dir / filename
|
||||
|
||||
checkpoint = {
|
||||
"step": self.current_step,
|
||||
|
|
@ -506,9 +513,8 @@ class Trainer:
|
|||
# 更新最佳模型
|
||||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||
self.save_checkpoint(
|
||||
f"step_{global_step + 1}.pt", is_best=True
|
||||
)
|
||||
# 只保存best_model,不创建额外的checkpoint文件
|
||||
self.save_checkpoint("best_model.pt", is_best=True)
|
||||
|
||||
# 记录到TensorBoard
|
||||
self._log_to_tensorboard(log_metrics, global_step)
|
||||
|
|
@ -535,9 +541,9 @@ class Trainer:
|
|||
accumulated_accuracy = 0.0
|
||||
accumulation_counter = 0
|
||||
|
||||
# 定期保存检查点
|
||||
# 定期保存检查点(覆盖之前的定期检查点)
|
||||
if (global_step + 1) % self.save_frequency == 0:
|
||||
self.save_checkpoint(f"step_{global_step}.pt")
|
||||
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
|
||||
|
||||
# 更新步数
|
||||
global_step += 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue