refactor(trainer): 优化两阶段训练器代码结构和注释格式

This commit is contained in:
songsenand 2026-04-08 06:37:47 +08:00
parent 5dda0e6f85
commit 7ac44a2731
3 changed files with 369 additions and 408 deletions

View File

@ -23,13 +23,7 @@ class BigExpert(InputMethodEngine):
if compile: if compile:
self.forward = torch.compile( self.forward = torch.compile(
self.forward, self.forward,
# mode="reduce-overhead", mode="reduce-overhead",
fullgraph=False, fullgraph=False,
dynamic=False, dynamic=False,
options={
"epilogue_fusion": True,
"max_autotune": True,
"triton.cudagraphs": True,
"reorder_for_compute_comm_overlap": False,
},
) )

View File

@ -82,15 +82,15 @@ class InputMethodEngine(nn.Module):
if compile: if compile:
self.forward = torch.compile( self.forward = torch.compile(
self.forward, self.forward,
# mode="reduce-overhead", mode="reduce-overhead",
fullgraph=False, fullgraph=False,
dynamic=False, dynamic=False,
options={ # options={
"epilogue_fusion": True, # "epilogue_fusion": True,
"max_autotune": True, # "max_autotune": True,
"triton.cudagraphs": True, # "triton.cudagraphs": True,
"reorder_for_compute_comm_overlap": False, # "reorder_for_compute_comm_overlap": False,
}, # },
) )
def forward( def forward(

View File

@ -45,7 +45,7 @@ class Trainer:
- CrossEntropyLoss损失函数支持weight和label_smoothing - CrossEntropyLoss损失函数支持weight和label_smoothing
- Rich终端美化输出 - Rich终端美化输出
""" """
training_status_data: List[Dict[str, Any]] training_status_data: List[Dict[str, Any]]
def __init__( def __init__(
@ -696,25 +696,27 @@ def load_expanded_model(
) -> nn.Module: ) -> nn.Module:
""" """
加载预训练基础模型并创建扩容后的新模型冻结匹配的层 加载预训练基础模型并创建扩容后的新模型冻结匹配的层
Args: Args:
base_model_path: 预训练基础模型检查点路径 base_model_path: 预训练基础模型检查点路径
new_model_spec: 新模型规格格式 "module:ClassName" "new_model:NewModel" new_model_spec: 新模型规格格式 "module:ClassName" "new_model:NewModel"
device: 设备 device: 设备
**model_kwargs: 传递给新模型构造函数的参数 **model_kwargs: 传递给新模型构造函数的参数
Returns: Returns:
扩容后的新模型匹配的层已冻结 扩容后的新模型匹配的层已冻结
""" """
import importlib import importlib
import sys import sys
# 解析新模型规格 # 解析新模型规格
if ":" not in new_model_spec: if ":" not in new_model_spec:
raise ValueError(f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'") raise ValueError(
f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'"
)
module_name, class_name = new_model_spec.split(":", 1) module_name, class_name = new_model_spec.split(":", 1)
# 导入模块(支持任意路径) # 导入模块(支持任意路径)
module = None module = None
try: try:
@ -726,6 +728,7 @@ def load_expanded_model(
# 将模块名转换为可能的文件路径 # 将模块名转换为可能的文件路径
module_path = module_name.replace(".", "/") + ".py" module_path = module_name.replace(".", "/") + ".py"
import importlib.util import importlib.util
spec = importlib.util.spec_from_file_location(module_name, module_path) spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None: if spec is None or spec.loader is None:
raise ImportError(f"Cannot find module or loader: {module_name}") raise ImportError(f"Cannot find module or loader: {module_name}")
@ -734,43 +737,47 @@ def load_expanded_model(
except Exception as e: except Exception as e:
# 尝试在当前目录下查找 # 尝试在当前目录下查找
import os import os
if os.path.exists(module_name + ".py"): if os.path.exists(module_name + ".py"):
spec = importlib.util.spec_from_file_location(module_name, module_name + ".py") spec = importlib.util.spec_from_file_location(
module_name, module_name + ".py"
)
if spec is None or spec.loader is None: if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from file: {module_name}.py") raise ImportError(f"Cannot load module from file: {module_name}.py")
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore spec.loader.exec_module(module) # type: ignore
else: else:
raise ImportError(f"Failed to import module '{module_name}': {e}") raise ImportError(f"Failed to import module '{module_name}': {e}")
if module is None: if module is None:
raise ImportError(f"Module '{module_name}' could not be imported") raise ImportError(f"Module '{module_name}' could not be imported")
# 获取模型类 # 获取模型类
model_class = getattr(module, class_name) model_class = getattr(module, class_name)
# 检查模型类是否是 InputMethodEngine 的子类 # 检查模型类是否是 InputMethodEngine 的子类
from .model import InputMethodEngine from .model import InputMethodEngine
if not issubclass(model_class, InputMethodEngine): if not issubclass(model_class, InputMethodEngine):
raise TypeError( raise TypeError(
f"Model class {class_name} must be a subclass of InputMethodEngine. " f"Model class {class_name} must be a subclass of InputMethodEngine. "
f"Got {model_class.__name__} instead." f"Got {model_class.__name__} instead."
) )
# 创建新模型 # 创建新模型
new_model = model_class(**model_kwargs) new_model = model_class(**model_kwargs)
new_model.to(device) new_model.to(device)
# 加载预训练权重 # 加载预训练权重
checkpoint = torch.load(base_model_path, map_location=device) checkpoint = torch.load(base_model_path, map_location=device)
if "model_state_dict" in checkpoint: if "model_state_dict" in checkpoint:
pretrained_state_dict = checkpoint["model_state_dict"] pretrained_state_dict = checkpoint["model_state_dict"]
else: else:
pretrained_state_dict = checkpoint pretrained_state_dict = checkpoint
# 获取新模型的状态字典 # 获取新模型的状态字典
new_state_dict = new_model.state_dict() new_state_dict = new_model.state_dict()
# 冻结匹配的层 # 冻结匹配的层
frozen_layers = [] frozen_layers = []
for key in new_state_dict.keys(): for key in new_state_dict.keys():
@ -778,354 +785,23 @@ def load_expanded_model(
if new_state_dict[key].shape == pretrained_state_dict[key].shape: if new_state_dict[key].shape == pretrained_state_dict[key].shape:
new_state_dict[key] = pretrained_state_dict[key].to(device) new_state_dict[key] = pretrained_state_dict[key].to(device)
frozen_layers.append(key) frozen_layers.append(key)
# 加载更新后的状态字典 # 加载更新后的状态字典
new_model.load_state_dict(new_state_dict) new_model.load_state_dict(new_state_dict)
# 设置参数 requires_grad # 设置参数 requires_grad
for name, param in new_model.named_parameters(): for name, param in new_model.named_parameters():
if name in frozen_layers: if name in frozen_layers:
param.requires_grad = False param.requires_grad = False
logger.info(f"Loaded expanded model with {len(frozen_layers)} frozen layers") logger.info(f"Loaded expanded model with {len(frozen_layers)} frozen layers")
logger.info(f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}") logger.info(
f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}"
)
return new_model return new_model
class TwoStageTrainer(Trainer):
"""
两阶段训练器先冻结匹配层训练然后全量微调
"""
def __init__(
self,
model: nn.Module,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
total_steps: int,
output_dir: str = "./output",
num_epochs: int = 10,
learning_rate: float = 1e-4,
min_learning_rate: float = 1e-6,
weight_decay: float = 0.1,
warmup_ratio: float = 0.1,
label_smoothing: float = 0.15,
loss_weight: Optional[torch.Tensor] = None,
grad_accum_steps: int = 1,
clip_grad_norm: float = 1.0,
eval_frequency: int = 500,
save_frequency: int = 10000,
mixed_precision: bool = True,
device: Optional[torch.device] = None,
status_file: str = "training_status.json",
use_tensorboard: bool = True,
# 两阶段训练特有参数
frozen_patience: int = 10,
frozen_lr: Optional[float] = None,
full_lr: Optional[float] = None,
frozen_scheduler: str = "cosine",
full_scheduler: str = "cosine",
):
"""
初始化两阶段训练器
Args:
frozen_patience: 冻结阶段验证损失连续不下降的epoch数触发切换到全量微调
frozen_lr: 冻结阶段学习率如果为None则使用learning_rate
full_lr: 全量微调阶段学习率如果为None则使用learning_rate
frozen_scheduler: 冻结阶段学习率调度器类型"cosine""plateau"
full_scheduler: 全量微调阶段学习率调度器类型"cosine""plateau"
"""
super().__init__(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
label_smoothing=label_smoothing,
loss_weight=loss_weight,
grad_accum_steps=grad_accum_steps,
clip_grad_norm=clip_grad_norm,
eval_frequency=eval_frequency,
save_frequency=save_frequency,
mixed_precision=mixed_precision,
device=device,
status_file=status_file,
use_tensorboard=use_tensorboard,
)
# 两阶段训练参数
self.frozen_patience = frozen_patience
self.frozen_lr = frozen_lr if frozen_lr is not None else learning_rate
self.full_lr = full_lr if full_lr is not None else learning_rate
self.frozen_scheduler = frozen_scheduler
self.full_scheduler = full_scheduler
# 训练状态
self.current_stage = "frozen" # "frozen" 或 "full"
self.frozen_best_loss = float("inf")
self.frozen_patience_counter = 0
logger.info(f"TwoStageTrainer initialized with frozen_patience={frozen_patience}")
logger.info(f"Stage: {self.current_stage}, Frozen LR: {self.frozen_lr:.2e}, Full LR: {self.full_lr:.2e}")
# 覆盖父类的学习率调度器为冻结阶段调度器
self.lr_scheduler = self._create_stage_lr_scheduler("frozen")
def _create_stage_lr_scheduler(self, stage: str) -> Callable[[int], float]:
"""创建阶段特定的学习率调度函数"""
if stage == "frozen":
base_lr = self.frozen_lr
scheduler_type = self.frozen_scheduler
else:
base_lr = self.full_lr
scheduler_type = self.full_scheduler
# 捕获局部变量以避免闭包中的self引用问题
warmup_steps = self.warmup_steps
total_steps = self.total_steps
min_learning_rate = self.min_learning_rate
def lr_scheduler(step: int) -> float:
if step < warmup_steps:
# 线性预热
return base_lr * (step / warmup_steps)
else:
if scheduler_type == "cosine":
# 余弦退火
progress = (step - warmup_steps) / (
total_steps - warmup_steps
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
decayed_lr = (
min_learning_rate
+ (base_lr - min_learning_rate) * cosine_decay
)
return decayed_lr
elif scheduler_type == "plateau":
# 保持恒定学习率plateau调度需要在训练循环中实现
return base_lr
else:
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
return lr_scheduler
def _switch_to_full_stage(self):
"""切换到全量微调阶段"""
if self.current_stage == "full":
return
logger.info("Switching to full fine-tuning stage")
self.current_stage = "full"
# 解冻所有参数
for param in self.model.parameters():
param.requires_grad = True
# 更新学习率调度器
self.learning_rate = self.full_lr
self.lr_scheduler = self._create_stage_lr_scheduler("full")
# 重置优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=self.full_lr,
weight_decay=self.weight_decay,
betas=(0.9, 0.999),
eps=1e-8,
)
# 重置训练状态
self.frozen_best_loss = float("inf")
self.frozen_patience_counter = 0
logger.info(f"All layers unfrozen, using full LR: {self.full_lr:.2e}")
def _update_stage_after_eval(self, eval_loss: float):
"""根据评估结果更新训练阶段"""
if self.current_stage == "frozen":
# 检查是否应该切换到全量微调
if eval_loss < self.frozen_best_loss:
self.frozen_best_loss = eval_loss
self.frozen_patience_counter = 0
logger.info(f"Frozen stage new best loss: {eval_loss:.4f}")
else:
self.frozen_patience_counter += 1
logger.info(f"Frozen stage patience counter: {self.frozen_patience_counter}/{self.frozen_patience}")
# 如果达到耐心值,切换到全量微调
if self.frozen_patience_counter >= self.frozen_patience:
self._switch_to_full_stage()
def train(
self, resume_from: Optional[str] = None, reset_training_state: bool = False
):
"""
两阶段训练循环
Args:
resume_from: 从哪个检查点恢复训练可选
reset_training_state: 是否重置训练状态只加载模型权重从头开始训练
"""
# 如果提供了检查点,则恢复训练
if resume_from is not None:
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
# 打印训练信息
self._print_training_info()
# 初始化训练状态
global_step = self.current_step
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 创建进度条
with self._create_progress_bar() as progress:
epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs} (Stage: {self.current_stage})",
total=self.total_steps,
)
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch
for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率
current_lr = self._update_learning_rate()
# 训练步骤
loss, metrics = self.train_step(batch)
# 累积指标
accumulated_loss += loss
accumulated_accuracy += metrics.get("accuracy", 0.0)
accumulation_counter += 1
# 梯度累积每grad_accum_steps步更新一次参数
if (global_step + 1) % self.grad_accum_steps == 0:
# 梯度裁剪
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.clip_grad_norm
)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# 更新进度条
progress.update(
epoch_task,
advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} (Stage: {self.current_stage}) | "
f"Step {global_step}/{self.total_steps} | "
f"Loss: {loss:.4f} | "
f"LR: {current_lr:.2e}",
)
# 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标
avg_loss = accumulated_loss / accumulation_counter
avg_accuracy = accumulated_accuracy / accumulation_counter
# 评估模型
eval_metrics = self.evaluate()
# 准备日志指标
log_metrics = {
"train/loss": avg_loss,
"train/accuracy": avg_accuracy,
"train/learning_rate": current_lr,
"train/stage": 0.0 if self.current_stage == "frozen" else 1.0,
}
if eval_metrics:
log_metrics.update(
{
"eval/loss": eval_metrics["eval_loss"],
"eval/accuracy": eval_metrics["eval_accuracy"],
}
)
# 更新最佳模型(全局)
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
# 只保存best_model不创建额外的checkpoint文件
self.save_checkpoint("best_model.pt", is_best=True)
# 更新训练阶段
self._update_stage_after_eval(eval_metrics["eval_loss"])
# 记录到TensorBoard
self._log_to_tensorboard(log_metrics, global_step)
# 打印日志
log_text = (
f"[Epoch {epoch + 1}/{self.num_epochs}] "
f"[Stage: {self.current_stage}] "
f"[Step {global_step}/{self.total_steps}] "
f"Train Loss: {avg_loss:.4f} | "
f"Train Acc: {avg_accuracy:.4f} | "
f"LR: {current_lr:.2e}"
)
if eval_metrics:
log_text += (
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
)
progress.console.log(log_text)
# 重置累积指标
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 定期保存检查点(覆盖之前的定期检查点)
if (global_step + 1) % self.save_frequency == 0:
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
# 更新步数
global_step += 1
self.current_step = global_step
# 检查是否达到总步数
if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break
# 重置进度条
progress.reset(epoch_task)
# 每个epoch结束后保存检查点
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
# 检查是否达到总步数
if global_step >= self.total_steps:
break
# 训练完成
logger.info("Two-stage training completed!")
# 保存最终模型
self.save_checkpoint("final_model.pt")
# 关闭TensorBoard写入器
if self.writer is not None:
self.writer.close()
def worker_init_fn(worker_id: int) -> None: def worker_init_fn(worker_id: int) -> None:
""" """
初始化每个DataLoader worker的随机种子确保可复现性 初始化每个DataLoader worker的随机种子确保可复现性
@ -1537,7 +1213,10 @@ def expand_and_train(
..., "--base-model-path", help="预训练基础模型检查点路径" ..., "--base-model-path", help="预训练基础模型检查点路径"
), ),
new_model_spec: str = typer.Option( new_model_spec: str = typer.Option(
..., "--new-model-spec", "-m", help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类" ...,
"--new-model-spec",
"-m",
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
), ),
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option( pinyin_vocab_size: int = typer.Option(
@ -1555,19 +1234,19 @@ def expand_and_train(
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 两阶段训练参数 # 两阶段训练参数
frozen_patience: int = typer.Option( frozen_patience: int = typer.Option(
10, "--frozen-patience", help="冻结阶段验证损失连续不下降的epoch数触发切换到全量微调" 10,
), "--frozen-patience",
frozen_lr: float = typer.Option( help="冻结阶段验证损失连续不下降的epoch数触发切换到全量微调",
1e-3, "--frozen-lr", help="冻结阶段学习率"
),
full_lr: float = typer.Option(
1e-4, "--full-lr", help="全量微调阶段学习率"
), ),
frozen_lr: float = typer.Option(1e-3, "--frozen-lr", help="冻结阶段学习率"),
full_lr: float = typer.Option(1e-4, "--full-lr", help="全量微调阶段学习率"),
frozen_scheduler: str = typer.Option( frozen_scheduler: str = typer.Option(
"cosine", "--frozen-scheduler", help="冻结阶段学习率调度器类型cosine或plateau" "cosine", "--frozen-scheduler", help="冻结阶段学习率调度器类型cosine或plateau"
), ),
full_scheduler: str = typer.Option( full_scheduler: str = typer.Option(
"cosine", "--full-scheduler", help="全量微调阶段学习率调度器类型cosine或plateau" "cosine",
"--full-scheduler",
help="全量微调阶段学习率调度器类型cosine或plateau",
), ),
# 训练参数 # 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
@ -1607,9 +1286,6 @@ def expand_and_train(
help="是否开启 torch.compile 优化(需 PyTorch 2.0+", help="是否开启 torch.compile 优化(需 PyTorch 2.0+",
), ),
): ):
"""
模型扩容两阶段训练先冻结匹配层训练然后全量微调
"""
torch.multiprocessing.set_sharing_strategy("file_system") torch.multiprocessing.set_sharing_strategy("file_system")
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
@ -1625,7 +1301,9 @@ def expand_and_train(
# 打印配置信息 # 打印配置信息
console.print( console.print(
Panel.fit("[bold cyan]模型扩容两阶段训练配置[/bold cyan]", border_style="cyan") Panel.fit(
"[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan"
)
) )
config_table = Table(show_header=True, header_style="bold magenta") config_table = Table(show_header=True, header_style="bold magenta")
@ -1652,13 +1330,8 @@ def expand_and_train(
config_table.add_row("模型", "使用拼音", str(use_pinyin)) config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("模型", "编译优化", str(compile)) config_table.add_row("模型", "编译优化", str(compile))
config_table.add_row("两阶段训练", "冻结阶段耐心值", str(frozen_patience))
config_table.add_row("两阶段训练", "冻结阶段学习率", f"{frozen_lr:.2e}")
config_table.add_row("两阶段训练", "全量阶段学习率", f"{full_lr:.2e}")
config_table.add_row("两阶段训练", "冻结阶段调度器", frozen_scheduler)
config_table.add_row("两阶段训练", "全量阶段调度器", full_scheduler)
config_table.add_row("训练", "训练轮数", str(num_epochs)) config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(weight_decay)) config_table.add_row("训练", "权重衰减", str(weight_decay))
config_table.add_row("训练", "热身比例", str(warmup_ratio)) config_table.add_row("训练", "热身比例", str(warmup_ratio))
@ -1689,14 +1362,10 @@ def expand_and_train(
"num_experts": num_experts, "num_experts": num_experts,
"max_seq_len": max_seq_len, "max_seq_len": max_seq_len,
"use_pinyin": use_pinyin, "use_pinyin": use_pinyin,
"frozen_patience": frozen_patience,
"frozen_lr": frozen_lr,
"full_lr": full_lr,
"frozen_scheduler": frozen_scheduler,
"full_scheduler": full_scheduler,
"batch_size": batch_size, "batch_size": batch_size,
"num_workers": num_workers, "num_workers": num_workers,
"num_epochs": num_epochs, "num_epochs": num_epochs,
"learning_rate": learning_rate,
"min_learning_rate": min_learning_rate, "min_learning_rate": min_learning_rate,
"weight_decay": weight_decay, "weight_decay": weight_decay,
"warmup_ratio": warmup_ratio, "warmup_ratio": warmup_ratio,
@ -1765,7 +1434,7 @@ def expand_and_train(
# 创建扩容模型 # 创建扩容模型
console.print("[bold cyan]正在创建扩容模型...[/bold cyan]") console.print("[bold cyan]正在创建扩容模型...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_kwargs = { model_kwargs = {
"vocab_size": vocab_size, "vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size, "pinyin_vocab_size": pinyin_vocab_size,
@ -1777,7 +1446,7 @@ def expand_and_train(
"max_seq_len": max_seq_len, "max_seq_len": max_seq_len,
"compile": compile, "compile": compile,
} }
model = load_expanded_model( model = load_expanded_model(
base_model_path=base_model_path, base_model_path=base_model_path,
new_model_spec=new_model_spec, new_model_spec=new_model_spec,
@ -1788,24 +1457,24 @@ def expand_and_train(
console.print( console.print(
f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
) )
# 统计冻结参数比例 # 统计冻结参数比例
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
console.print( console.print(
f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params/total_params*100:.1f}%)[/green]" f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]"
) )
# 创建两阶段训练器 # 创建训练器(使用普通 Trainer只进行第一阶段冻结训练
console.print("[bold cyan]正在创建两阶段训练器...[/bold cyan]") console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = TwoStageTrainer( trainer = Trainer(
model=model, model=model,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader, eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length * num_epochs / batch_size), total_steps=int(max_iter_length * num_epochs / batch_size),
output_dir=output_dir, output_dir=output_dir,
num_epochs=num_epochs, num_epochs=num_epochs,
learning_rate=frozen_lr, # 初始学习率会被阶段特定LR覆盖 learning_rate=learning_rate,
min_learning_rate=min_learning_rate, min_learning_rate=min_learning_rate,
weight_decay=weight_decay, weight_decay=weight_decay,
warmup_ratio=warmup_ratio, warmup_ratio=warmup_ratio,
@ -1817,18 +1486,12 @@ def expand_and_train(
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard, use_tensorboard=use_tensorboard,
status_file="training_status.json", status_file="training_status.json",
# 两阶段训练特有参数
frozen_patience=frozen_patience,
frozen_lr=frozen_lr,
full_lr=full_lr,
frozen_scheduler=frozen_scheduler,
full_scheduler=full_scheduler,
) )
console.print("[green]✓ 两阶段训练器创建完成[/green]") console.print("[green]✓ 训练器创建完成[/green]")
# 开始训练 # 开始训练
console.print("\n[bold cyan]开始阶段训练...[/bold cyan]") console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try: try:
trainer.train( trainer.train(
@ -1838,9 +1501,313 @@ def expand_and_train(
console.print("[bold green]训练被终止[/bold green]") console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt") trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 两阶段训练完成![/bold green]") # 保存扩容信息供第二阶段使用
expansion_info = {
"stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"),
"model_spec": new_model_spec,
"model_kwargs": model_kwargs,
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"batch_size": batch_size,
"max_iter_length": max_iter_length,
"max_seq_len": max_seq_len,
"num_workers": num_workers,
}
expansion_info_file = output_path / "expansion_info.json"
with open(expansion_info_file, "w", encoding="utf-8") as f:
json.dump(expansion_info, f, indent=2, ensure_ascii=False)
logger.info(f"Expansion info saved to {expansion_info_file}")
console.print("[bold green]✓ 第一阶段训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {output_dir}") console.print(f"模型和日志保存在: {output_dir}")
console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]")
console.print(
"[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]"
)
@app.command()
def expand_finetune(
expand_config: str = typer.Option(
...,
"--expand-config",
"-c",
help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'",
),
stage1_info: str = typer.Option(
..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径"
),
# 可选覆盖参数
checkpoint: Optional[str] = typer.Option(
None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)"
),
output_dir: Optional[str] = typer.Option(
None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)"
),
train_data_path: Optional[str] = typer.Option(
None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)"
),
eval_data_path: Optional[str] = typer.Option(
None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)"
),
batch_size: Optional[int] = typer.Option(
None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)"
),
num_epochs: Optional[int] = typer.Option(
None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)"
),
learning_rate: Optional[float] = typer.Option(
None, "--learning-rate", "-lr", help="学习率"
),
min_learning_rate: Optional[float] = typer.Option(
None, "--min-learning-rate", help="最小学习率"
),
weight_decay: Optional[float] = typer.Option(
None, "--weight-decay", help="权重衰减"
),
warmup_ratio: Optional[float] = typer.Option(
None, "--warmup-ratio", help="热身步数比例"
),
label_smoothing: Optional[float] = typer.Option(
None, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: Optional[int] = typer.Option(
None, "--grad-accum-steps", help="梯度累积步数"
),
clip_grad_norm: Optional[float] = typer.Option(
None, "--clip-grad-norm", help="梯度裁剪范数"
),
eval_frequency: Optional[int] = typer.Option(
None, "--eval-frequency", help="评估频率"
),
save_frequency: Optional[int] = typer.Option(
None, "--save-frequency", help="保存频率"
),
max_iter_length: Optional[int] = typer.Option(
None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)"
),
max_seq_len: Optional[int] = typer.Option(
None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)"
),
num_workers: Optional[int] = typer.Option(
None, "--num-workers", help="数据加载worker数量"
),
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: Optional[bool] = typer.Option(
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
),
):
"""
模型扩容第二阶段训练读取第一阶段的 expansion_info.json加载扩容模型进行全量微调
命令行参数优先级高于 JSON 文件中的配置
"""
torch.multiprocessing.set_sharing_strategy("file_system")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 加载第一阶段信息
stage1_info_path = Path(stage1_info)
if not stage1_info_path.exists():
console.print(
f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]"
)
raise typer.Exit(1)
with open(stage1_info_path, "r", encoding="utf-8") as f:
info = json.load(f)
# 命令行参数优先级高于 JSON 文件
final_checkpoint = checkpoint or info["stage1_checkpoint_path"]
final_output_dir = output_dir or info["output_dir"]
final_train_data_path = train_data_path or info["train_data_path"]
final_eval_data_path = eval_data_path or info["eval_data_path"]
final_batch_size = batch_size if batch_size is not None else info["batch_size"]
final_num_epochs = (
num_epochs if num_epochs is not None else info.get("num_epochs", 10)
)
final_max_iter_length = (
max_iter_length if max_iter_length is not None else info["max_iter_length"]
)
final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"]
final_num_workers = (
num_workers if num_workers is not None else info.get("num_workers", 2)
)
# 训练参数(有默认值,不覆盖则使用默认)
final_learning_rate = learning_rate if learning_rate is not None else 1e-4
final_min_learning_rate = (
min_learning_rate if min_learning_rate is not None else 1e-9
)
final_weight_decay = weight_decay if weight_decay is not None else 0.1
final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1
final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15
final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1
final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0
final_eval_frequency = eval_frequency if eval_frequency is not None else 500
final_save_frequency = save_frequency if save_frequency is not None else 1000
# 模型参数从 JSON 获取
model_kwargs = info["model_kwargs"]
if compile is not None:
model_kwargs["compile"] = compile
console.print(
Panel.fit(
"[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan"
)
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path))
config_table.add_row("数据", "训练数据路径", final_train_data_path)
config_table.add_row("数据", "评估数据路径", final_eval_data_path)
config_table.add_row("数据", "输出目录", final_output_dir)
config_table.add_row("数据", "批次大小", str(final_batch_size))
config_table.add_row("数据", "Worker数量", str(final_num_workers))
config_table.add_row("模型", "新模型规格", expand_config)
config_table.add_row("模型", "检查点路径", final_checkpoint)
for k, v in model_kwargs.items():
config_table.add_row("模型", k, str(v))
config_table.add_row("训练", "训练轮数", str(final_num_epochs))
config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(final_weight_decay))
config_table.add_row("训练", "热身比例", str(final_warmup_ratio))
config_table.add_row("训练", "标签平滑", str(final_label_smoothing))
config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
console.print(config_table)
output_path = Path(final_output_dir)
output_path.mkdir(parents=True, exist_ok=True)
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
train_dataset = PinyinInputDataset(
data_path=final_train_data_path,
max_workers=-1,
max_iter_length=final_max_iter_length,
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=final_batch_size,
num_workers=final_num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=final_max_iter_length,
)
eval_dataset = PinyinInputDataset(
data_path=final_eval_data_path,
max_workers=-1,
max_iter_length=final_batch_size * 64,
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=final_batch_size,
num_workers=1,
pin_memory=torch.cuda.is_available(),
max_iter_length=final_batch_size * 64,
)
console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_expanded_model(
base_model_path=final_checkpoint,
new_model_spec=expand_config,
device=device,
**model_kwargs,
)
# 全量微调:解冻所有参数
for param in model.parameters():
param.requires_grad = True
console.print(
f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]")
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size),
output_dir=final_output_dir,
num_epochs=final_num_epochs,
learning_rate=final_learning_rate,
min_learning_rate=final_min_learning_rate,
weight_decay=final_weight_decay,
warmup_ratio=final_warmup_ratio,
label_smoothing=final_label_smoothing,
grad_accum_steps=final_grad_accum_steps,
clip_grad_norm=final_clip_grad_norm,
eval_frequency=final_eval_frequency,
save_frequency=final_save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
status_file="training_status_finetune.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {final_output_dir}")
if __name__ == "__main__": if __name__ == "__main__":