diff --git a/src/trainer/model.py b/src/trainer/model.py index b48ad4c..829ce68 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -496,17 +496,8 @@ class MoEModel(nn.Module): scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) - has_nan = False - for p in self.parameters(): - if p.grad is not None and torch.isnan(p.grad).any(): - has_nan = True - break - - if not has_nan: - scaler.step(optimizer) - scaler.update() - else: - logger.warning("NaN detected, skipping step.") + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() batch_loss_sum += loss.item() * grad_accum_steps if global_step % eval_frequency == 0: