移除 global_step 自增逻辑并调整至循环末尾

This commit is contained in:
songsenand 2026-02-20 23:28:35 +08:00
parent 558d7f9fc9
commit 4560a9ed06
1 changed files with 1 additions and 1 deletions

View File

@ -570,7 +570,6 @@ class MoEModel(nn.Module):
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1
original_loss = loss.item() * grad_accum_steps original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss batch_loss_sum += original_loss
# 周期性评估(与原代码相同) # 周期性评估(与原代码相同)
@ -592,6 +591,7 @@ class MoEModel(nn.Module):
batch_loss_sum = 0.0 batch_loss_sum = 0.0
if processed_batches >= stop_batch: if processed_batches >= stop_batch:
break break
global_step += 1
def load_from_state_dict(self, state_dict_path: Union[str, Path]): def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load( state_dict = torch.load(