import json import math import random from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.optim as optim import typer from loguru import logger from rich.console import Console from rich.panel import Panel from rich.progress import ( BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, ) from rich.table import Table from torch import autocast from torch.amp.grad_scaler import GradScaler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from .dataset import PinyinInputDataset # 导入模型和数据 from .model import InputMethodEngine class Trainer: """ 输入法模型训练器 实现训练InputMethodEngine模型,支持: - 预热+余弦退火学习率调度 - TensorBoard日志记录 - AdamW优化器(weight_decay=0.1) - 混合精度训练 - CrossEntropyLoss损失函数(支持weight和label_smoothing) - Rich终端美化输出 """ training_status_data: List[Dict[str, Any]] def __init__( self, model: nn.Module, train_dataloader: DataLoader, eval_dataloader: DataLoader, total_steps: int, output_dir: str = "./output", num_epochs: int = 10, learning_rate: float = 1e-4, min_learning_rate: float = 1e-6, weight_decay: float = 0.1, warmup_ratio: float = 0.1, label_smoothing: float = 0.15, loss_weight: Optional[torch.Tensor] = None, grad_accum_steps: int = 1, clip_grad_norm: float = 1.0, eval_frequency: int = 500, save_frequency: int = 10000, mixed_precision: bool = True, device: Optional[torch.device] = None, status_file: str = "training_status.json", use_tensorboard: bool = True, ): """ 初始化训练器 Args: model: 要训练的InputMethodEngine模型 train_dataloader: 训练数据加载器 eval_dataloader: 评估数据加载器(可选) output_dir: 输出目录,用于保存模型和日志 num_epochs: 训练轮数 total_steps: 总训练步数,如果为None则根据epochs计算 learning_rate: 最大学习率(预热后) min_learning_rate: 最小学习率(余弦退火后的最低值) weight_decay: AdamW优化器的权重衰减 warmup_ratio: 热身步数占总步数的比例 label_smoothing: CrossEntropyLoss的标签平滑参数 loss_weight: CrossEntropyLoss的类别权重 grad_accum_steps: 梯度累积步数 clip_grad_norm: 梯度裁剪的最大范数 eval_frequency: 评估频率(步数) save_frequency: 保存检查点频率(步数) mixed_precision: 是否使用混合精度训练 device: 训练设备,如果为None则自动选择 use_tensorboard: 是否使用TensorBoard记录 """ self.model = model self.train_dataloader = train_dataloader self.eval_dataloader = list([i for i in eval_dataloader]) self.output_dir = Path(output_dir) self.num_epochs = num_epochs self.learning_rate = learning_rate self.min_learning_rate = min_learning_rate self.weight_decay = weight_decay self.warmup_ratio = warmup_ratio self.label_smoothing = label_smoothing self.loss_weight = loss_weight self.grad_accum_steps = grad_accum_steps self.clip_grad_norm = clip_grad_norm self.eval_frequency = eval_frequency self.save_frequency = save_frequency self.mixed_precision = mixed_precision self.use_tensorboard = use_tensorboard # 设置设备 logger.info(f"GPU可用: {torch.cuda.is_available()}") if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device # 移动模型到设备 self.model.to(self.device) # 创建输出目录 self.output_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir = self.output_dir / "checkpoints" self.checkpoint_dir.mkdir(exist_ok=True) # 计算总步数 self.total_steps = total_steps self.warmup_steps = int(self.total_steps * warmup_ratio) # 初始化优化器 self.optimizer = optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8, ) # 初始化损失函数 if loss_weight is not None: self.criterion = nn.CrossEntropyLoss( weight=loss_weight.to(self.device), label_smoothing=label_smoothing ) else: self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) # 初始化混合精度训练器 device_type = "cuda" if "cuda" in str(self.device) else "cpu" self.scaler = GradScaler(device_type, enabled=mixed_precision) # 初始化TensorBoard if use_tensorboard: self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") else: self.writer = None # 设置状态文件 self.use_tensorboard = use_tensorboard self.status_file = self.output_dir / status_file # 不加载历史数据,直接初始化为空列表以覆盖原有数据 self.training_status_data = [] # 初始化Rich控制台 self.console = Console() # 训练状态 self.current_step = 0 self.current_epoch = 0 self.best_eval_loss = float("inf") # 学习率调度函数 self.lr_scheduler = self._create_lr_scheduler() logger.info(f"Trainer initialized with device: {self.device}") logger.info( f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}" ) logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}") def _create_lr_scheduler(self) -> Callable[[int], float]: """创建学习率调度函数(预热 + 余弦退火)""" def lr_scheduler(step: int) -> float: if step < self.warmup_steps: # 线性预热 return self.learning_rate * (step / self.warmup_steps) else: # 余弦退火 progress = (step - self.warmup_steps) / ( self.total_steps - self.warmup_steps ) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) decayed_lr = ( self.min_learning_rate + (self.learning_rate - self.min_learning_rate) * cosine_decay ) return decayed_lr return lr_scheduler def _get_current_lr(self) -> float: """获取当前学习率""" return self.lr_scheduler(self.current_step) def _update_learning_rate(self): """更新优化器中的学习率""" current_lr = self._get_current_lr() for param_group in self.optimizer.param_groups: param_group["lr"] = current_lr return current_lr def train_step( self, batch: Dict[str, torch.Tensor] ) -> Tuple[float, Dict[str, Any]]: """ 执行单个训练步骤 Args: batch: 包含输入数据的批次 Returns: loss: 损失值 metrics: 训练指标字典 """ self.model.train() # 移动数据到设备 (异步传输以提升 GPU 利用率) input_ids = batch["input_ids"].to(self.device, non_blocking=True) token_type_ids = batch["token_type_ids"].to(self.device, non_blocking=True) attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) history_slot_ids = batch["history_slot_ids"].to(self.device, non_blocking=True) pinyin_ids = batch["pinyin_ids"].to(self.device, non_blocking=True) labels = ( batch["labels"].to(self.device, non_blocking=True).squeeze(-1) ) # [batch_size] # 混合精度训练 with autocast(device_type=self.device.type, enabled=self.mixed_precision): # 前向传播 logits = self.model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) # 计算损失 loss = self.criterion(logits, labels) loss = loss / self.grad_accum_steps # 反向传播 self.scaler.scale(loss).backward() metrics = { "loss": loss.item() * self.grad_accum_steps, "lr": self._get_current_lr(), } # 计算准确率 with torch.no_grad(): preds = torch.argmax(logits, dim=-1) correct = (preds == labels).sum().item() total = labels.size(0) metrics["accuracy"] = correct / total if total > 0 else 0.0 return loss.item() * self.grad_accum_steps, metrics def evaluate(self) -> Dict[str, float]: """ 在评估集上评估模型 Returns: 评估指标字典 """ if self.eval_dataloader is None: return {} self.model.eval() total_loss = 0.0 total_correct = 0 total_samples = 0 with torch.no_grad(): for batch in self.eval_dataloader: # 移动数据到设备 input_ids = batch["input_ids"].to(self.device) token_type_ids = batch["token_type_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) history_slot_ids = batch["history_slot_ids"].to(self.device) pinyin_ids = batch["pinyin_ids"].to(self.device) labels = batch["labels"].to(self.device).squeeze(-1) # 前向传播 logits = self.model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) # 计算损失 loss = self.criterion(logits, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 preds = torch.argmax(logits, dim=-1) correct = (preds == labels).sum().item() total_correct += correct total_samples += labels.size(0) avg_loss = total_loss / total_samples if total_samples > 0 else 0.0 accuracy = total_correct / total_samples if total_samples > 0 else 0.0 return {"eval_loss": avg_loss, "eval_accuracy": accuracy} def save_checkpoint( self, filename: str, is_best: bool = False, is_periodic: bool = False ): """ 保存检查点 Args: filename: 检查点文件名 is_best: 是否是最佳模型 is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点) """ # 如果是定期保存,使用固定的文件名来覆盖之前的 if is_periodic: checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt" else: checkpoint_path = self.checkpoint_dir / filename checkpoint = { "step": self.current_step, "epoch": self.current_epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scaler_state_dict": self.scaler.state_dict(), "best_eval_loss": self.best_eval_loss, "config": { "learning_rate": self.learning_rate, "weight_decay": self.weight_decay, "warmup_ratio": self.warmup_ratio, "label_smoothing": self.label_smoothing, "total_steps": self.total_steps, }, } torch.save(checkpoint, checkpoint_path) logger.info(f"Checkpoint saved to {checkpoint_path}") if is_best: best_path = self.checkpoint_dir / "best_model.pt" torch.save(checkpoint, best_path) logger.info(f"Best model saved to {best_path}") def load_checkpoint( self, checkpoint_path: Union[str, Path], reset_training_state: bool = False ): """ 加载检查点 Args: checkpoint_path: 检查点文件路径 reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练) """ checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) if not reset_training_state: self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.current_step = checkpoint["step"] self.current_epoch = checkpoint["epoch"] self.best_eval_loss = checkpoint["best_eval_loss"] logger.info(f"Checkpoint loaded from {checkpoint_path}") logger.info( f"Resuming from step {self.current_step}, epoch {self.current_epoch}" ) else: # 重置训练状态 self.current_step = 0 self.current_epoch = 0 self.best_eval_loss = float("inf") logger.info( f"Checkpoint loaded from {checkpoint_path} (training state reset)" ) logger.info("Training state reset: starting from step 0, epoch 0") def _log_to_tensorboard(self, metrics: Dict[str, float], step: int): """将指标记录到TensorBoard和JSON状态文件""" if self.writer is not None: for key, value in metrics.items(): self.writer.add_scalar(key, value, step) # 同时记录到JSON状态文件 self._write_training_status(metrics, step) def _load_existing_status_data(self) -> List[Dict]: """从文件加载已有的训练状态数据""" try: if self.status_file.exists(): with open(self.status_file, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): logger.info( f"Loaded {len(data)} existing training status records from {self.status_file}" ) return data else: logger.warning( f"Status file {self.status_file} does not contain a list, starting fresh" ) return [] else: logger.info( f"Status file {self.status_file} does not exist, starting fresh" ) return [] except json.JSONDecodeError: logger.warning( f"Status file {self.status_file} has invalid JSON format, starting fresh" ) return [] except Exception as e: logger.error( f"Failed to load existing status data from {self.status_file}: {e}" ) return [] def _write_training_status(self, metrics: Dict[str, float], step: int): """将训练状态写入JSON文件""" try: # 创建状态记录 status_record = { "step": step, "epoch": self.current_epoch + 1, "timestamp": datetime.now().isoformat(), } # 添加所有指标 for key, value in metrics.items(): status_record[key] = float(value) # 检查是否已存在相同步数的记录(避免重复) existing_indices = [ i for i, record in enumerate(self.training_status_data) if record.get("step") == step ] if existing_indices: # 替换现有记录 for idx in existing_indices: self.training_status_data[idx] = status_record # type: ignore else: # 添加到内存缓存 self.training_status_data.append(status_record) # type: ignore # 限制内存中的数据量,只保留最近1000条记录 if len(self.training_status_data) > 1000: self.training_status_data = self.training_status_data[-1000:] # 确保数据是列表格式 if not isinstance(self.training_status_data, list): logger.warning( f"training_status_data is not a list (type: {type(self.training_status_data).__name__}), converting to list" ) self.training_status_data = ( [self.training_status_data] if self.training_status_data else [] ) # 使用原子写入避免读取不完整JSON # 先写入临时文件,然后原子重命名 temp_file = Path(f"{self.status_file}.tmp") with open(temp_file, "w", encoding="utf-8") as f: json.dump(self.training_status_data, f, indent=2, ensure_ascii=False) # 原子重命名(Unix系统是原子操作) temp_file.rename(self.status_file) except Exception as e: logger.error(f"Failed to write training status: {e}") def _create_progress_bar(self) -> Progress: """创建Rich进度条""" return Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), TimeRemainingColumn(), console=self.console, expand=True, ) def _print_training_info(self): """打印训练信息""" info_table = Table( title="Training Configuration", show_header=True, header_style="bold magenta", ) info_table.add_column("Parameter", style="cyan") info_table.add_column("Value", style="green") info_table.add_row("Device", str(self.device)) info_table.add_row("Total Steps", str(self.total_steps)) info_table.add_row("Warmup Steps", str(self.warmup_steps)) info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}") info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}") info_table.add_row("Weight Decay", str(self.weight_decay)) info_table.add_row("Label Smoothing", str(self.label_smoothing)) info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps)) info_table.add_row("Mixed Precision", str(self.mixed_precision)) self.console.print(info_table) def train( self, resume_from: Optional[str] = None, reset_training_state: bool = False ): """ 主训练循环 Args: resume_from: 从哪个检查点恢复训练(可选) reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练) """ # 如果提供了检查点,则恢复训练 if resume_from is not None: self.load_checkpoint(resume_from, reset_training_state=reset_training_state) # 打印训练信息 self._print_training_info() # 初始化训练状态 global_step = self.current_step accumulated_loss = 0.0 accumulated_accuracy = 0.0 accumulation_counter = 0 # 创建进度条 with self._create_progress_bar() as progress: epoch_task = progress.add_task( f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}", total=self.total_steps, ) # 训练循环 for epoch in range(self.current_epoch, self.num_epochs): self.current_epoch = epoch progress.update( epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}" ) for batch_idx, batch in enumerate(self.train_dataloader): # 更新学习率 current_lr = self._update_learning_rate() # 训练步骤 loss, metrics = self.train_step(batch) # 累积指标 accumulated_loss += loss accumulated_accuracy += metrics.get("accuracy", 0.0) accumulation_counter += 1 # 梯度累积:每grad_accum_steps步更新一次参数 if (global_step + 1) % self.grad_accum_steps == 0: # 梯度裁剪 self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.clip_grad_norm ) # 更新参数 self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() # 更新进度条 progress.update( epoch_task, advance=1, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | " f"Step {global_step}/{self.total_steps} | " f"Loss: {loss:.4f} | " f"LR: {current_lr:.2e}", ) # 定期评估和记录 if (global_step + 1) % self.eval_frequency == 0: # 计算平均指标 avg_loss = accumulated_loss / accumulation_counter avg_accuracy = accumulated_accuracy / accumulation_counter # 评估模型 eval_metrics = self.evaluate() # 准备日志指标 log_metrics = { "train/loss": avg_loss, "train/accuracy": avg_accuracy, "train/learning_rate": current_lr, } if eval_metrics: log_metrics.update( { "eval/loss": eval_metrics["eval_loss"], "eval/accuracy": eval_metrics["eval_accuracy"], } ) # 更新最佳模型 if eval_metrics["eval_loss"] < self.best_eval_loss: self.best_eval_loss = eval_metrics["eval_loss"] # 只保存best_model,不创建额外的checkpoint文件 self.save_checkpoint("best_model.pt", is_best=True) # 记录到TensorBoard self._log_to_tensorboard(log_metrics, global_step) # 打印日志 log_text = ( f"[Epoch {epoch + 1}/{self.num_epochs}] " f"[Step {global_step}/{self.total_steps}] " f"Train Loss: {avg_loss:.4f} | " f"Train Acc: {avg_accuracy:.4f} | " f"LR: {current_lr:.2e}" ) if eval_metrics: log_text += ( f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | " f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}" ) progress.console.log(log_text) # 重置累积指标 accumulated_loss = 0.0 accumulated_accuracy = 0.0 accumulation_counter = 0 # 定期保存检查点(覆盖之前的定期检查点) if (global_step + 1) % self.save_frequency == 0: self.save_checkpoint("latest_checkpoint.pt", is_periodic=True) # 更新步数 global_step += 1 self.current_step = global_step # 检查是否达到总步数 if global_step >= self.total_steps: progress.update(epoch_task, completed=self.total_steps) break # 重置进度条 progress.reset(epoch_task) # 每个epoch结束后保存检查点 self.save_checkpoint(f"epoch_{epoch + 1}.pt") # 检查是否达到总步数 if global_step >= self.total_steps: break # 训练完成 logger.info("Training completed!") # 保存最终模型 self.save_checkpoint("final_model.pt") # 关闭TensorBoard写入器 if self.writer is not None: self.writer.close() def load_expanded_model( base_model_path: str, new_model_spec: str, device: torch.device, **model_kwargs, ) -> nn.Module: """ 加载预训练基础模型并创建扩容后的新模型,冻结匹配的层。 Args: base_model_path: 预训练基础模型检查点路径 new_model_spec: 新模型规格,格式 "module:ClassName",如 "new_model:NewModel" device: 设备 **model_kwargs: 传递给新模型构造函数的参数 Returns: 扩容后的新模型,匹配的层已冻结 """ import importlib import sys # 解析新模型规格 if ":" not in new_model_spec: raise ValueError( f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'" ) module_name, class_name = new_model_spec.split(":", 1) # 导入模块(支持任意路径) module = None try: # 尝试直接导入 module = importlib.import_module(module_name) except ImportError: # 如果失败,尝试将其视为文件路径 try: # 将模块名转换为可能的文件路径 module_path = module_name.replace(".", "/") + ".py" import importlib.util spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is None or spec.loader is None: raise ImportError(f"Cannot find module or loader: {module_name}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore except Exception as e: # 尝试在当前目录下查找 import os if os.path.exists(module_name + ".py"): spec = importlib.util.spec_from_file_location( module_name, module_name + ".py" ) if spec is None or spec.loader is None: raise ImportError(f"Cannot load module from file: {module_name}.py") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore else: raise ImportError(f"Failed to import module '{module_name}': {e}") if module is None: raise ImportError(f"Module '{module_name}' could not be imported") # 获取模型类 model_class = getattr(module, class_name) # 检查模型类是否是 InputMethodEngine 的子类 from .model import InputMethodEngine if not issubclass(model_class, InputMethodEngine): raise TypeError( f"Model class {class_name} must be a subclass of InputMethodEngine. " f"Got {model_class.__name__} instead." ) # 创建新模型 new_model = model_class(**model_kwargs) new_model.to(device) # 加载预训练权重 checkpoint = torch.load(base_model_path, map_location=device) if "model_state_dict" in checkpoint: pretrained_state_dict = checkpoint["model_state_dict"] else: pretrained_state_dict = checkpoint # 获取新模型的状态字典 new_state_dict = new_model.state_dict() # 冻结匹配的层 frozen_layers = [] for key in new_state_dict.keys(): if key in pretrained_state_dict: if new_state_dict[key].shape == pretrained_state_dict[key].shape: new_state_dict[key] = pretrained_state_dict[key].to(device) frozen_layers.append(key) # 加载更新后的状态字典 new_model.load_state_dict(new_state_dict) # 设置参数 requires_grad for name, param in new_model.named_parameters(): if name in frozen_layers: param.requires_grad = False logger.info(f"Loaded expanded model with {len(frozen_layers)} frozen layers") logger.info( f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}" ) return new_model def worker_init_fn(worker_id: int) -> None: """ 初始化每个DataLoader worker的随机种子,确保可复现性 Args: worker_id: worker的ID """ worker_seed = torch.initial_seed() % (2**32) np.random.seed(worker_seed) random.seed(worker_seed) def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ 自定义批处理函数,将多个样本组合成一个batch Args: batch: 样本列表,每个样本是一个字典 Returns: 批处理后的字典,tensor字段已stack,字符串字段保持为列表 """ # 处理tensor字段 - 使用squeeze去除多余的batch维度 input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) labels = torch.stack([item["label"].squeeze(0) for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) # 字符串字段保持为列表 prefixes = [item["prefix"] for item in batch] suffixes = [item["suffix"] for item in batch] pinyins = [item["pinyin"] for item in batch] return { "input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "labels": labels, "history_slot_ids": history_slot_ids, "prefix": prefixes, "suffix": suffixes, "pinyin": pinyins, "pinyin_ids": pinyin_ids, } # Typer CLI应用 def create_dataloader( dataset: PinyinInputDataset, batch_size: int, num_workers: int = 2, pin_memory: bool = True, shuffle: bool = False, max_iter_length: Optional[int] = None, ) -> Any: """ 创建数据加载器,优先使用DataLoader2,如果不可用则回退到DataLoader。 专门针对流式数据集优化。 Args: dataset: PinyinInputDataset实例 batch_size: 批次大小 num_workers: worker数量(对于流式数据集建议为2) pin_memory: 是否固定内存 shuffle: 是否打乱(流式数据集内部处理打乱) max_iter_length: 最大迭代长度,用于计算总步数 Returns: 数据加载器实例 """ logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}") dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=2, # 减少预取以避免内存问题 persistent_workers=True, shuffle=shuffle, ) return dataloader app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False) @app.command() def train( # 数据参数 train_data_path: str = typer.Option( ..., "--train-data-path", "-t", help="训练数据集路径" ), eval_data_path: str = typer.Option( ..., "--eval-data-path", "-e", help="评估数据集路径" ), output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"), # 模型参数 vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), pinyin_vocab_size: int = typer.Option( 30, "--pinyin-vocab-size", help="拼音词汇表大小" ), max_iter_length: int = typer.Option( 1024 * 1024 * 128, "--max_iter_length", help="数据集大小" ), dim: int = typer.Option(512, "--dim", help="模型维度"), num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"), n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"), n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"), num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"), max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"), use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"), min_learning_rate: float = typer.Option( 1e-9, "--min-learning-rate", help="最小学习率" ), weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"), warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"), label_smoothing: float = typer.Option( 0.15, "--label-smoothing", help="标签平滑参数" ), grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"), save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"), # 其他参数 mixed_precision: bool = typer.Option( True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" ), num_workers: int = typer.Option( 2, "--num-workers", help="数据加载worker数量(流式数据集建议为2)" ), use_tensorboard: bool = typer.Option( True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" ), resume_from: Optional[str] = typer.Option( None, "--resume-from", help="从检查点恢复训练" ), reset_training_state: bool = typer.Option( False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练" ), seed: int = typer.Option(42, "--seed", help="随机种子"), compile: bool = typer.Option( False, "--compile/--no-compile", help="是否开启 torch.compile 优化(需 PyTorch 2.0+)", ), ): """ 训练输入法模型 """ torch.multiprocessing.set_sharing_strategy("file_system") # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) if torch.cuda.is_available(): torch.set_float32_matmul_precision("high") # 设置随机种子 torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) console = Console() # 打印配置信息 console.print( Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan") ) config_table = Table(show_header=True, header_style="bold magenta") config_table.add_column("Category", style="cyan") config_table.add_column("Parameter", style="green") config_table.add_column("Value", style="yellow") # 添加配置信息 config_table.add_row("数据", "训练数据路径", train_data_path) config_table.add_row("数据", "评估数据路径", eval_data_path) config_table.add_row("数据", "输出目录", output_dir) config_table.add_row("数据", "批次大小", str(batch_size)) config_table.add_row("数据", "Worker数量", str(num_workers)) config_table.add_row("模型", "词汇表大小", str(vocab_size)) config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size)) config_table.add_row("模型", "模型维度", str(dim)) config_table.add_row("模型", "槽位数量", str(num_slots)) config_table.add_row("模型", "Transformer层数", str(n_layers)) config_table.add_row("模型", "注意力头数", str(n_heads)) config_table.add_row("模型", "MoE专家数", str(num_experts)) config_table.add_row("模型", "使用拼音", str(use_pinyin)) config_table.add_row("模型", "编译优化", str(compile)) config_table.add_row("训练", "训练轮数", str(num_epochs)) config_table.add_row("训练", "学习率", f"{learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}") config_table.add_row("训练", "权重衰减", str(weight_decay)) config_table.add_row("训练", "热身比例", str(warmup_ratio)) config_table.add_row("训练", "标签平滑", str(label_smoothing)) config_table.add_row("训练", "梯度累积", str(grad_accum_steps)) config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) console.print(config_table) # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 保存配置 config = { "train_data_path": train_data_path, "eval_data_path": eval_data_path, "output_dir": output_dir, "vocab_size": vocab_size, "pinyin_vocab_size": pinyin_vocab_size, "dim": dim, "num_slots": num_slots, "n_layers": n_layers, "n_heads": n_heads, "num_experts": num_experts, "max_seq_len": max_seq_len, "use_pinyin": use_pinyin, "batch_size": batch_size, "num_workers": num_workers, "num_epochs": num_epochs, "learning_rate": learning_rate, "min_learning_rate": min_learning_rate, "weight_decay": weight_decay, "warmup_ratio": warmup_ratio, "label_smoothing": label_smoothing, "grad_accum_steps": grad_accum_steps, "clip_grad_norm": clip_grad_norm, "eval_frequency": eval_frequency, "save_frequency": save_frequency, "mixed_precision": mixed_precision, "use_tensorboard": use_tensorboard, "seed": seed, "max_iter_length": max_iter_length, "compile": compile, } config_file = output_path / "training_config.json" with open(config_file, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info(f"Configuration saved to {config_file}") # 创建数据加载器 console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") # 训练数据集 train_dataset = PinyinInputDataset( data_path=train_data_path, max_workers=-1, # 自动选择worker数量 max_iter_length=max_iter_length, max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=5000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) # 训练数据加载器 # 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数 # 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理 train_dataloader = create_dataloader( dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=torch.cuda.is_available(), max_iter_length=max_iter_length, ) # 评估数据集(使用相同的设置,但可以调整参数) eval_dataset = PinyinInputDataset( data_path=eval_data_path, max_workers=-1, max_iter_length=batch_size * 64, # 评估集较小 max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=50000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=batch_size, num_workers=1, # 评估使用较少的worker pin_memory=torch.cuda.is_available(), max_iter_length=batch_size * 64, ) console.print("[bold cyan]正在创建模型...[/bold cyan]") model = InputMethodEngine( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, num_slots=num_slots, n_layers=n_layers, n_heads=n_heads, num_experts=num_experts, max_seq_len=max_seq_len, compile=compile, ) console.print( f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" ) # 创建训练器 console.print("[bold cyan]正在创建训练器...[/bold cyan]") trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, total_steps=int(max_iter_length * num_epochs / batch_size), output_dir=output_dir, num_epochs=num_epochs, learning_rate=learning_rate, min_learning_rate=min_learning_rate, weight_decay=weight_decay, warmup_ratio=warmup_ratio, label_smoothing=label_smoothing, grad_accum_steps=grad_accum_steps, clip_grad_norm=clip_grad_norm, eval_frequency=eval_frequency, save_frequency=save_frequency, mixed_precision=mixed_precision, use_tensorboard=use_tensorboard, status_file="training_status.json", ) console.print("[green]✓ 训练器创建完成[/green]") # 开始训练 console.print("\n[bold cyan]开始训练...[/bold cyan]") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( resume_from=resume_from, reset_training_state=reset_training_state ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") trainer.save_checkpoint("interrupted_model.pt") console.print("[bold green]✓ 训练完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"模型和日志保存在: {output_dir}") @app.command() def evaluate( checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"), data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"), batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"), ): """ 评估训练好的模型 """ console = Console() console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]") # 这里应该实现评估逻辑 # 1. 加载检查点 # 2. 创建数据加载器 # 3. 评估模型 console.print("[yellow]评估功能待实现[/yellow]") @app.command() def export( checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"), output_path: str = typer.Option( "./exported_model.onnx", "--output", "-o", help="输出路径" ), ): """ 导出模型为ONNX格式 """ console = Console() console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]") # 这里应该实现导出逻辑 # 1. 加载检查点 # 2. 导出为ONNX console.print("[yellow]导出功能待实现[/yellow]") @app.command() def expand_and_train( # 数据参数 train_data_path: str = typer.Option( ..., "--train-data-path", "-t", help="训练数据集路径" ), eval_data_path: str = typer.Option( ..., "--eval-data-path", "-e", help="评估数据集路径" ), output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"), # 模型参数 base_model_path: str = typer.Option( ..., "--base-model-path", help="预训练基础模型检查点路径" ), new_model_spec: str = typer.Option( ..., "--new-model-spec", "-m", help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类", ), vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), pinyin_vocab_size: int = typer.Option( 30, "--pinyin-vocab-size", help="拼音词汇表大小" ), max_iter_length: int = typer.Option( 1024 * 1024 * 128, "--max_iter_length", help="数据集大小" ), dim: int = typer.Option(512, "--dim", help="模型维度"), num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"), n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"), n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"), num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"), max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"), use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), # 两阶段训练参数 frozen_patience: int = typer.Option( 10, "--frozen-patience", help="冻结阶段验证损失连续不下降的epoch数,触发切换到全量微调", ), frozen_lr: float = typer.Option(1e-3, "--frozen-lr", help="冻结阶段学习率"), full_lr: float = typer.Option(1e-4, "--full-lr", help="全量微调阶段学习率"), frozen_scheduler: str = typer.Option( "cosine", "--frozen-scheduler", help="冻结阶段学习率调度器类型:cosine或plateau" ), full_scheduler: str = typer.Option( "cosine", "--full-scheduler", help="全量微调阶段学习率调度器类型:cosine或plateau", ), # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), min_learning_rate: float = typer.Option( 1e-9, "--min-learning-rate", help="最小学习率" ), weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"), warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"), label_smoothing: float = typer.Option( 0.15, "--label-smoothing", help="标签平滑参数" ), grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"), save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"), # 其他参数 mixed_precision: bool = typer.Option( True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" ), num_workers: int = typer.Option( 2, "--num-workers", help="数据加载worker数量(流式数据集建议为2)" ), use_tensorboard: bool = typer.Option( True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" ), resume_from: Optional[str] = typer.Option( None, "--resume-from", help="从检查点恢复训练" ), reset_training_state: bool = typer.Option( False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练" ), seed: int = typer.Option(42, "--seed", help="随机种子"), compile: bool = typer.Option( False, "--compile/--no-compile", help="是否开启 torch.compile 优化(需 PyTorch 2.0+)", ), ): torch.multiprocessing.set_sharing_strategy("file_system") # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) if torch.cuda.is_available(): torch.set_float32_matmul_precision("high") # 设置随机种子 torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) console = Console() # 打印配置信息 console.print( Panel.fit( "[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan" ) ) config_table = Table(show_header=True, header_style="bold magenta") config_table.add_column("Category", style="cyan") config_table.add_column("Parameter", style="green") config_table.add_column("Value", style="yellow") # 添加配置信息 config_table.add_row("数据", "训练数据路径", train_data_path) config_table.add_row("数据", "评估数据路径", eval_data_path) config_table.add_row("数据", "输出目录", output_dir) config_table.add_row("数据", "批次大小", str(batch_size)) config_table.add_row("数据", "Worker数量", str(num_workers)) config_table.add_row("模型", "基础模型路径", base_model_path) config_table.add_row("模型", "新模型规格", new_model_spec) config_table.add_row("模型", "词汇表大小", str(vocab_size)) config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size)) config_table.add_row("模型", "模型维度", str(dim)) config_table.add_row("模型", "槽位数量", str(num_slots)) config_table.add_row("模型", "Transformer层数", str(n_layers)) config_table.add_row("模型", "注意力头数", str(n_heads)) config_table.add_row("模型", "MoE专家数", str(num_experts)) config_table.add_row("模型", "使用拼音", str(use_pinyin)) config_table.add_row("模型", "编译优化", str(compile)) config_table.add_row("训练", "训练轮数", str(num_epochs)) config_table.add_row("训练", "学习率", f"{learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}") config_table.add_row("训练", "权重衰减", str(weight_decay)) config_table.add_row("训练", "热身比例", str(warmup_ratio)) config_table.add_row("训练", "标签平滑", str(label_smoothing)) config_table.add_row("训练", "梯度累积", str(grad_accum_steps)) config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) console.print(config_table) # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 保存配置 config = { "train_data_path": train_data_path, "eval_data_path": eval_data_path, "output_dir": output_dir, "base_model_path": base_model_path, "new_model_spec": new_model_spec, "vocab_size": vocab_size, "pinyin_vocab_size": pinyin_vocab_size, "dim": dim, "num_slots": num_slots, "n_layers": n_layers, "n_heads": n_heads, "num_experts": num_experts, "max_seq_len": max_seq_len, "use_pinyin": use_pinyin, "batch_size": batch_size, "num_workers": num_workers, "num_epochs": num_epochs, "learning_rate": learning_rate, "min_learning_rate": min_learning_rate, "weight_decay": weight_decay, "warmup_ratio": warmup_ratio, "label_smoothing": label_smoothing, "grad_accum_steps": grad_accum_steps, "clip_grad_norm": clip_grad_norm, "eval_frequency": eval_frequency, "save_frequency": save_frequency, "mixed_precision": mixed_precision, "use_tensorboard": use_tensorboard, "seed": seed, "max_iter_length": max_iter_length, "compile": compile, } config_file = output_path / "expansion_training_config.json" with open(config_file, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info(f"Configuration saved to {config_file}") # 创建数据加载器 console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") # 训练数据集 train_dataset = PinyinInputDataset( data_path=train_data_path, max_workers=-1, # 自动选择worker数量 max_iter_length=max_iter_length, max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=5000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) # 训练数据加载器 train_dataloader = create_dataloader( dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=torch.cuda.is_available(), max_iter_length=max_iter_length, ) # 评估数据集 eval_dataset = PinyinInputDataset( data_path=eval_data_path, max_workers=-1, max_iter_length=batch_size * 64, # 评估集较小 max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=50000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=batch_size, num_workers=1, # 评估使用较少的worker pin_memory=torch.cuda.is_available(), max_iter_length=batch_size * 64, ) # 创建扩容模型 console.print("[bold cyan]正在创建扩容模型...[/bold cyan]") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_kwargs = { "vocab_size": vocab_size, "pinyin_vocab_size": pinyin_vocab_size, "dim": dim, "num_slots": num_slots, "n_layers": n_layers, "n_heads": n_heads, "num_experts": num_experts, "max_seq_len": max_seq_len, "compile": compile, } model = load_expanded_model( base_model_path=base_model_path, new_model_spec=new_model_spec, device=device, **model_kwargs, ) console.print( f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" ) # 统计冻结参数比例 total_params = sum(p.numel() for p in model.parameters()) frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) console.print( f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]" ) # 创建训练器(使用普通 Trainer,只进行第一阶段冻结训练) console.print("[bold cyan]正在创建训练器...[/bold cyan]") trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, total_steps=int(max_iter_length * num_epochs / batch_size), output_dir=output_dir, num_epochs=num_epochs, learning_rate=learning_rate, min_learning_rate=min_learning_rate, weight_decay=weight_decay, warmup_ratio=warmup_ratio, label_smoothing=label_smoothing, grad_accum_steps=grad_accum_steps, clip_grad_norm=clip_grad_norm, eval_frequency=eval_frequency, save_frequency=save_frequency, mixed_precision=mixed_precision, use_tensorboard=use_tensorboard, status_file="training_status.json", ) console.print("[green]✓ 训练器创建完成[/green]") # 开始训练 console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( resume_from=resume_from, reset_training_state=reset_training_state ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") trainer.save_checkpoint("interrupted_model.pt") # 保存扩容信息供第二阶段使用 expansion_info = { "stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"), "model_spec": new_model_spec, "model_kwargs": model_kwargs, "train_data_path": train_data_path, "eval_data_path": eval_data_path, "output_dir": output_dir, "batch_size": batch_size, "max_iter_length": max_iter_length, "max_seq_len": max_seq_len, "num_workers": num_workers, } expansion_info_file = output_path / "expansion_info.json" with open(expansion_info_file, "w", encoding="utf-8") as f: json.dump(expansion_info, f, indent=2, ensure_ascii=False) logger.info(f"Expansion info saved to {expansion_info_file}") console.print("[bold green]✓ 第一阶段训练完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"模型和日志保存在: {output_dir}") console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]") console.print( "[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]" ) @app.command() def expand_finetune( expand_config: str = typer.Option( ..., "--expand-config", "-c", help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'", ), stage1_info: str = typer.Option( ..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径" ), # 可选覆盖参数 checkpoint: Optional[str] = typer.Option( None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)" ), output_dir: Optional[str] = typer.Option( None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)" ), train_data_path: Optional[str] = typer.Option( None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)" ), eval_data_path: Optional[str] = typer.Option( None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)" ), batch_size: Optional[int] = typer.Option( None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)" ), num_epochs: Optional[int] = typer.Option( None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)" ), learning_rate: Optional[float] = typer.Option( None, "--learning-rate", "-lr", help="学习率" ), min_learning_rate: Optional[float] = typer.Option( None, "--min-learning-rate", help="最小学习率" ), weight_decay: Optional[float] = typer.Option( None, "--weight-decay", help="权重衰减" ), warmup_ratio: Optional[float] = typer.Option( None, "--warmup-ratio", help="热身步数比例" ), label_smoothing: Optional[float] = typer.Option( None, "--label-smoothing", help="标签平滑参数" ), grad_accum_steps: Optional[int] = typer.Option( None, "--grad-accum-steps", help="梯度累积步数" ), clip_grad_norm: Optional[float] = typer.Option( None, "--clip-grad-norm", help="梯度裁剪范数" ), eval_frequency: Optional[int] = typer.Option( None, "--eval-frequency", help="评估频率" ), save_frequency: Optional[int] = typer.Option( None, "--save-frequency", help="保存频率" ), max_iter_length: Optional[int] = typer.Option( None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)" ), max_seq_len: Optional[int] = typer.Option( None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)" ), num_workers: Optional[int] = typer.Option( None, "--num-workers", help="数据加载worker数量" ), mixed_precision: bool = typer.Option( True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" ), use_tensorboard: bool = typer.Option( True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" ), resume_from: Optional[str] = typer.Option( None, "--resume-from", help="从检查点恢复训练" ), reset_training_state: bool = typer.Option( False, "--reset-training-state", help="重置训练状态" ), seed: int = typer.Option(42, "--seed", help="随机种子"), compile: Optional[bool] = typer.Option( None, "--compile/--no-compile", help="是否开启 torch.compile 优化" ), ): """ 模型扩容第二阶段训练:读取第一阶段的 expansion_info.json,加载扩容模型进行全量微调。 命令行参数优先级高于 JSON 文件中的配置。 """ torch.multiprocessing.set_sharing_strategy("file_system") if torch.cuda.is_available(): torch.set_float32_matmul_precision("high") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) console = Console() # 加载第一阶段信息 stage1_info_path = Path(stage1_info) if not stage1_info_path.exists(): console.print( f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]" ) raise typer.Exit(1) with open(stage1_info_path, "r", encoding="utf-8") as f: info = json.load(f) # 命令行参数优先级高于 JSON 文件 final_checkpoint = checkpoint or info["stage1_checkpoint_path"] final_output_dir = output_dir or info["output_dir"] final_train_data_path = train_data_path or info["train_data_path"] final_eval_data_path = eval_data_path or info["eval_data_path"] final_batch_size = batch_size if batch_size is not None else info["batch_size"] final_num_epochs = ( num_epochs if num_epochs is not None else info.get("num_epochs", 10) ) final_max_iter_length = ( max_iter_length if max_iter_length is not None else info["max_iter_length"] ) final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"] final_num_workers = ( num_workers if num_workers is not None else info.get("num_workers", 2) ) # 训练参数(有默认值,不覆盖则使用默认) final_learning_rate = learning_rate if learning_rate is not None else 1e-4 final_min_learning_rate = ( min_learning_rate if min_learning_rate is not None else 1e-9 ) final_weight_decay = weight_decay if weight_decay is not None else 0.1 final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1 final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15 final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1 final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0 final_eval_frequency = eval_frequency if eval_frequency is not None else 500 final_save_frequency = save_frequency if save_frequency is not None else 1000 # 模型参数从 JSON 获取 model_kwargs = info["model_kwargs"] if compile is not None: model_kwargs["compile"] = compile console.print( Panel.fit( "[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan" ) ) config_table = Table(show_header=True, header_style="bold magenta") config_table.add_column("Category", style="cyan") config_table.add_column("Parameter", style="green") config_table.add_column("Value", style="yellow") config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path)) config_table.add_row("数据", "训练数据路径", final_train_data_path) config_table.add_row("数据", "评估数据路径", final_eval_data_path) config_table.add_row("数据", "输出目录", final_output_dir) config_table.add_row("数据", "批次大小", str(final_batch_size)) config_table.add_row("数据", "Worker数量", str(final_num_workers)) config_table.add_row("模型", "新模型规格", expand_config) config_table.add_row("模型", "检查点路径", final_checkpoint) for k, v in model_kwargs.items(): config_table.add_row("模型", k, str(v)) config_table.add_row("训练", "训练轮数", str(final_num_epochs)) config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}") config_table.add_row("训练", "权重衰减", str(final_weight_decay)) config_table.add_row("训练", "热身比例", str(final_warmup_ratio)) config_table.add_row("训练", "标签平滑", str(final_label_smoothing)) config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps)) config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) console.print(config_table) output_path = Path(final_output_dir) output_path.mkdir(parents=True, exist_ok=True) console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") train_dataset = PinyinInputDataset( data_path=final_train_data_path, max_workers=-1, max_iter_length=final_max_iter_length, max_seq_length=final_max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=5000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) train_dataloader = create_dataloader( dataset=train_dataset, batch_size=final_batch_size, num_workers=final_num_workers, pin_memory=torch.cuda.is_available(), max_iter_length=final_max_iter_length, ) eval_dataset = PinyinInputDataset( data_path=final_eval_data_path, max_workers=-1, max_iter_length=final_batch_size * 64, max_seq_length=final_max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=50000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=final_batch_size, num_workers=1, pin_memory=torch.cuda.is_available(), max_iter_length=final_batch_size * 64, ) console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_expanded_model( base_model_path=final_checkpoint, new_model_spec=expand_config, device=device, **model_kwargs, ) # 全量微调:解冻所有参数 for param in model.parameters(): param.requires_grad = True console.print( f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" ) console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]") console.print("[bold cyan]正在创建训练器...[/bold cyan]") trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size), output_dir=final_output_dir, num_epochs=final_num_epochs, learning_rate=final_learning_rate, min_learning_rate=final_min_learning_rate, weight_decay=final_weight_decay, warmup_ratio=final_warmup_ratio, label_smoothing=final_label_smoothing, grad_accum_steps=final_grad_accum_steps, clip_grad_norm=final_clip_grad_norm, eval_frequency=final_eval_frequency, save_frequency=final_save_frequency, mixed_precision=mixed_precision, use_tensorboard=use_tensorboard, status_file="training_status_finetune.json", ) console.print("[green]✓ 训练器创建完成[/green]") console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( resume_from=resume_from, reset_training_state=reset_training_state ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") trainer.save_checkpoint("interrupted_model.pt") console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"模型和日志保存在: {final_output_dir}") if __name__ == "__main__": app()