修复损失权重计算逻辑,修正平方根次数以提升稳定性

This commit is contained in:
songsenand 2026-02-15 23:01:45 +08:00
parent cd25349d90
commit ab2dbc378b
1 changed files with 1 additions and 3 deletions

View File

@ -509,9 +509,7 @@ class MoEModel(nn.Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
if loss_weight:
loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight))
loss_weight = loss_weight / loss_weight.mean()
loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0)
loss_weight = 1 / torch.sqrt(torch.sqrt(torch.tensor(loss_weight)))
self.loss_weight = loss_weight.to(self.device)
criterion.weight = self.loss_weight