添加自定义学习率调度支持并优化默认优化器配置
This commit is contained in:
parent
335540d8c2
commit
0e3418798e
|
|
@ -351,12 +351,12 @@ class MoEModel(nn.Module):
|
||||||
monitor: TrainingMonitor = None,
|
monitor: TrainingMonitor = None,
|
||||||
criterion=nn.CrossEntropyLoss(),
|
criterion=nn.CrossEntropyLoss(),
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
eval_frequency=500,
|
eval_frequency=500,
|
||||||
grad_accum_steps=1,
|
grad_accum_steps=1,
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
|
lr=1e-4,
|
||||||
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -379,10 +379,7 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# 默认优化器
|
# 默认优化器
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4
|
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
|
||||||
created_optimizer = True
|
|
||||||
else:
|
|
||||||
created_optimizer = False # 用户传入优化器,不自动覆盖学习率
|
|
||||||
|
|
||||||
# 混合精度缩放器
|
# 混合精度缩放器
|
||||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||||
|
|
@ -398,15 +395,7 @@ class MoEModel(nn.Module):
|
||||||
processed_batches += 1
|
processed_batches += 1
|
||||||
|
|
||||||
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
||||||
if created_optimizer and lr_schedule is None:
|
if lr_schedule is not None:
|
||||||
if processed_batches <= 8000:
|
|
||||||
new_lr = 1e-4
|
|
||||||
else:
|
|
||||||
new_lr = 6e-6
|
|
||||||
# 为所有参数组统一设置学习率
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group["lr"] = new_lr
|
|
||||||
elif lr_schedule is not None:
|
|
||||||
# 调用用户自定义的调度函数
|
# 调用用户自定义的调度函数
|
||||||
lr_schedule(processed_batches, optimizer)
|
lr_schedule(processed_batches, optimizer)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue