修复损失计算方式,使用NLLLoss替代原始criterion
This commit is contained in:
parent
f89635b201
commit
9fad2bf1d4
|
|
@ -349,7 +349,8 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
logits = self(input_ids, attention_mask, pg)
|
logits = self(input_ids, attention_mask, pg)
|
||||||
loss = criterion(logits, labels)
|
log_probs = torch.log(probs + 1e-12)
|
||||||
|
loss = nn.NLLLoss()(log_probs, labels)
|
||||||
total_loss += loss.item() * labels.size(0)
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
# 计算准确率
|
# 计算准确率
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue