diff --git a/src/trainer/model.py b/src/trainer/model.py index 088d0fa..9f7b64e 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -351,12 +351,12 @@ class MoEModel(nn.Module): monitor: TrainingMonitor = None, criterion=nn.CrossEntropyLoss(), optimizer=None, - scheduler=None, num_epochs=1, eval_frequency=500, grad_accum_steps=1, clip_grad_norm=1.0, mixed_precision=False, + lr=1e-4, lr_schedule=None, # 新增:可选的自定义学习率调度函数 ): """ @@ -379,10 +379,7 @@ class MoEModel(nn.Module): # 默认优化器 if optimizer is None: - optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4 - created_optimizer = True - else: - created_optimizer = False # 用户传入优化器,不自动覆盖学习率 + optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4 # 混合精度缩放器 scaler = amp.GradScaler(enabled=mixed_precision) @@ -398,15 +395,7 @@ class MoEModel(nn.Module): processed_batches += 1 # ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- - if created_optimizer and lr_schedule is None: - if processed_batches <= 8000: - new_lr = 1e-4 - else: - new_lr = 6e-6 - # 为所有参数组统一设置学习率 - for param_group in optimizer.param_groups: - param_group["lr"] = new_lr - elif lr_schedule is not None: + if lr_schedule is not None: # 调用用户自定义的调度函数 lr_schedule(processed_batches, optimizer)