调整损失权重计算并优化训练循环终止条件

This commit is contained in:
songsenand 2026-02-15 01:48:37 +08:00
parent 94b44e6f71
commit 0d529c0c89
1 changed files with 5 additions and 1 deletions

View File

@ -510,6 +510,8 @@ class MoEModel(nn.Module):
self.to(self.device)
if loss_weight:
loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight))
loss_weight = loss_weight / loss_weight.mean()
loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0)
self.loss_weight = loss_weight.to(self.device)
criterion.weight = self.loss_weight
@ -529,7 +531,7 @@ class MoEModel(nn.Module):
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)):
# ---------- 更新 batch 计数器 ----------
processed_batches += 1
@ -582,6 +584,8 @@ class MoEModel(nn.Module):
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
)
batch_loss_sum = 0.0
if processed_batches >= stop_batch:
break
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(