diff --git a/src/trainer/model.py b/src/trainer/model.py index bfdc1ff..362e6cc 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -579,6 +579,8 @@ class MoEModel(nn.Module): ): avg_loss = batch_loss_sum / eval_frequency acc, eval_loss = self.model_eval(eval_dataloader, criterion) + if global_step == 0: + avg_loss = eval_loss super().train() if monitor is not None: monitor.add_step(