diff --git a/src/model/trainer.py b/src/model/trainer.py index ae30f11..63baa58 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -812,7 +812,8 @@ def train( prefetch_factor=64, # 每个worker预取64个batch persistent_workers=True, # 保持worker进程存活 ) - + if torch.cuda.is_available(): + console.print("[red]GPU不可用,退回使用CPU[/red]") console.print("[green]✓ 数据加载器创建完成[/green]") console.print(f" 训练批次大小: {batch_size}") console.print(f" 评估批次大小: {batch_size}")