From 7ac44a273171d767b101f5cc641997573ed3d35d Mon Sep 17 00:00:00 2001 From: songsenand Date: Wed, 8 Apr 2026 06:37:47 +0800 Subject: [PATCH] =?UTF-8?q?refactor(trainer):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=B8=A4=E9=98=B6=E6=AE=B5=E8=AE=AD=E7=BB=83=E5=99=A8=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E7=BB=93=E6=9E=84=E5=92=8C=E6=B3=A8=E9=87=8A=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- big_expert.py | 8 +- src/model/model.py | 14 +- src/model/trainer.py | 755 +++++++++++++++++++++---------------------- 3 files changed, 369 insertions(+), 408 deletions(-) diff --git a/big_expert.py b/big_expert.py index 7837e93..021c2cd 100644 --- a/big_expert.py +++ b/big_expert.py @@ -23,13 +23,7 @@ class BigExpert(InputMethodEngine): if compile: self.forward = torch.compile( self.forward, - # mode="reduce-overhead", + mode="reduce-overhead", fullgraph=False, dynamic=False, - options={ - "epilogue_fusion": True, - "max_autotune": True, - "triton.cudagraphs": True, - "reorder_for_compute_comm_overlap": False, - }, ) diff --git a/src/model/model.py b/src/model/model.py index f834843..082ad32 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -82,15 +82,15 @@ class InputMethodEngine(nn.Module): if compile: self.forward = torch.compile( self.forward, - # mode="reduce-overhead", + mode="reduce-overhead", fullgraph=False, dynamic=False, - options={ - "epilogue_fusion": True, - "max_autotune": True, - "triton.cudagraphs": True, - "reorder_for_compute_comm_overlap": False, - }, + # options={ + # "epilogue_fusion": True, + # "max_autotune": True, + # "triton.cudagraphs": True, + # "reorder_for_compute_comm_overlap": False, + # }, ) def forward( diff --git a/src/model/trainer.py b/src/model/trainer.py index e5b6dab..7b384d6 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -45,7 +45,7 @@ class Trainer: - CrossEntropyLoss损失函数(支持weight和label_smoothing) - Rich终端美化输出 """ - + training_status_data: List[Dict[str, Any]] def __init__( @@ -696,25 +696,27 @@ def load_expanded_model( ) -> nn.Module: """ 加载预训练基础模型并创建扩容后的新模型,冻结匹配的层。 - + Args: base_model_path: 预训练基础模型检查点路径 new_model_spec: 新模型规格,格式 "module:ClassName",如 "new_model:NewModel" device: 设备 **model_kwargs: 传递给新模型构造函数的参数 - + Returns: 扩容后的新模型,匹配的层已冻结 """ import importlib import sys - + # 解析新模型规格 if ":" not in new_model_spec: - raise ValueError(f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'") - + raise ValueError( + f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'" + ) + module_name, class_name = new_model_spec.split(":", 1) - + # 导入模块(支持任意路径) module = None try: @@ -726,6 +728,7 @@ def load_expanded_model( # 将模块名转换为可能的文件路径 module_path = module_name.replace(".", "/") + ".py" import importlib.util + spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is None or spec.loader is None: raise ImportError(f"Cannot find module or loader: {module_name}") @@ -734,43 +737,47 @@ def load_expanded_model( except Exception as e: # 尝试在当前目录下查找 import os + if os.path.exists(module_name + ".py"): - spec = importlib.util.spec_from_file_location(module_name, module_name + ".py") + spec = importlib.util.spec_from_file_location( + module_name, module_name + ".py" + ) if spec is None or spec.loader is None: raise ImportError(f"Cannot load module from file: {module_name}.py") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore else: raise ImportError(f"Failed to import module '{module_name}': {e}") - + if module is None: raise ImportError(f"Module '{module_name}' could not be imported") - + # 获取模型类 model_class = getattr(module, class_name) - + # 检查模型类是否是 InputMethodEngine 的子类 from .model import InputMethodEngine + if not issubclass(model_class, InputMethodEngine): raise TypeError( f"Model class {class_name} must be a subclass of InputMethodEngine. " f"Got {model_class.__name__} instead." ) - + # 创建新模型 new_model = model_class(**model_kwargs) new_model.to(device) - + # 加载预训练权重 checkpoint = torch.load(base_model_path, map_location=device) if "model_state_dict" in checkpoint: pretrained_state_dict = checkpoint["model_state_dict"] else: pretrained_state_dict = checkpoint - + # 获取新模型的状态字典 new_state_dict = new_model.state_dict() - + # 冻结匹配的层 frozen_layers = [] for key in new_state_dict.keys(): @@ -778,354 +785,23 @@ def load_expanded_model( if new_state_dict[key].shape == pretrained_state_dict[key].shape: new_state_dict[key] = pretrained_state_dict[key].to(device) frozen_layers.append(key) - + # 加载更新后的状态字典 new_model.load_state_dict(new_state_dict) - + # 设置参数 requires_grad for name, param in new_model.named_parameters(): if name in frozen_layers: param.requires_grad = False - + logger.info(f"Loaded expanded model with {len(frozen_layers)} frozen layers") - logger.info(f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}") - + logger.info( + f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}" + ) + return new_model -class TwoStageTrainer(Trainer): - """ - 两阶段训练器:先冻结匹配层训练,然后全量微调。 - """ - - def __init__( - self, - model: nn.Module, - train_dataloader: DataLoader, - eval_dataloader: DataLoader, - total_steps: int, - output_dir: str = "./output", - num_epochs: int = 10, - learning_rate: float = 1e-4, - min_learning_rate: float = 1e-6, - weight_decay: float = 0.1, - warmup_ratio: float = 0.1, - label_smoothing: float = 0.15, - loss_weight: Optional[torch.Tensor] = None, - grad_accum_steps: int = 1, - clip_grad_norm: float = 1.0, - eval_frequency: int = 500, - save_frequency: int = 10000, - mixed_precision: bool = True, - device: Optional[torch.device] = None, - status_file: str = "training_status.json", - use_tensorboard: bool = True, - # 两阶段训练特有参数 - frozen_patience: int = 10, - frozen_lr: Optional[float] = None, - full_lr: Optional[float] = None, - frozen_scheduler: str = "cosine", - full_scheduler: str = "cosine", - ): - """ - 初始化两阶段训练器 - - Args: - frozen_patience: 冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调 - frozen_lr: 冻结阶段学习率,如果为None则使用learning_rate - full_lr: 全量微调阶段学习率,如果为None则使用learning_rate - frozen_scheduler: 冻结阶段学习率调度器类型,"cosine"或"plateau" - full_scheduler: 全量微调阶段学习率调度器类型,"cosine"或"plateau" - """ - super().__init__( - model=model, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - total_steps=total_steps, - output_dir=output_dir, - num_epochs=num_epochs, - learning_rate=learning_rate, - min_learning_rate=min_learning_rate, - weight_decay=weight_decay, - warmup_ratio=warmup_ratio, - label_smoothing=label_smoothing, - loss_weight=loss_weight, - grad_accum_steps=grad_accum_steps, - clip_grad_norm=clip_grad_norm, - eval_frequency=eval_frequency, - save_frequency=save_frequency, - mixed_precision=mixed_precision, - device=device, - status_file=status_file, - use_tensorboard=use_tensorboard, - ) - - # 两阶段训练参数 - self.frozen_patience = frozen_patience - self.frozen_lr = frozen_lr if frozen_lr is not None else learning_rate - self.full_lr = full_lr if full_lr is not None else learning_rate - self.frozen_scheduler = frozen_scheduler - self.full_scheduler = full_scheduler - - # 训练状态 - self.current_stage = "frozen" # "frozen" 或 "full" - self.frozen_best_loss = float("inf") - self.frozen_patience_counter = 0 - - logger.info(f"TwoStageTrainer initialized with frozen_patience={frozen_patience}") - logger.info(f"Stage: {self.current_stage}, Frozen LR: {self.frozen_lr:.2e}, Full LR: {self.full_lr:.2e}") - - # 覆盖父类的学习率调度器为冻结阶段调度器 - self.lr_scheduler = self._create_stage_lr_scheduler("frozen") - - def _create_stage_lr_scheduler(self, stage: str) -> Callable[[int], float]: - """创建阶段特定的学习率调度函数""" - if stage == "frozen": - base_lr = self.frozen_lr - scheduler_type = self.frozen_scheduler - else: - base_lr = self.full_lr - scheduler_type = self.full_scheduler - - # 捕获局部变量以避免闭包中的self引用问题 - warmup_steps = self.warmup_steps - total_steps = self.total_steps - min_learning_rate = self.min_learning_rate - - def lr_scheduler(step: int) -> float: - if step < warmup_steps: - # 线性预热 - return base_lr * (step / warmup_steps) - else: - if scheduler_type == "cosine": - # 余弦退火 - progress = (step - warmup_steps) / ( - total_steps - warmup_steps - ) - cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) - decayed_lr = ( - min_learning_rate - + (base_lr - min_learning_rate) * cosine_decay - ) - return decayed_lr - elif scheduler_type == "plateau": - # 保持恒定学习率(plateau调度需要在训练循环中实现) - return base_lr - else: - raise ValueError(f"Unknown scheduler type: {scheduler_type}") - - return lr_scheduler - - def _switch_to_full_stage(self): - """切换到全量微调阶段""" - if self.current_stage == "full": - return - - logger.info("Switching to full fine-tuning stage") - self.current_stage = "full" - - # 解冻所有参数 - for param in self.model.parameters(): - param.requires_grad = True - - # 更新学习率调度器 - self.learning_rate = self.full_lr - self.lr_scheduler = self._create_stage_lr_scheduler("full") - - # 重置优化器 - self.optimizer = optim.AdamW( - self.model.parameters(), - lr=self.full_lr, - weight_decay=self.weight_decay, - betas=(0.9, 0.999), - eps=1e-8, - ) - - # 重置训练状态 - self.frozen_best_loss = float("inf") - self.frozen_patience_counter = 0 - - logger.info(f"All layers unfrozen, using full LR: {self.full_lr:.2e}") - - def _update_stage_after_eval(self, eval_loss: float): - """根据评估结果更新训练阶段""" - if self.current_stage == "frozen": - # 检查是否应该切换到全量微调 - if eval_loss < self.frozen_best_loss: - self.frozen_best_loss = eval_loss - self.frozen_patience_counter = 0 - logger.info(f"Frozen stage new best loss: {eval_loss:.4f}") - else: - self.frozen_patience_counter += 1 - logger.info(f"Frozen stage patience counter: {self.frozen_patience_counter}/{self.frozen_patience}") - - # 如果达到耐心值,切换到全量微调 - if self.frozen_patience_counter >= self.frozen_patience: - self._switch_to_full_stage() - - def train( - self, resume_from: Optional[str] = None, reset_training_state: bool = False - ): - """ - 两阶段训练循环 - - Args: - resume_from: 从哪个检查点恢复训练(可选) - reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练) - """ - # 如果提供了检查点,则恢复训练 - if resume_from is not None: - self.load_checkpoint(resume_from, reset_training_state=reset_training_state) - - # 打印训练信息 - self._print_training_info() - - # 初始化训练状态 - global_step = self.current_step - accumulated_loss = 0.0 - accumulated_accuracy = 0.0 - accumulation_counter = 0 - - # 创建进度条 - with self._create_progress_bar() as progress: - epoch_task = progress.add_task( - f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs} (Stage: {self.current_stage})", - total=self.total_steps, - ) - - # 训练循环 - for epoch in range(self.current_epoch, self.num_epochs): - self.current_epoch = epoch - - for batch_idx, batch in enumerate(self.train_dataloader): - # 更新学习率 - current_lr = self._update_learning_rate() - - # 训练步骤 - loss, metrics = self.train_step(batch) - - # 累积指标 - accumulated_loss += loss - accumulated_accuracy += metrics.get("accuracy", 0.0) - accumulation_counter += 1 - - # 梯度累积:每grad_accum_steps步更新一次参数 - if (global_step + 1) % self.grad_accum_steps == 0: - # 梯度裁剪 - self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.clip_grad_norm - ) - - # 更新参数 - self.scaler.step(self.optimizer) - self.scaler.update() - self.optimizer.zero_grad() - - # 更新进度条 - progress.update( - epoch_task, - advance=1, - description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} (Stage: {self.current_stage}) | " - f"Step {global_step}/{self.total_steps} | " - f"Loss: {loss:.4f} | " - f"LR: {current_lr:.2e}", - ) - - # 定期评估和记录 - if (global_step + 1) % self.eval_frequency == 0: - # 计算平均指标 - avg_loss = accumulated_loss / accumulation_counter - avg_accuracy = accumulated_accuracy / accumulation_counter - - # 评估模型 - eval_metrics = self.evaluate() - - # 准备日志指标 - log_metrics = { - "train/loss": avg_loss, - "train/accuracy": avg_accuracy, - "train/learning_rate": current_lr, - "train/stage": 0.0 if self.current_stage == "frozen" else 1.0, - } - - if eval_metrics: - log_metrics.update( - { - "eval/loss": eval_metrics["eval_loss"], - "eval/accuracy": eval_metrics["eval_accuracy"], - } - ) - - # 更新最佳模型(全局) - if eval_metrics["eval_loss"] < self.best_eval_loss: - self.best_eval_loss = eval_metrics["eval_loss"] - # 只保存best_model,不创建额外的checkpoint文件 - self.save_checkpoint("best_model.pt", is_best=True) - - # 更新训练阶段 - self._update_stage_after_eval(eval_metrics["eval_loss"]) - - # 记录到TensorBoard - self._log_to_tensorboard(log_metrics, global_step) - - # 打印日志 - log_text = ( - f"[Epoch {epoch + 1}/{self.num_epochs}] " - f"[Stage: {self.current_stage}] " - f"[Step {global_step}/{self.total_steps}] " - f"Train Loss: {avg_loss:.4f} | " - f"Train Acc: {avg_accuracy:.4f} | " - f"LR: {current_lr:.2e}" - ) - - if eval_metrics: - log_text += ( - f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | " - f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}" - ) - - progress.console.log(log_text) - - # 重置累积指标 - accumulated_loss = 0.0 - accumulated_accuracy = 0.0 - accumulation_counter = 0 - - # 定期保存检查点(覆盖之前的定期检查点) - if (global_step + 1) % self.save_frequency == 0: - self.save_checkpoint("latest_checkpoint.pt", is_periodic=True) - - # 更新步数 - global_step += 1 - self.current_step = global_step - - # 检查是否达到总步数 - if global_step >= self.total_steps: - progress.update(epoch_task, completed=self.total_steps) - break - - # 重置进度条 - progress.reset(epoch_task) - - # 每个epoch结束后保存检查点 - self.save_checkpoint(f"epoch_{epoch + 1}.pt") - - # 检查是否达到总步数 - if global_step >= self.total_steps: - break - - # 训练完成 - logger.info("Two-stage training completed!") - - # 保存最终模型 - self.save_checkpoint("final_model.pt") - - # 关闭TensorBoard写入器 - if self.writer is not None: - self.writer.close() - - def worker_init_fn(worker_id: int) -> None: """ 初始化每个DataLoader worker的随机种子,确保可复现性 @@ -1537,7 +1213,10 @@ def expand_and_train( ..., "--base-model-path", help="预训练基础模型检查点路径" ), new_model_spec: str = typer.Option( - ..., "--new-model-spec", "-m", help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类" + ..., + "--new-model-spec", + "-m", + help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类", ), vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), pinyin_vocab_size: int = typer.Option( @@ -1555,19 +1234,19 @@ def expand_and_train( use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), # 两阶段训练参数 frozen_patience: int = typer.Option( - 10, "--frozen-patience", help="冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调" - ), - frozen_lr: float = typer.Option( - 1e-3, "--frozen-lr", help="冻结阶段学习率" - ), - full_lr: float = typer.Option( - 1e-4, "--full-lr", help="全量微调阶段学习率" + 10, + "--frozen-patience", + help="冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调", ), + frozen_lr: float = typer.Option(1e-3, "--frozen-lr", help="冻结阶段学习率"), + full_lr: float = typer.Option(1e-4, "--full-lr", help="全量微调阶段学习率"), frozen_scheduler: str = typer.Option( "cosine", "--frozen-scheduler", help="冻结阶段学习率调度器类型:cosine或plateau" ), full_scheduler: str = typer.Option( - "cosine", "--full-scheduler", help="全量微调阶段学习率调度器类型:cosine或plateau" + "cosine", + "--full-scheduler", + help="全量微调阶段学习率调度器类型:cosine或plateau", ), # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), @@ -1607,9 +1286,6 @@ def expand_and_train( help="是否开启 torch.compile 优化(需 PyTorch 2.0+)", ), ): - """ - 模型扩容两阶段训练:先冻结匹配层训练,然后全量微调 - """ torch.multiprocessing.set_sharing_strategy("file_system") # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) @@ -1625,7 +1301,9 @@ def expand_and_train( # 打印配置信息 console.print( - Panel.fit("[bold cyan]模型扩容两阶段训练配置[/bold cyan]", border_style="cyan") + Panel.fit( + "[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan" + ) ) config_table = Table(show_header=True, header_style="bold magenta") @@ -1652,13 +1330,8 @@ def expand_and_train( config_table.add_row("模型", "使用拼音", str(use_pinyin)) config_table.add_row("模型", "编译优化", str(compile)) - config_table.add_row("两阶段训练", "冻结阶段耐心值", str(frozen_patience)) - config_table.add_row("两阶段训练", "冻结阶段学习率", f"{frozen_lr:.2e}") - config_table.add_row("两阶段训练", "全量阶段学习率", f"{full_lr:.2e}") - config_table.add_row("两阶段训练", "冻结阶段调度器", frozen_scheduler) - config_table.add_row("两阶段训练", "全量阶段调度器", full_scheduler) - config_table.add_row("训练", "训练轮数", str(num_epochs)) + config_table.add_row("训练", "学习率", f"{learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}") config_table.add_row("训练", "权重衰减", str(weight_decay)) config_table.add_row("训练", "热身比例", str(warmup_ratio)) @@ -1689,14 +1362,10 @@ def expand_and_train( "num_experts": num_experts, "max_seq_len": max_seq_len, "use_pinyin": use_pinyin, - "frozen_patience": frozen_patience, - "frozen_lr": frozen_lr, - "full_lr": full_lr, - "frozen_scheduler": frozen_scheduler, - "full_scheduler": full_scheduler, "batch_size": batch_size, "num_workers": num_workers, "num_epochs": num_epochs, + "learning_rate": learning_rate, "min_learning_rate": min_learning_rate, "weight_decay": weight_decay, "warmup_ratio": warmup_ratio, @@ -1765,7 +1434,7 @@ def expand_and_train( # 创建扩容模型 console.print("[bold cyan]正在创建扩容模型...[/bold cyan]") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + model_kwargs = { "vocab_size": vocab_size, "pinyin_vocab_size": pinyin_vocab_size, @@ -1777,7 +1446,7 @@ def expand_and_train( "max_seq_len": max_seq_len, "compile": compile, } - + model = load_expanded_model( base_model_path=base_model_path, new_model_spec=new_model_spec, @@ -1788,24 +1457,24 @@ def expand_and_train( console.print( f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" ) - + # 统计冻结参数比例 total_params = sum(p.numel() for p in model.parameters()) frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) console.print( - f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params/total_params*100:.1f}%)[/green]" + f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]" ) - # 创建两阶段训练器 - console.print("[bold cyan]正在创建两阶段训练器...[/bold cyan]") - trainer = TwoStageTrainer( + # 创建训练器(使用普通 Trainer,只进行第一阶段冻结训练) + console.print("[bold cyan]正在创建训练器...[/bold cyan]") + trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, total_steps=int(max_iter_length * num_epochs / batch_size), output_dir=output_dir, num_epochs=num_epochs, - learning_rate=frozen_lr, # 初始学习率(会被阶段特定LR覆盖) + learning_rate=learning_rate, min_learning_rate=min_learning_rate, weight_decay=weight_decay, warmup_ratio=warmup_ratio, @@ -1817,18 +1486,12 @@ def expand_and_train( mixed_precision=mixed_precision, use_tensorboard=use_tensorboard, status_file="training_status.json", - # 两阶段训练特有参数 - frozen_patience=frozen_patience, - frozen_lr=frozen_lr, - full_lr=full_lr, - frozen_scheduler=frozen_scheduler, - full_scheduler=full_scheduler, ) - console.print("[green]✓ 两阶段训练器创建完成[/green]") + console.print("[green]✓ 训练器创建完成[/green]") # 开始训练 - console.print("\n[bold cyan]开始两阶段训练...[/bold cyan]") + console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( @@ -1838,9 +1501,313 @@ def expand_and_train( console.print("[bold green]训练被终止[/bold green]") trainer.save_checkpoint("interrupted_model.pt") - console.print("[bold green]✓ 两阶段训练完成![/bold green]") + # 保存扩容信息供第二阶段使用 + expansion_info = { + "stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"), + "model_spec": new_model_spec, + "model_kwargs": model_kwargs, + "train_data_path": train_data_path, + "eval_data_path": eval_data_path, + "output_dir": output_dir, + "batch_size": batch_size, + "max_iter_length": max_iter_length, + "max_seq_len": max_seq_len, + "num_workers": num_workers, + } + + expansion_info_file = output_path / "expansion_info.json" + with open(expansion_info_file, "w", encoding="utf-8") as f: + json.dump(expansion_info, f, indent=2, ensure_ascii=False) + + logger.info(f"Expansion info saved to {expansion_info_file}") + + console.print("[bold green]✓ 第一阶段训练完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"模型和日志保存在: {output_dir}") + console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]") + console.print( + "[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]" + ) + + +@app.command() +def expand_finetune( + expand_config: str = typer.Option( + ..., + "--expand-config", + "-c", + help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'", + ), + stage1_info: str = typer.Option( + ..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径" + ), + # 可选覆盖参数 + checkpoint: Optional[str] = typer.Option( + None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)" + ), + output_dir: Optional[str] = typer.Option( + None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)" + ), + train_data_path: Optional[str] = typer.Option( + None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)" + ), + eval_data_path: Optional[str] = typer.Option( + None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)" + ), + batch_size: Optional[int] = typer.Option( + None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)" + ), + num_epochs: Optional[int] = typer.Option( + None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)" + ), + learning_rate: Optional[float] = typer.Option( + None, "--learning-rate", "-lr", help="学习率" + ), + min_learning_rate: Optional[float] = typer.Option( + None, "--min-learning-rate", help="最小学习率" + ), + weight_decay: Optional[float] = typer.Option( + None, "--weight-decay", help="权重衰减" + ), + warmup_ratio: Optional[float] = typer.Option( + None, "--warmup-ratio", help="热身步数比例" + ), + label_smoothing: Optional[float] = typer.Option( + None, "--label-smoothing", help="标签平滑参数" + ), + grad_accum_steps: Optional[int] = typer.Option( + None, "--grad-accum-steps", help="梯度累积步数" + ), + clip_grad_norm: Optional[float] = typer.Option( + None, "--clip-grad-norm", help="梯度裁剪范数" + ), + eval_frequency: Optional[int] = typer.Option( + None, "--eval-frequency", help="评估频率" + ), + save_frequency: Optional[int] = typer.Option( + None, "--save-frequency", help="保存频率" + ), + max_iter_length: Optional[int] = typer.Option( + None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)" + ), + max_seq_len: Optional[int] = typer.Option( + None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)" + ), + num_workers: Optional[int] = typer.Option( + None, "--num-workers", help="数据加载worker数量" + ), + mixed_precision: bool = typer.Option( + True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" + ), + use_tensorboard: bool = typer.Option( + True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" + ), + resume_from: Optional[str] = typer.Option( + None, "--resume-from", help="从检查点恢复训练" + ), + reset_training_state: bool = typer.Option( + False, "--reset-training-state", help="重置训练状态" + ), + seed: int = typer.Option(42, "--seed", help="随机种子"), + compile: Optional[bool] = typer.Option( + None, "--compile/--no-compile", help="是否开启 torch.compile 优化" + ), +): + """ + 模型扩容第二阶段训练:读取第一阶段的 expansion_info.json,加载扩容模型进行全量微调。 + 命令行参数优先级高于 JSON 文件中的配置。 + """ + torch.multiprocessing.set_sharing_strategy("file_system") + + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("high") + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + console = Console() + + # 加载第一阶段信息 + stage1_info_path = Path(stage1_info) + if not stage1_info_path.exists(): + console.print( + f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]" + ) + raise typer.Exit(1) + + with open(stage1_info_path, "r", encoding="utf-8") as f: + info = json.load(f) + + # 命令行参数优先级高于 JSON 文件 + final_checkpoint = checkpoint or info["stage1_checkpoint_path"] + final_output_dir = output_dir or info["output_dir"] + final_train_data_path = train_data_path or info["train_data_path"] + final_eval_data_path = eval_data_path or info["eval_data_path"] + final_batch_size = batch_size if batch_size is not None else info["batch_size"] + final_num_epochs = ( + num_epochs if num_epochs is not None else info.get("num_epochs", 10) + ) + final_max_iter_length = ( + max_iter_length if max_iter_length is not None else info["max_iter_length"] + ) + final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"] + final_num_workers = ( + num_workers if num_workers is not None else info.get("num_workers", 2) + ) + + # 训练参数(有默认值,不覆盖则使用默认) + final_learning_rate = learning_rate if learning_rate is not None else 1e-4 + final_min_learning_rate = ( + min_learning_rate if min_learning_rate is not None else 1e-9 + ) + final_weight_decay = weight_decay if weight_decay is not None else 0.1 + final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1 + final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15 + final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1 + final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0 + final_eval_frequency = eval_frequency if eval_frequency is not None else 500 + final_save_frequency = save_frequency if save_frequency is not None else 1000 + + # 模型参数从 JSON 获取 + model_kwargs = info["model_kwargs"] + if compile is not None: + model_kwargs["compile"] = compile + + console.print( + Panel.fit( + "[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan" + ) + ) + + config_table = Table(show_header=True, header_style="bold magenta") + config_table.add_column("Category", style="cyan") + config_table.add_column("Parameter", style="green") + config_table.add_column("Value", style="yellow") + + config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path)) + config_table.add_row("数据", "训练数据路径", final_train_data_path) + config_table.add_row("数据", "评估数据路径", final_eval_data_path) + config_table.add_row("数据", "输出目录", final_output_dir) + config_table.add_row("数据", "批次大小", str(final_batch_size)) + config_table.add_row("数据", "Worker数量", str(final_num_workers)) + + config_table.add_row("模型", "新模型规格", expand_config) + config_table.add_row("模型", "检查点路径", final_checkpoint) + for k, v in model_kwargs.items(): + config_table.add_row("模型", k, str(v)) + + config_table.add_row("训练", "训练轮数", str(final_num_epochs)) + config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}") + config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}") + config_table.add_row("训练", "权重衰减", str(final_weight_decay)) + config_table.add_row("训练", "热身比例", str(final_warmup_ratio)) + config_table.add_row("训练", "标签平滑", str(final_label_smoothing)) + config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps)) + config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm)) + config_table.add_row("训练", "混合精度", str(mixed_precision)) + + console.print(config_table) + + output_path = Path(final_output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") + + train_dataset = PinyinInputDataset( + data_path=final_train_data_path, + max_workers=-1, + max_iter_length=final_max_iter_length, + max_seq_length=final_max_seq_len, + text_field="text", + py_style_weight=(9, 2, 1), + shuffle_buffer_size=5000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, + ) + + train_dataloader = create_dataloader( + dataset=train_dataset, + batch_size=final_batch_size, + num_workers=final_num_workers, + pin_memory=torch.cuda.is_available(), + max_iter_length=final_max_iter_length, + ) + + eval_dataset = PinyinInputDataset( + data_path=final_eval_data_path, + max_workers=-1, + max_iter_length=final_batch_size * 64, + max_seq_length=final_max_seq_len, + text_field="text", + py_style_weight=(9, 2, 1), + shuffle_buffer_size=50000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, + ) + + eval_dataloader = create_dataloader( + dataset=eval_dataset, + batch_size=final_batch_size, + num_workers=1, + pin_memory=torch.cuda.is_available(), + max_iter_length=final_batch_size * 64, + ) + + console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = load_expanded_model( + base_model_path=final_checkpoint, + new_model_spec=expand_config, + device=device, + **model_kwargs, + ) + + # 全量微调:解冻所有参数 + for param in model.parameters(): + param.requires_grad = True + + console.print( + f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" + ) + console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]") + + console.print("[bold cyan]正在创建训练器...[/bold cyan]") + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size), + output_dir=final_output_dir, + num_epochs=final_num_epochs, + learning_rate=final_learning_rate, + min_learning_rate=final_min_learning_rate, + weight_decay=final_weight_decay, + warmup_ratio=final_warmup_ratio, + label_smoothing=final_label_smoothing, + grad_accum_steps=final_grad_accum_steps, + clip_grad_norm=final_clip_grad_norm, + eval_frequency=final_eval_frequency, + save_frequency=final_save_frequency, + mixed_precision=mixed_precision, + use_tensorboard=use_tensorboard, + status_file="training_status_finetune.json", + ) + + console.print("[green]✓ 训练器创建完成[/green]") + + console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]") + console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + try: + trainer.train( + resume_from=resume_from, reset_training_state=reset_training_state + ) + except KeyboardInterrupt: + console.print("[bold green]训练被终止[/bold green]") + trainer.save_checkpoint("interrupted_model.pt") + + console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]") + console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + console.print(f"模型和日志保存在: {final_output_dir}") if __name__ == "__main__":