diff --git a/src/trainer/model.py b/src/trainer/model.py index 496deb2..29a5f42 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -349,7 +349,8 @@ class MoEModel(nn.Module): # 前向传播 logits = self(input_ids, attention_mask, pg) - loss = criterion(logits, labels) + log_probs = torch.log(probs + 1e-12) + loss = nn.NLLLoss()(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率