From ab2dbc378bb62341cdcdbb0482c6a2a4bdaf8295 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 15 Feb 2026 23:01:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8D=9F=E5=A4=B1=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E8=AE=A1=E7=AE=97=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3=E5=B9=B3=E6=96=B9=E6=A0=B9=E6=AC=A1=E6=95=B0=E4=BB=A5?= =?UTF-8?q?=E6=8F=90=E5=8D=87=E7=A8=B3=E5=AE=9A=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index 06e42fe..a29f5bb 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -509,9 +509,7 @@ class MoEModel(nn.Module): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) if loss_weight: - loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight)) - loss_weight = loss_weight / loss_weight.mean() - loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0) + loss_weight = 1 / torch.sqrt(torch.sqrt(torch.tensor(loss_weight))) self.loss_weight = loss_weight.to(self.device) criterion.weight = self.loss_weight