移除 global_step 自增逻辑并调整至循环末尾
This commit is contained in:
parent
558d7f9fc9
commit
4560a9ed06
|
|
@ -570,7 +570,6 @@ class MoEModel(nn.Module):
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
|
||||||
original_loss = loss.item() * grad_accum_steps
|
original_loss = loss.item() * grad_accum_steps
|
||||||
batch_loss_sum += original_loss
|
batch_loss_sum += original_loss
|
||||||
# 周期性评估(与原代码相同)
|
# 周期性评估(与原代码相同)
|
||||||
|
|
@ -592,6 +591,7 @@ class MoEModel(nn.Module):
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
if processed_batches >= stop_batch:
|
if processed_batches >= stop_batch:
|
||||||
break
|
break
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue