diff --git a/src/trainer/model.py b/src/trainer/model.py index 06e42fe..a29f5bb 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -509,9 +509,7 @@ class MoEModel(nn.Module): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) if loss_weight: - loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight)) - loss_weight = loss_weight / loss_weight.mean() - loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0) + loss_weight = 1 / torch.sqrt(torch.sqrt(torch.tensor(loss_weight))) self.loss_weight = loss_weight.to(self.device) criterion.weight = self.loss_weight