From 0e3418798e70122cbaf322a4102990ffa5bbaad2 Mon Sep 17 00:00:00 2001 From: songsenand Date: Fri, 13 Feb 2026 12:58:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0=E7=8E=87=E8=B0=83=E5=BA=A6=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96=E9=BB=98=E8=AE=A4=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=99=A8=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index 088d0fa..9f7b64e 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -351,12 +351,12 @@ class MoEModel(nn.Module): monitor: TrainingMonitor = None, criterion=nn.CrossEntropyLoss(), optimizer=None, - scheduler=None, num_epochs=1, eval_frequency=500, grad_accum_steps=1, clip_grad_norm=1.0, mixed_precision=False, + lr=1e-4, lr_schedule=None, # 新增:可选的自定义学习率调度函数 ): """ @@ -379,10 +379,7 @@ class MoEModel(nn.Module): # 默认优化器 if optimizer is None: - optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4 - created_optimizer = True - else: - created_optimizer = False # 用户传入优化器,不自动覆盖学习率 + optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4 # 混合精度缩放器 scaler = amp.GradScaler(enabled=mixed_precision) @@ -398,15 +395,7 @@ class MoEModel(nn.Module): processed_batches += 1 # ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- - if created_optimizer and lr_schedule is None: - if processed_batches <= 8000: - new_lr = 1e-4 - else: - new_lr = 6e-6 - # 为所有参数组统一设置学习率 - for param_group in optimizer.param_groups: - param_group["lr"] = new_lr - elif lr_schedule is not None: + if lr_schedule is not None: # 调用用户自定义的调度函数 lr_schedule(processed_batches, optimizer)