diff --git a/src/model/trainer.py b/src/model/trainer.py index f2a8582..a411788 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -878,6 +878,10 @@ def train( """ torch.multiprocessing.set_sharing_strategy("file_system") + # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("high") + # 设置随机种子 torch.manual_seed(seed) if torch.cuda.is_available():