fix(model): 移除梯度 NaN 检查,直接执行优化器步骤
This commit is contained in:
parent
b0a4ce9ac8
commit
66c2f78dda
|
|
@ -496,17 +496,8 @@ class MoEModel(nn.Module):
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||||
|
|
||||||
has_nan = False
|
|
||||||
for p in self.parameters():
|
|
||||||
if p.grad is not None and torch.isnan(p.grad).any():
|
|
||||||
has_nan = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not has_nan:
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
|
||||||
logger.warning("NaN detected, skipping step.")
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
batch_loss_sum += loss.item() * grad_accum_steps
|
batch_loss_sum += loss.item() * grad_accum_steps
|
||||||
if global_step % eval_frequency == 0:
|
if global_step % eval_frequency == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue