修复损失权重计算逻辑,修正平方根次数以提升稳定性
This commit is contained in:
parent
cd25349d90
commit
ab2dbc378b
|
|
@ -509,9 +509,7 @@ class MoEModel(nn.Module):
|
|||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
if loss_weight:
|
||||
loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight))
|
||||
loss_weight = loss_weight / loss_weight.mean()
|
||||
loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0)
|
||||
loss_weight = 1 / torch.sqrt(torch.sqrt(torch.tensor(loss_weight)))
|
||||
self.loss_weight = loss_weight.to(self.device)
|
||||
criterion.weight = self.loss_weight
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue