添加自定义学习率调度支持并优化默认优化器配置

This commit is contained in:
songsenand 2026-02-13 12:58:09 +08:00
parent 335540d8c2
commit 0e3418798e
1 changed files with 3 additions and 14 deletions

View File

@ -351,12 +351,12 @@ class MoEModel(nn.Module):
monitor: TrainingMonitor = None,
criterion=nn.CrossEntropyLoss(),
optimizer=None,
scheduler=None,
num_epochs=1,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
mixed_precision=False,
lr=1e-4,
lr_schedule=None, # 新增:可选的自定义学习率调度函数
):
"""
@ -379,10 +379,7 @@ class MoEModel(nn.Module):
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4
created_optimizer = True
else:
created_optimizer = False # 用户传入优化器,不自动覆盖学习率
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
@ -398,15 +395,7 @@ class MoEModel(nn.Module):
processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
if created_optimizer and lr_schedule is None:
if processed_batches <= 8000:
new_lr = 1e-4
else:
new_lr = 6e-6
# 为所有参数组统一设置学习率
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
elif lr_schedule is not None:
if lr_schedule is not None:
# 调用用户自定义的调度函数
lr_schedule(processed_batches, optimizer)