移除 global_step 自增逻辑并调整至循环末尾
This commit is contained in:
parent
558d7f9fc9
commit
4560a9ed06
|
|
@ -570,7 +570,6 @@ class MoEModel(nn.Module):
|
|||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
original_loss = loss.item() * grad_accum_steps
|
||||
batch_loss_sum += original_loss
|
||||
# 周期性评估(与原代码相同)
|
||||
|
|
@ -592,6 +591,7 @@ class MoEModel(nn.Module):
|
|||
batch_loss_sum = 0.0
|
||||
if processed_batches >= stop_batch:
|
||||
break
|
||||
global_step += 1
|
||||
|
||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||
state_dict = torch.load(
|
||||
|
|
|
|||
Loading…
Reference in New Issue