调整损失权重计算并优化训练循环终止条件
This commit is contained in:
parent
94b44e6f71
commit
0d529c0c89
|
|
@ -510,6 +510,8 @@ class MoEModel(nn.Module):
|
|||
self.to(self.device)
|
||||
if loss_weight:
|
||||
loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight))
|
||||
loss_weight = loss_weight / loss_weight.mean()
|
||||
loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0)
|
||||
self.loss_weight = loss_weight.to(self.device)
|
||||
criterion.weight = self.loss_weight
|
||||
|
||||
|
|
@ -529,7 +531,7 @@ class MoEModel(nn.Module):
|
|||
optimizer.zero_grad()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
|
||||
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)):
|
||||
# ---------- 更新 batch 计数器 ----------
|
||||
processed_batches += 1
|
||||
|
||||
|
|
@ -582,6 +584,8 @@ class MoEModel(nn.Module):
|
|||
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
|
||||
)
|
||||
batch_loss_sum = 0.0
|
||||
if processed_batches >= stop_batch:
|
||||
break
|
||||
|
||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||
state_dict = torch.load(
|
||||
|
|
|
|||
Loading…
Reference in New Issue