From 66c2f78dda91841f2dd84f61d6e7790e54dec488 Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 26 Feb 2026 01:19:17 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E7=A7=BB=E9=99=A4=E6=A2=AF?= =?UTF-8?q?=E5=BA=A6=20NaN=20=E6=A3=80=E6=9F=A5=EF=BC=8C=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E4=BC=98=E5=8C=96=E5=99=A8=E6=AD=A5=E9=AA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index b48ad4c..829ce68 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -496,17 +496,8 @@ class MoEModel(nn.Module): scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) - has_nan = False - for p in self.parameters(): - if p.grad is not None and torch.isnan(p.grad).any(): - has_nan = True - break - - if not has_nan: - scaler.step(optimizer) - scaler.update() - else: - logger.warning("NaN detected, skipping step.") + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() batch_loss_sum += loss.item() * grad_accum_steps if global_step % eval_frequency == 0: