fix: 启用 TensorFloat32 加速矩阵乘法并解决 UserWarning
This commit is contained in:
parent
59bb29e4fd
commit
7143896f4d
|
|
@ -878,6 +878,10 @@ def train(
|
||||||
"""
|
"""
|
||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
# 设置随机种子
|
# 设置随机种子
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue