fix(trainer): 调整模型保存频率以避免频繁写盘
This commit is contained in:
parent
369424be28
commit
3e529d805f
|
|
@ -17,7 +17,7 @@ dependencies = [
|
||||||
"requests>=2.32.5",
|
"requests>=2.32.5",
|
||||||
"rich>=14.3.1",
|
"rich>=14.3.1",
|
||||||
"tensorboard>=2.20.0",
|
"tensorboard>=2.20.0",
|
||||||
"torch>=2.11.0",
|
"torch>=2.10.0",
|
||||||
"transformers==5.1.0",
|
"transformers==5.1.0",
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ class Trainer:
|
||||||
grad_accum_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
clip_grad_norm: float = 1.0,
|
clip_grad_norm: float = 1.0,
|
||||||
eval_frequency: int = 500,
|
eval_frequency: int = 500,
|
||||||
save_frequency: int = 1000,
|
save_frequency: int = 10000,
|
||||||
mixed_precision: bool = True,
|
mixed_precision: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
use_tensorboard: bool = True,
|
use_tensorboard: bool = True,
|
||||||
|
|
@ -507,7 +507,7 @@ class Trainer:
|
||||||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||||
self.best_eval_loss = eval_metrics["eval_loss"]
|
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||||
self.save_checkpoint(
|
self.save_checkpoint(
|
||||||
f"step_{global_step}.pt", is_best=True
|
f"step_{global_step + 1}.pt", is_best=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录到TensorBoard
|
# 记录到TensorBoard
|
||||||
|
|
@ -661,7 +661,7 @@ def train(
|
||||||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||||||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||||||
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
|
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
|
||||||
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
|
save_frequency: int = typer.Option(10000, "--save-frequency", help="保存频率"),
|
||||||
# 其他参数
|
# 其他参数
|
||||||
mixed_precision: bool = typer.Option(
|
mixed_precision: bool = typer.Option(
|
||||||
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue