From 0d529c0c8939f1ca9d2bb18ab28bd6e144b954e6 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 15 Feb 2026 01:48:37 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=8D=9F=E5=A4=B1=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E8=AE=A1=E7=AE=97=E5=B9=B6=E4=BC=98=E5=8C=96=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=BE=AA=E7=8E=AF=E7=BB=88=E6=AD=A2=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index d7d8f32..06e42fe 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -510,6 +510,8 @@ class MoEModel(nn.Module): self.to(self.device) if loss_weight: loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight)) + loss_weight = loss_weight / loss_weight.mean() + loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0) self.loss_weight = loss_weight.to(self.device) criterion.weight = self.loss_weight @@ -529,7 +531,7 @@ class MoEModel(nn.Module): optimizer.zero_grad() for epoch in range(num_epochs): - for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)): + for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)): # ---------- 更新 batch 计数器 ---------- processed_batches += 1 @@ -582,6 +584,8 @@ class MoEModel(nn.Module): f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" ) batch_loss_sum = 0.0 + if processed_batches >= stop_batch: + break def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load(