From 7143896f4d1e855784297151ae3c34e043dc506b Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 5 Apr 2026 22:16:48 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=90=AF=E7=94=A8=20TensorFloat32=20?= =?UTF-8?q?=E5=8A=A0=E9=80=9F=E7=9F=A9=E9=98=B5=E4=B9=98=E6=B3=95=E5=B9=B6?= =?UTF-8?q?=E8=A7=A3=E5=86=B3=20UserWarning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/model/trainer.py b/src/model/trainer.py index f2a8582..a411788 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -878,6 +878,10 @@ def train( """ torch.multiprocessing.set_sharing_strategy("file_system") + # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("high") + # 设置随机种子 torch.manual_seed(seed) if torch.cuda.is_available():