diff --git a/src/model/trainer.py b/src/model/trainer.py index 43507c0..61e77f8 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -865,8 +865,11 @@ def train( # 开始训练 console.print("\n[bold cyan]开始训练...[/bold cyan]") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - - trainer.train(resume_from=resume_from) + try: + trainer.train(resume_from=resume_from) + except KeyboardInterrupt: + console.print("[bold green]训练被终止[/bold green]") + trainer.save_checkpoint("interrupted_model.pt") console.print("[bold green]✓ 训练完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")