修复周期性评估时平均损失计算错误

This commit is contained in:
songsenand 2026-02-13 11:29:22 +08:00
parent 92b12ef703
commit 02f851205f
1 changed files with 1 additions and 2 deletions

View File

@ -437,13 +437,12 @@ class MoEModel(nn.Module):
global_step += 1 global_step += 1
original_loss = loss.item() * grad_accum_steps original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss batch_loss_sum += original_loss
# 周期性评估(与原代码相同) # 周期性评估(与原代码相同)
if ( if (
eval_dataloader is not None eval_dataloader is not None
and global_step % eval_frequency == 0 and global_step % eval_frequency == 0
): ):
avg_loss = batch_loss_sum / global_step avg_loss = batch_loss_sum / eval_frequency
acc, _ = self.model_eval(eval_dataloader, criterion) acc, _ = self.model_eval(eval_dataloader, criterion)
super().train() super().train()
if monitor is not None: if monitor is not None: