添加自定义学习率调度支持并优化默认优化器配置
This commit is contained in:
parent
335540d8c2
commit
0e3418798e
|
|
@ -351,12 +351,12 @@ class MoEModel(nn.Module):
|
|||
monitor: TrainingMonitor = None,
|
||||
criterion=nn.CrossEntropyLoss(),
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
num_epochs=1,
|
||||
eval_frequency=500,
|
||||
grad_accum_steps=1,
|
||||
clip_grad_norm=1.0,
|
||||
mixed_precision=False,
|
||||
lr=1e-4,
|
||||
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
||||
):
|
||||
"""
|
||||
|
|
@ -379,10 +379,7 @@ class MoEModel(nn.Module):
|
|||
|
||||
# 默认优化器
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4
|
||||
created_optimizer = True
|
||||
else:
|
||||
created_optimizer = False # 用户传入优化器,不自动覆盖学习率
|
||||
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
|
||||
|
||||
# 混合精度缩放器
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
|
@ -398,15 +395,7 @@ class MoEModel(nn.Module):
|
|||
processed_batches += 1
|
||||
|
||||
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
||||
if created_optimizer and lr_schedule is None:
|
||||
if processed_batches <= 8000:
|
||||
new_lr = 1e-4
|
||||
else:
|
||||
new_lr = 6e-6
|
||||
# 为所有参数组统一设置学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
elif lr_schedule is not None:
|
||||
if lr_schedule is not None:
|
||||
# 调用用户自定义的调度函数
|
||||
lr_schedule(processed_batches, optimizer)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue