refactor(trainer): 优化检查点保存逻辑以支持定期覆盖保存
This commit is contained in:
parent
3e529d805f
commit
310a926c98
|
|
@ -311,14 +311,21 @@ class Trainer:
|
||||||
|
|
||||||
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
|
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
|
||||||
|
|
||||||
def save_checkpoint(self, filename: str, is_best: bool = False):
|
def save_checkpoint(
|
||||||
|
self, filename: str, is_best: bool = False, is_periodic: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
保存检查点
|
保存检查点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 检查点文件名
|
filename: 检查点文件名
|
||||||
is_best: 是否是最佳模型
|
is_best: 是否是最佳模型
|
||||||
|
is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点)
|
||||||
"""
|
"""
|
||||||
|
# 如果是定期保存,使用固定的文件名来覆盖之前的
|
||||||
|
if is_periodic:
|
||||||
|
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
|
||||||
|
else:
|
||||||
checkpoint_path = self.checkpoint_dir / filename
|
checkpoint_path = self.checkpoint_dir / filename
|
||||||
|
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
|
|
@ -506,9 +513,8 @@ class Trainer:
|
||||||
# 更新最佳模型
|
# 更新最佳模型
|
||||||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||||
self.best_eval_loss = eval_metrics["eval_loss"]
|
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||||
self.save_checkpoint(
|
# 只保存best_model,不创建额外的checkpoint文件
|
||||||
f"step_{global_step + 1}.pt", is_best=True
|
self.save_checkpoint("best_model.pt", is_best=True)
|
||||||
)
|
|
||||||
|
|
||||||
# 记录到TensorBoard
|
# 记录到TensorBoard
|
||||||
self._log_to_tensorboard(log_metrics, global_step)
|
self._log_to_tensorboard(log_metrics, global_step)
|
||||||
|
|
@ -535,9 +541,9 @@ class Trainer:
|
||||||
accumulated_accuracy = 0.0
|
accumulated_accuracy = 0.0
|
||||||
accumulation_counter = 0
|
accumulation_counter = 0
|
||||||
|
|
||||||
# 定期保存检查点
|
# 定期保存检查点(覆盖之前的定期检查点)
|
||||||
if (global_step + 1) % self.save_frequency == 0:
|
if (global_step + 1) % self.save_frequency == 0:
|
||||||
self.save_checkpoint(f"step_{global_step}.pt")
|
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
|
||||||
|
|
||||||
# 更新步数
|
# 更新步数
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue