diff --git a/src/trainer/model.py b/src/trainer/model.py index 54ed892..bfdc1ff 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -570,7 +570,6 @@ class MoEModel(nn.Module): scaler.step(optimizer) scaler.update() optimizer.zero_grad() - global_step += 1 original_loss = loss.item() * grad_accum_steps batch_loss_sum += original_loss # 周期性评估(与原代码相同) @@ -592,6 +591,7 @@ class MoEModel(nn.Module): batch_loss_sum = 0.0 if processed_batches >= stop_batch: break + global_step += 1 def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load(