修复损失计算方式,使用NLLLoss替代原始criterion

This commit is contained in:
songsenand 2026-02-14 17:07:20 +08:00
parent f89635b201
commit 9fad2bf1d4
1 changed files with 2 additions and 1 deletions

View File

@ -349,7 +349,8 @@ class MoEModel(nn.Module):
# 前向传播 # 前向传播
logits = self(input_ids, attention_mask, pg) logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels) log_probs = torch.log(probs + 1e-12)
loss = nn.NLLLoss()(log_probs, labels)
total_loss += loss.item() * labels.size(0) total_loss += loss.item() * labels.size(0)
# 计算准确率 # 计算准确率