修复损失计算方式,使用NLLLoss替代原始criterion
This commit is contained in:
parent
f89635b201
commit
9fad2bf1d4
|
|
@ -349,7 +349,8 @@ class MoEModel(nn.Module):
|
|||
|
||||
# 前向传播
|
||||
logits = self(input_ids, attention_mask, pg)
|
||||
loss = criterion(logits, labels)
|
||||
log_probs = torch.log(probs + 1e-12)
|
||||
loss = nn.NLLLoss()(log_probs, labels)
|
||||
total_loss += loss.item() * labels.size(0)
|
||||
|
||||
# 计算准确率
|
||||
|
|
|
|||
Loading…
Reference in New Issue