refactor(trainer): 优化两阶段训练器代码结构和注释格式
This commit is contained in:
parent
5dda0e6f85
commit
7ac44a2731
|
|
@ -23,13 +23,7 @@ class BigExpert(InputMethodEngine):
|
|||
if compile:
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
# mode="reduce-overhead",
|
||||
mode="reduce-overhead",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
options={
|
||||
"epilogue_fusion": True,
|
||||
"max_autotune": True,
|
||||
"triton.cudagraphs": True,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -82,15 +82,15 @@ class InputMethodEngine(nn.Module):
|
|||
if compile:
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
# mode="reduce-overhead",
|
||||
mode="reduce-overhead",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
options={
|
||||
"epilogue_fusion": True,
|
||||
"max_autotune": True,
|
||||
"triton.cudagraphs": True,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
},
|
||||
# options={
|
||||
# "epilogue_fusion": True,
|
||||
# "max_autotune": True,
|
||||
# "triton.cudagraphs": True,
|
||||
# "reorder_for_compute_comm_overlap": False,
|
||||
# },
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class Trainer:
|
|||
- CrossEntropyLoss损失函数(支持weight和label_smoothing)
|
||||
- Rich终端美化输出
|
||||
"""
|
||||
|
||||
|
||||
training_status_data: List[Dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
|
|
@ -696,25 +696,27 @@ def load_expanded_model(
|
|||
) -> nn.Module:
|
||||
"""
|
||||
加载预训练基础模型并创建扩容后的新模型,冻结匹配的层。
|
||||
|
||||
|
||||
Args:
|
||||
base_model_path: 预训练基础模型检查点路径
|
||||
new_model_spec: 新模型规格,格式 "module:ClassName",如 "new_model:NewModel"
|
||||
device: 设备
|
||||
**model_kwargs: 传递给新模型构造函数的参数
|
||||
|
||||
|
||||
Returns:
|
||||
扩容后的新模型,匹配的层已冻结
|
||||
"""
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
|
||||
# 解析新模型规格
|
||||
if ":" not in new_model_spec:
|
||||
raise ValueError(f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'")
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'"
|
||||
)
|
||||
|
||||
module_name, class_name = new_model_spec.split(":", 1)
|
||||
|
||||
|
||||
# 导入模块(支持任意路径)
|
||||
module = None
|
||||
try:
|
||||
|
|
@ -726,6 +728,7 @@ def load_expanded_model(
|
|||
# 将模块名转换为可能的文件路径
|
||||
module_path = module_name.replace(".", "/") + ".py"
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot find module or loader: {module_name}")
|
||||
|
|
@ -734,43 +737,47 @@ def load_expanded_model(
|
|||
except Exception as e:
|
||||
# 尝试在当前目录下查找
|
||||
import os
|
||||
|
||||
if os.path.exists(module_name + ".py"):
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_name + ".py")
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name, module_name + ".py"
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot load module from file: {module_name}.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
else:
|
||||
raise ImportError(f"Failed to import module '{module_name}': {e}")
|
||||
|
||||
|
||||
if module is None:
|
||||
raise ImportError(f"Module '{module_name}' could not be imported")
|
||||
|
||||
|
||||
# 获取模型类
|
||||
model_class = getattr(module, class_name)
|
||||
|
||||
|
||||
# 检查模型类是否是 InputMethodEngine 的子类
|
||||
from .model import InputMethodEngine
|
||||
|
||||
if not issubclass(model_class, InputMethodEngine):
|
||||
raise TypeError(
|
||||
f"Model class {class_name} must be a subclass of InputMethodEngine. "
|
||||
f"Got {model_class.__name__} instead."
|
||||
)
|
||||
|
||||
|
||||
# 创建新模型
|
||||
new_model = model_class(**model_kwargs)
|
||||
new_model.to(device)
|
||||
|
||||
|
||||
# 加载预训练权重
|
||||
checkpoint = torch.load(base_model_path, map_location=device)
|
||||
if "model_state_dict" in checkpoint:
|
||||
pretrained_state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
pretrained_state_dict = checkpoint
|
||||
|
||||
|
||||
# 获取新模型的状态字典
|
||||
new_state_dict = new_model.state_dict()
|
||||
|
||||
|
||||
# 冻结匹配的层
|
||||
frozen_layers = []
|
||||
for key in new_state_dict.keys():
|
||||
|
|
@ -778,354 +785,23 @@ def load_expanded_model(
|
|||
if new_state_dict[key].shape == pretrained_state_dict[key].shape:
|
||||
new_state_dict[key] = pretrained_state_dict[key].to(device)
|
||||
frozen_layers.append(key)
|
||||
|
||||
|
||||
# 加载更新后的状态字典
|
||||
new_model.load_state_dict(new_state_dict)
|
||||
|
||||
|
||||
# 设置参数 requires_grad
|
||||
for name, param in new_model.named_parameters():
|
||||
if name in frozen_layers:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
logger.info(f"Loaded expanded model with {len(frozen_layers)} frozen layers")
|
||||
logger.info(f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}")
|
||||
|
||||
logger.info(
|
||||
f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}"
|
||||
)
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
class TwoStageTrainer(Trainer):
|
||||
"""
|
||||
两阶段训练器:先冻结匹配层训练,然后全量微调。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader,
|
||||
total_steps: int,
|
||||
output_dir: str = "./output",
|
||||
num_epochs: int = 10,
|
||||
learning_rate: float = 1e-4,
|
||||
min_learning_rate: float = 1e-6,
|
||||
weight_decay: float = 0.1,
|
||||
warmup_ratio: float = 0.1,
|
||||
label_smoothing: float = 0.15,
|
||||
loss_weight: Optional[torch.Tensor] = None,
|
||||
grad_accum_steps: int = 1,
|
||||
clip_grad_norm: float = 1.0,
|
||||
eval_frequency: int = 500,
|
||||
save_frequency: int = 10000,
|
||||
mixed_precision: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
status_file: str = "training_status.json",
|
||||
use_tensorboard: bool = True,
|
||||
# 两阶段训练特有参数
|
||||
frozen_patience: int = 10,
|
||||
frozen_lr: Optional[float] = None,
|
||||
full_lr: Optional[float] = None,
|
||||
frozen_scheduler: str = "cosine",
|
||||
full_scheduler: str = "cosine",
|
||||
):
|
||||
"""
|
||||
初始化两阶段训练器
|
||||
|
||||
Args:
|
||||
frozen_patience: 冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调
|
||||
frozen_lr: 冻结阶段学习率,如果为None则使用learning_rate
|
||||
full_lr: 全量微调阶段学习率,如果为None则使用learning_rate
|
||||
frozen_scheduler: 冻结阶段学习率调度器类型,"cosine"或"plateau"
|
||||
full_scheduler: 全量微调阶段学习率调度器类型,"cosine"或"plateau"
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
total_steps=total_steps,
|
||||
output_dir=output_dir,
|
||||
num_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
min_learning_rate=min_learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
warmup_ratio=warmup_ratio,
|
||||
label_smoothing=label_smoothing,
|
||||
loss_weight=loss_weight,
|
||||
grad_accum_steps=grad_accum_steps,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
eval_frequency=eval_frequency,
|
||||
save_frequency=save_frequency,
|
||||
mixed_precision=mixed_precision,
|
||||
device=device,
|
||||
status_file=status_file,
|
||||
use_tensorboard=use_tensorboard,
|
||||
)
|
||||
|
||||
# 两阶段训练参数
|
||||
self.frozen_patience = frozen_patience
|
||||
self.frozen_lr = frozen_lr if frozen_lr is not None else learning_rate
|
||||
self.full_lr = full_lr if full_lr is not None else learning_rate
|
||||
self.frozen_scheduler = frozen_scheduler
|
||||
self.full_scheduler = full_scheduler
|
||||
|
||||
# 训练状态
|
||||
self.current_stage = "frozen" # "frozen" 或 "full"
|
||||
self.frozen_best_loss = float("inf")
|
||||
self.frozen_patience_counter = 0
|
||||
|
||||
logger.info(f"TwoStageTrainer initialized with frozen_patience={frozen_patience}")
|
||||
logger.info(f"Stage: {self.current_stage}, Frozen LR: {self.frozen_lr:.2e}, Full LR: {self.full_lr:.2e}")
|
||||
|
||||
# 覆盖父类的学习率调度器为冻结阶段调度器
|
||||
self.lr_scheduler = self._create_stage_lr_scheduler("frozen")
|
||||
|
||||
def _create_stage_lr_scheduler(self, stage: str) -> Callable[[int], float]:
|
||||
"""创建阶段特定的学习率调度函数"""
|
||||
if stage == "frozen":
|
||||
base_lr = self.frozen_lr
|
||||
scheduler_type = self.frozen_scheduler
|
||||
else:
|
||||
base_lr = self.full_lr
|
||||
scheduler_type = self.full_scheduler
|
||||
|
||||
# 捕获局部变量以避免闭包中的self引用问题
|
||||
warmup_steps = self.warmup_steps
|
||||
total_steps = self.total_steps
|
||||
min_learning_rate = self.min_learning_rate
|
||||
|
||||
def lr_scheduler(step: int) -> float:
|
||||
if step < warmup_steps:
|
||||
# 线性预热
|
||||
return base_lr * (step / warmup_steps)
|
||||
else:
|
||||
if scheduler_type == "cosine":
|
||||
# 余弦退火
|
||||
progress = (step - warmup_steps) / (
|
||||
total_steps - warmup_steps
|
||||
)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
||||
decayed_lr = (
|
||||
min_learning_rate
|
||||
+ (base_lr - min_learning_rate) * cosine_decay
|
||||
)
|
||||
return decayed_lr
|
||||
elif scheduler_type == "plateau":
|
||||
# 保持恒定学习率(plateau调度需要在训练循环中实现)
|
||||
return base_lr
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
|
||||
|
||||
return lr_scheduler
|
||||
|
||||
def _switch_to_full_stage(self):
|
||||
"""切换到全量微调阶段"""
|
||||
if self.current_stage == "full":
|
||||
return
|
||||
|
||||
logger.info("Switching to full fine-tuning stage")
|
||||
self.current_stage = "full"
|
||||
|
||||
# 解冻所有参数
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
# 更新学习率调度器
|
||||
self.learning_rate = self.full_lr
|
||||
self.lr_scheduler = self._create_stage_lr_scheduler("full")
|
||||
|
||||
# 重置优化器
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.full_lr,
|
||||
weight_decay=self.weight_decay,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
# 重置训练状态
|
||||
self.frozen_best_loss = float("inf")
|
||||
self.frozen_patience_counter = 0
|
||||
|
||||
logger.info(f"All layers unfrozen, using full LR: {self.full_lr:.2e}")
|
||||
|
||||
def _update_stage_after_eval(self, eval_loss: float):
|
||||
"""根据评估结果更新训练阶段"""
|
||||
if self.current_stage == "frozen":
|
||||
# 检查是否应该切换到全量微调
|
||||
if eval_loss < self.frozen_best_loss:
|
||||
self.frozen_best_loss = eval_loss
|
||||
self.frozen_patience_counter = 0
|
||||
logger.info(f"Frozen stage new best loss: {eval_loss:.4f}")
|
||||
else:
|
||||
self.frozen_patience_counter += 1
|
||||
logger.info(f"Frozen stage patience counter: {self.frozen_patience_counter}/{self.frozen_patience}")
|
||||
|
||||
# 如果达到耐心值,切换到全量微调
|
||||
if self.frozen_patience_counter >= self.frozen_patience:
|
||||
self._switch_to_full_stage()
|
||||
|
||||
def train(
|
||||
self, resume_from: Optional[str] = None, reset_training_state: bool = False
|
||||
):
|
||||
"""
|
||||
两阶段训练循环
|
||||
|
||||
Args:
|
||||
resume_from: 从哪个检查点恢复训练(可选)
|
||||
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
|
||||
"""
|
||||
# 如果提供了检查点,则恢复训练
|
||||
if resume_from is not None:
|
||||
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
|
||||
|
||||
# 打印训练信息
|
||||
self._print_training_info()
|
||||
|
||||
# 初始化训练状态
|
||||
global_step = self.current_step
|
||||
accumulated_loss = 0.0
|
||||
accumulated_accuracy = 0.0
|
||||
accumulation_counter = 0
|
||||
|
||||
# 创建进度条
|
||||
with self._create_progress_bar() as progress:
|
||||
epoch_task = progress.add_task(
|
||||
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs} (Stage: {self.current_stage})",
|
||||
total=self.total_steps,
|
||||
)
|
||||
|
||||
# 训练循环
|
||||
for epoch in range(self.current_epoch, self.num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
for batch_idx, batch in enumerate(self.train_dataloader):
|
||||
# 更新学习率
|
||||
current_lr = self._update_learning_rate()
|
||||
|
||||
# 训练步骤
|
||||
loss, metrics = self.train_step(batch)
|
||||
|
||||
# 累积指标
|
||||
accumulated_loss += loss
|
||||
accumulated_accuracy += metrics.get("accuracy", 0.0)
|
||||
accumulation_counter += 1
|
||||
|
||||
# 梯度累积:每grad_accum_steps步更新一次参数
|
||||
if (global_step + 1) % self.grad_accum_steps == 0:
|
||||
# 梯度裁剪
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.clip_grad_norm
|
||||
)
|
||||
|
||||
# 更新参数
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 更新进度条
|
||||
progress.update(
|
||||
epoch_task,
|
||||
advance=1,
|
||||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} (Stage: {self.current_stage}) | "
|
||||
f"Step {global_step}/{self.total_steps} | "
|
||||
f"Loss: {loss:.4f} | "
|
||||
f"LR: {current_lr:.2e}",
|
||||
)
|
||||
|
||||
# 定期评估和记录
|
||||
if (global_step + 1) % self.eval_frequency == 0:
|
||||
# 计算平均指标
|
||||
avg_loss = accumulated_loss / accumulation_counter
|
||||
avg_accuracy = accumulated_accuracy / accumulation_counter
|
||||
|
||||
# 评估模型
|
||||
eval_metrics = self.evaluate()
|
||||
|
||||
# 准备日志指标
|
||||
log_metrics = {
|
||||
"train/loss": avg_loss,
|
||||
"train/accuracy": avg_accuracy,
|
||||
"train/learning_rate": current_lr,
|
||||
"train/stage": 0.0 if self.current_stage == "frozen" else 1.0,
|
||||
}
|
||||
|
||||
if eval_metrics:
|
||||
log_metrics.update(
|
||||
{
|
||||
"eval/loss": eval_metrics["eval_loss"],
|
||||
"eval/accuracy": eval_metrics["eval_accuracy"],
|
||||
}
|
||||
)
|
||||
|
||||
# 更新最佳模型(全局)
|
||||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||
# 只保存best_model,不创建额外的checkpoint文件
|
||||
self.save_checkpoint("best_model.pt", is_best=True)
|
||||
|
||||
# 更新训练阶段
|
||||
self._update_stage_after_eval(eval_metrics["eval_loss"])
|
||||
|
||||
# 记录到TensorBoard
|
||||
self._log_to_tensorboard(log_metrics, global_step)
|
||||
|
||||
# 打印日志
|
||||
log_text = (
|
||||
f"[Epoch {epoch + 1}/{self.num_epochs}] "
|
||||
f"[Stage: {self.current_stage}] "
|
||||
f"[Step {global_step}/{self.total_steps}] "
|
||||
f"Train Loss: {avg_loss:.4f} | "
|
||||
f"Train Acc: {avg_accuracy:.4f} | "
|
||||
f"LR: {current_lr:.2e}"
|
||||
)
|
||||
|
||||
if eval_metrics:
|
||||
log_text += (
|
||||
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
|
||||
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
||||
)
|
||||
|
||||
progress.console.log(log_text)
|
||||
|
||||
# 重置累积指标
|
||||
accumulated_loss = 0.0
|
||||
accumulated_accuracy = 0.0
|
||||
accumulation_counter = 0
|
||||
|
||||
# 定期保存检查点(覆盖之前的定期检查点)
|
||||
if (global_step + 1) % self.save_frequency == 0:
|
||||
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
|
||||
|
||||
# 更新步数
|
||||
global_step += 1
|
||||
self.current_step = global_step
|
||||
|
||||
# 检查是否达到总步数
|
||||
if global_step >= self.total_steps:
|
||||
progress.update(epoch_task, completed=self.total_steps)
|
||||
break
|
||||
|
||||
# 重置进度条
|
||||
progress.reset(epoch_task)
|
||||
|
||||
# 每个epoch结束后保存检查点
|
||||
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||
|
||||
# 检查是否达到总步数
|
||||
if global_step >= self.total_steps:
|
||||
break
|
||||
|
||||
# 训练完成
|
||||
logger.info("Two-stage training completed!")
|
||||
|
||||
# 保存最终模型
|
||||
self.save_checkpoint("final_model.pt")
|
||||
|
||||
# 关闭TensorBoard写入器
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
|
||||
def worker_init_fn(worker_id: int) -> None:
|
||||
"""
|
||||
初始化每个DataLoader worker的随机种子,确保可复现性
|
||||
|
|
@ -1537,7 +1213,10 @@ def expand_and_train(
|
|||
..., "--base-model-path", help="预训练基础模型检查点路径"
|
||||
),
|
||||
new_model_spec: str = typer.Option(
|
||||
..., "--new-model-spec", "-m", help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类"
|
||||
...,
|
||||
"--new-model-spec",
|
||||
"-m",
|
||||
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
|
||||
),
|
||||
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
|
||||
pinyin_vocab_size: int = typer.Option(
|
||||
|
|
@ -1555,19 +1234,19 @@ def expand_and_train(
|
|||
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
|
||||
# 两阶段训练参数
|
||||
frozen_patience: int = typer.Option(
|
||||
10, "--frozen-patience", help="冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调"
|
||||
),
|
||||
frozen_lr: float = typer.Option(
|
||||
1e-3, "--frozen-lr", help="冻结阶段学习率"
|
||||
),
|
||||
full_lr: float = typer.Option(
|
||||
1e-4, "--full-lr", help="全量微调阶段学习率"
|
||||
10,
|
||||
"--frozen-patience",
|
||||
help="冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调",
|
||||
),
|
||||
frozen_lr: float = typer.Option(1e-3, "--frozen-lr", help="冻结阶段学习率"),
|
||||
full_lr: float = typer.Option(1e-4, "--full-lr", help="全量微调阶段学习率"),
|
||||
frozen_scheduler: str = typer.Option(
|
||||
"cosine", "--frozen-scheduler", help="冻结阶段学习率调度器类型:cosine或plateau"
|
||||
),
|
||||
full_scheduler: str = typer.Option(
|
||||
"cosine", "--full-scheduler", help="全量微调阶段学习率调度器类型:cosine或plateau"
|
||||
"cosine",
|
||||
"--full-scheduler",
|
||||
help="全量微调阶段学习率调度器类型:cosine或plateau",
|
||||
),
|
||||
# 训练参数
|
||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||
|
|
@ -1607,9 +1286,6 @@ def expand_and_train(
|
|||
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
模型扩容两阶段训练:先冻结匹配层训练,然后全量微调
|
||||
"""
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
|
||||
|
|
@ -1625,7 +1301,9 @@ def expand_and_train(
|
|||
|
||||
# 打印配置信息
|
||||
console.print(
|
||||
Panel.fit("[bold cyan]模型扩容两阶段训练配置[/bold cyan]", border_style="cyan")
|
||||
Panel.fit(
|
||||
"[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan"
|
||||
)
|
||||
)
|
||||
|
||||
config_table = Table(show_header=True, header_style="bold magenta")
|
||||
|
|
@ -1652,13 +1330,8 @@ def expand_and_train(
|
|||
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
||||
config_table.add_row("模型", "编译优化", str(compile))
|
||||
|
||||
config_table.add_row("两阶段训练", "冻结阶段耐心值", str(frozen_patience))
|
||||
config_table.add_row("两阶段训练", "冻结阶段学习率", f"{frozen_lr:.2e}")
|
||||
config_table.add_row("两阶段训练", "全量阶段学习率", f"{full_lr:.2e}")
|
||||
config_table.add_row("两阶段训练", "冻结阶段调度器", frozen_scheduler)
|
||||
config_table.add_row("两阶段训练", "全量阶段调度器", full_scheduler)
|
||||
|
||||
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
||||
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
||||
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
|
||||
config_table.add_row("训练", "权重衰减", str(weight_decay))
|
||||
config_table.add_row("训练", "热身比例", str(warmup_ratio))
|
||||
|
|
@ -1689,14 +1362,10 @@ def expand_and_train(
|
|||
"num_experts": num_experts,
|
||||
"max_seq_len": max_seq_len,
|
||||
"use_pinyin": use_pinyin,
|
||||
"frozen_patience": frozen_patience,
|
||||
"frozen_lr": frozen_lr,
|
||||
"full_lr": full_lr,
|
||||
"frozen_scheduler": frozen_scheduler,
|
||||
"full_scheduler": full_scheduler,
|
||||
"batch_size": batch_size,
|
||||
"num_workers": num_workers,
|
||||
"num_epochs": num_epochs,
|
||||
"learning_rate": learning_rate,
|
||||
"min_learning_rate": min_learning_rate,
|
||||
"weight_decay": weight_decay,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
|
|
@ -1765,7 +1434,7 @@ def expand_and_train(
|
|||
# 创建扩容模型
|
||||
console.print("[bold cyan]正在创建扩容模型...[/bold cyan]")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
model_kwargs = {
|
||||
"vocab_size": vocab_size,
|
||||
"pinyin_vocab_size": pinyin_vocab_size,
|
||||
|
|
@ -1777,7 +1446,7 @@ def expand_and_train(
|
|||
"max_seq_len": max_seq_len,
|
||||
"compile": compile,
|
||||
}
|
||||
|
||||
|
||||
model = load_expanded_model(
|
||||
base_model_path=base_model_path,
|
||||
new_model_spec=new_model_spec,
|
||||
|
|
@ -1788,24 +1457,24 @@ def expand_and_train(
|
|||
console.print(
|
||||
f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
|
||||
)
|
||||
|
||||
|
||||
# 统计冻结参数比例
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
||||
console.print(
|
||||
f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params/total_params*100:.1f}%)[/green]"
|
||||
f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]"
|
||||
)
|
||||
|
||||
# 创建两阶段训练器
|
||||
console.print("[bold cyan]正在创建两阶段训练器...[/bold cyan]")
|
||||
trainer = TwoStageTrainer(
|
||||
# 创建训练器(使用普通 Trainer,只进行第一阶段冻结训练)
|
||||
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
total_steps=int(max_iter_length * num_epochs / batch_size),
|
||||
output_dir=output_dir,
|
||||
num_epochs=num_epochs,
|
||||
learning_rate=frozen_lr, # 初始学习率(会被阶段特定LR覆盖)
|
||||
learning_rate=learning_rate,
|
||||
min_learning_rate=min_learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
warmup_ratio=warmup_ratio,
|
||||
|
|
@ -1817,18 +1486,12 @@ def expand_and_train(
|
|||
mixed_precision=mixed_precision,
|
||||
use_tensorboard=use_tensorboard,
|
||||
status_file="training_status.json",
|
||||
# 两阶段训练特有参数
|
||||
frozen_patience=frozen_patience,
|
||||
frozen_lr=frozen_lr,
|
||||
full_lr=full_lr,
|
||||
frozen_scheduler=frozen_scheduler,
|
||||
full_scheduler=full_scheduler,
|
||||
)
|
||||
|
||||
console.print("[green]✓ 两阶段训练器创建完成[/green]")
|
||||
console.print("[green]✓ 训练器创建完成[/green]")
|
||||
|
||||
# 开始训练
|
||||
console.print("\n[bold cyan]开始两阶段训练...[/bold cyan]")
|
||||
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
|
||||
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
try:
|
||||
trainer.train(
|
||||
|
|
@ -1838,9 +1501,313 @@ def expand_and_train(
|
|||
console.print("[bold green]训练被终止[/bold green]")
|
||||
trainer.save_checkpoint("interrupted_model.pt")
|
||||
|
||||
console.print("[bold green]✓ 两阶段训练完成![/bold green]")
|
||||
# 保存扩容信息供第二阶段使用
|
||||
expansion_info = {
|
||||
"stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"),
|
||||
"model_spec": new_model_spec,
|
||||
"model_kwargs": model_kwargs,
|
||||
"train_data_path": train_data_path,
|
||||
"eval_data_path": eval_data_path,
|
||||
"output_dir": output_dir,
|
||||
"batch_size": batch_size,
|
||||
"max_iter_length": max_iter_length,
|
||||
"max_seq_len": max_seq_len,
|
||||
"num_workers": num_workers,
|
||||
}
|
||||
|
||||
expansion_info_file = output_path / "expansion_info.json"
|
||||
with open(expansion_info_file, "w", encoding="utf-8") as f:
|
||||
json.dump(expansion_info, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Expansion info saved to {expansion_info_file}")
|
||||
|
||||
console.print("[bold green]✓ 第一阶段训练完成![/bold green]")
|
||||
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
console.print(f"模型和日志保存在: {output_dir}")
|
||||
console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]")
|
||||
console.print(
|
||||
"[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def expand_finetune(
|
||||
expand_config: str = typer.Option(
|
||||
...,
|
||||
"--expand-config",
|
||||
"-c",
|
||||
help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'",
|
||||
),
|
||||
stage1_info: str = typer.Option(
|
||||
..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径"
|
||||
),
|
||||
# 可选覆盖参数
|
||||
checkpoint: Optional[str] = typer.Option(
|
||||
None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)"
|
||||
),
|
||||
output_dir: Optional[str] = typer.Option(
|
||||
None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)"
|
||||
),
|
||||
train_data_path: Optional[str] = typer.Option(
|
||||
None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)"
|
||||
),
|
||||
eval_data_path: Optional[str] = typer.Option(
|
||||
None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)"
|
||||
),
|
||||
batch_size: Optional[int] = typer.Option(
|
||||
None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)"
|
||||
),
|
||||
num_epochs: Optional[int] = typer.Option(
|
||||
None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)"
|
||||
),
|
||||
learning_rate: Optional[float] = typer.Option(
|
||||
None, "--learning-rate", "-lr", help="学习率"
|
||||
),
|
||||
min_learning_rate: Optional[float] = typer.Option(
|
||||
None, "--min-learning-rate", help="最小学习率"
|
||||
),
|
||||
weight_decay: Optional[float] = typer.Option(
|
||||
None, "--weight-decay", help="权重衰减"
|
||||
),
|
||||
warmup_ratio: Optional[float] = typer.Option(
|
||||
None, "--warmup-ratio", help="热身步数比例"
|
||||
),
|
||||
label_smoothing: Optional[float] = typer.Option(
|
||||
None, "--label-smoothing", help="标签平滑参数"
|
||||
),
|
||||
grad_accum_steps: Optional[int] = typer.Option(
|
||||
None, "--grad-accum-steps", help="梯度累积步数"
|
||||
),
|
||||
clip_grad_norm: Optional[float] = typer.Option(
|
||||
None, "--clip-grad-norm", help="梯度裁剪范数"
|
||||
),
|
||||
eval_frequency: Optional[int] = typer.Option(
|
||||
None, "--eval-frequency", help="评估频率"
|
||||
),
|
||||
save_frequency: Optional[int] = typer.Option(
|
||||
None, "--save-frequency", help="保存频率"
|
||||
),
|
||||
max_iter_length: Optional[int] = typer.Option(
|
||||
None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)"
|
||||
),
|
||||
max_seq_len: Optional[int] = typer.Option(
|
||||
None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)"
|
||||
),
|
||||
num_workers: Optional[int] = typer.Option(
|
||||
None, "--num-workers", help="数据加载worker数量"
|
||||
),
|
||||
mixed_precision: bool = typer.Option(
|
||||
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
||||
),
|
||||
use_tensorboard: bool = typer.Option(
|
||||
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
|
||||
),
|
||||
resume_from: Optional[str] = typer.Option(
|
||||
None, "--resume-from", help="从检查点恢复训练"
|
||||
),
|
||||
reset_training_state: bool = typer.Option(
|
||||
False, "--reset-training-state", help="重置训练状态"
|
||||
),
|
||||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
||||
compile: Optional[bool] = typer.Option(
|
||||
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
|
||||
),
|
||||
):
|
||||
"""
|
||||
模型扩容第二阶段训练:读取第一阶段的 expansion_info.json,加载扩容模型进行全量微调。
|
||||
命令行参数优先级高于 JSON 文件中的配置。
|
||||
"""
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
console = Console()
|
||||
|
||||
# 加载第一阶段信息
|
||||
stage1_info_path = Path(stage1_info)
|
||||
if not stage1_info_path.exists():
|
||||
console.print(
|
||||
f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
with open(stage1_info_path, "r", encoding="utf-8") as f:
|
||||
info = json.load(f)
|
||||
|
||||
# 命令行参数优先级高于 JSON 文件
|
||||
final_checkpoint = checkpoint or info["stage1_checkpoint_path"]
|
||||
final_output_dir = output_dir or info["output_dir"]
|
||||
final_train_data_path = train_data_path or info["train_data_path"]
|
||||
final_eval_data_path = eval_data_path or info["eval_data_path"]
|
||||
final_batch_size = batch_size if batch_size is not None else info["batch_size"]
|
||||
final_num_epochs = (
|
||||
num_epochs if num_epochs is not None else info.get("num_epochs", 10)
|
||||
)
|
||||
final_max_iter_length = (
|
||||
max_iter_length if max_iter_length is not None else info["max_iter_length"]
|
||||
)
|
||||
final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"]
|
||||
final_num_workers = (
|
||||
num_workers if num_workers is not None else info.get("num_workers", 2)
|
||||
)
|
||||
|
||||
# 训练参数(有默认值,不覆盖则使用默认)
|
||||
final_learning_rate = learning_rate if learning_rate is not None else 1e-4
|
||||
final_min_learning_rate = (
|
||||
min_learning_rate if min_learning_rate is not None else 1e-9
|
||||
)
|
||||
final_weight_decay = weight_decay if weight_decay is not None else 0.1
|
||||
final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1
|
||||
final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15
|
||||
final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1
|
||||
final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0
|
||||
final_eval_frequency = eval_frequency if eval_frequency is not None else 500
|
||||
final_save_frequency = save_frequency if save_frequency is not None else 1000
|
||||
|
||||
# 模型参数从 JSON 获取
|
||||
model_kwargs = info["model_kwargs"]
|
||||
if compile is not None:
|
||||
model_kwargs["compile"] = compile
|
||||
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan"
|
||||
)
|
||||
)
|
||||
|
||||
config_table = Table(show_header=True, header_style="bold magenta")
|
||||
config_table.add_column("Category", style="cyan")
|
||||
config_table.add_column("Parameter", style="green")
|
||||
config_table.add_column("Value", style="yellow")
|
||||
|
||||
config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path))
|
||||
config_table.add_row("数据", "训练数据路径", final_train_data_path)
|
||||
config_table.add_row("数据", "评估数据路径", final_eval_data_path)
|
||||
config_table.add_row("数据", "输出目录", final_output_dir)
|
||||
config_table.add_row("数据", "批次大小", str(final_batch_size))
|
||||
config_table.add_row("数据", "Worker数量", str(final_num_workers))
|
||||
|
||||
config_table.add_row("模型", "新模型规格", expand_config)
|
||||
config_table.add_row("模型", "检查点路径", final_checkpoint)
|
||||
for k, v in model_kwargs.items():
|
||||
config_table.add_row("模型", k, str(v))
|
||||
|
||||
config_table.add_row("训练", "训练轮数", str(final_num_epochs))
|
||||
config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}")
|
||||
config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}")
|
||||
config_table.add_row("训练", "权重衰减", str(final_weight_decay))
|
||||
config_table.add_row("训练", "热身比例", str(final_warmup_ratio))
|
||||
config_table.add_row("训练", "标签平滑", str(final_label_smoothing))
|
||||
config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps))
|
||||
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
|
||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||
|
||||
console.print(config_table)
|
||||
|
||||
output_path = Path(final_output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=final_train_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=final_max_iter_length,
|
||||
max_seq_length=final_max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=5000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=final_batch_size,
|
||||
num_workers=final_num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=final_max_iter_length,
|
||||
)
|
||||
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=final_eval_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=final_batch_size * 64,
|
||||
max_seq_length=final_max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=50000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=final_batch_size,
|
||||
num_workers=1,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=final_batch_size * 64,
|
||||
)
|
||||
|
||||
console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = load_expanded_model(
|
||||
base_model_path=final_checkpoint,
|
||||
new_model_spec=expand_config,
|
||||
device=device,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 全量微调:解冻所有参数
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
console.print(
|
||||
f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
|
||||
)
|
||||
console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]")
|
||||
|
||||
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size),
|
||||
output_dir=final_output_dir,
|
||||
num_epochs=final_num_epochs,
|
||||
learning_rate=final_learning_rate,
|
||||
min_learning_rate=final_min_learning_rate,
|
||||
weight_decay=final_weight_decay,
|
||||
warmup_ratio=final_warmup_ratio,
|
||||
label_smoothing=final_label_smoothing,
|
||||
grad_accum_steps=final_grad_accum_steps,
|
||||
clip_grad_norm=final_clip_grad_norm,
|
||||
eval_frequency=final_eval_frequency,
|
||||
save_frequency=final_save_frequency,
|
||||
mixed_precision=mixed_precision,
|
||||
use_tensorboard=use_tensorboard,
|
||||
status_file="training_status_finetune.json",
|
||||
)
|
||||
|
||||
console.print("[green]✓ 训练器创建完成[/green]")
|
||||
|
||||
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
|
||||
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
try:
|
||||
trainer.train(
|
||||
resume_from=resume_from, reset_training_state=reset_training_state
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("[bold green]训练被终止[/bold green]")
|
||||
trainer.save_checkpoint("interrupted_model.pt")
|
||||
|
||||
console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]")
|
||||
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
console.print(f"模型和日志保存在: {final_output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue