diff --git a/src/trainer/model.py b/src/trainer/model.py index d7d8f32..06e42fe 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -510,6 +510,8 @@ class MoEModel(nn.Module): self.to(self.device) if loss_weight: loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight)) + loss_weight = loss_weight / loss_weight.mean() + loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0) self.loss_weight = loss_weight.to(self.device) criterion.weight = self.loss_weight @@ -529,7 +531,7 @@ class MoEModel(nn.Module): optimizer.zero_grad() for epoch in range(num_epochs): - for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)): + for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)): # ---------- 更新 batch 计数器 ---------- processed_batches += 1 @@ -582,6 +584,8 @@ class MoEModel(nn.Module): f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" ) batch_loss_sum = 0.0 + if processed_batches >= stop_batch: + break def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load(