fix: 启用 TensorFloat32 加速矩阵乘法并解决 UserWarning

This commit is contained in:
songsenand 2026-04-05 22:16:48 +08:00
parent 59bb29e4fd
commit 7143896f4d
1 changed files with 4 additions and 0 deletions

View File

@ -878,6 +878,10 @@ def train(
"""
torch.multiprocessing.set_sharing_strategy("file_system")
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():