From 9fad2bf1d4bd218ad58292f600d202c646428b2d Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 14 Feb 2026 17:07:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8D=9F=E5=A4=B1=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E6=96=B9=E5=BC=8F=EF=BC=8C=E4=BD=BF=E7=94=A8NLLLoss?= =?UTF-8?q?=E6=9B=BF=E4=BB=A3=E5=8E=9F=E5=A7=8Bcriterion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index 496deb2..29a5f42 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -349,7 +349,8 @@ class MoEModel(nn.Module): # 前向传播 logits = self(input_ids, attention_mask, pg) - loss = criterion(logits, labels) + log_probs = torch.log(probs + 1e-12) + loss = nn.NLLLoss()(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率