import json import math import os import random from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.optim as optim import typer from loguru import logger from rich.console import Console from rich.panel import Panel from rich.progress import ( BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, ) from rich.table import Table from torch import autocast from torch.amp.grad_scaler import GradScaler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from .dataset import PinyinInputDataset # 导入模型和数据 from .model import InputMethodEngine class Trainer: """ 输入法模型训练器 实现训练InputMethodEngine模型,支持: - 预热+余弦退火学习率调度 - TensorBoard日志记录 - AdamW优化器(weight_decay=0.1) - 混合精度训练 - CrossEntropyLoss损失函数(支持weight和label_smoothing) - Rich终端美化输出 """ def __init__( self, model: InputMethodEngine, train_dataloader: DataLoader, eval_dataloader: DataLoader, total_steps: int, output_dir: str = "./output", num_epochs: int = 10, learning_rate: float = 1e-4, min_learning_rate: float = 1e-6, weight_decay: float = 0.1, warmup_ratio: float = 0.1, label_smoothing: float = 0.15, loss_weight: Optional[torch.Tensor] = None, grad_accum_steps: int = 1, clip_grad_norm: float = 1.0, eval_frequency: int = 500, save_frequency: int = 10000, mixed_precision: bool = True, device: Optional[torch.device] = None, use_tensorboard: bool = True, ): """ 初始化训练器 Args: model: 要训练的InputMethodEngine模型 train_dataloader: 训练数据加载器 eval_dataloader: 评估数据加载器(可选) output_dir: 输出目录,用于保存模型和日志 num_epochs: 训练轮数 total_steps: 总训练步数,如果为None则根据epochs计算 learning_rate: 最大学习率(预热后) min_learning_rate: 最小学习率(余弦退火后的最低值) weight_decay: AdamW优化器的权重衰减 warmup_ratio: 热身步数占总步数的比例 label_smoothing: CrossEntropyLoss的标签平滑参数 loss_weight: CrossEntropyLoss的类别权重 grad_accum_steps: 梯度累积步数 clip_grad_norm: 梯度裁剪的最大范数 eval_frequency: 评估频率(步数) save_frequency: 保存检查点频率(步数) mixed_precision: 是否使用混合精度训练 device: 训练设备,如果为None则自动选择 use_tensorboard: 是否使用TensorBoard记录 """ self.model = model self.train_dataloader = train_dataloader self.eval_dataloader = list([i for i in eval_dataloader]) self.output_dir = Path(output_dir) self.num_epochs = num_epochs self.learning_rate = learning_rate self.min_learning_rate = min_learning_rate self.weight_decay = weight_decay self.warmup_ratio = warmup_ratio self.label_smoothing = label_smoothing self.loss_weight = loss_weight self.grad_accum_steps = grad_accum_steps self.clip_grad_norm = clip_grad_norm self.eval_frequency = eval_frequency self.save_frequency = save_frequency self.mixed_precision = mixed_precision self.use_tensorboard = use_tensorboard # 设置设备 logger.info(f"GPU可用: {torch.cuda.is_available()}") if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device # 移动模型到设备 self.model.to(self.device) # 创建输出目录 self.output_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir = self.output_dir / "checkpoints" self.checkpoint_dir.mkdir(exist_ok=True) # 计算总步数 self.total_steps = total_steps self.warmup_steps = int(self.total_steps * warmup_ratio) # 初始化优化器 self.optimizer = optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8, ) # 初始化损失函数 if loss_weight is not None: self.criterion = nn.CrossEntropyLoss( weight=loss_weight.to(self.device), label_smoothing=label_smoothing ) else: self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) # 初始化混合精度训练器 device_type = "cuda" if "cuda" in str(self.device) else "cpu" self.scaler = GradScaler(device_type, enabled=mixed_precision) # 初始化TensorBoard if use_tensorboard: self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") else: self.writer = None # 初始化Rich控制台 self.console = Console() # 训练状态 self.current_step = 0 self.current_epoch = 0 self.best_eval_loss = float("inf") # 学习率调度函数 self.lr_scheduler = self._create_lr_scheduler() logger.info(f"Trainer initialized with device: {self.device}") logger.info( f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}" ) logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}") def _create_lr_scheduler(self) -> Callable[[int], float]: """创建学习率调度函数(预热 + 余弦退火)""" def lr_scheduler(step: int) -> float: if step < self.warmup_steps: # 线性预热 return self.learning_rate * (step / self.warmup_steps) else: # 余弦退火 progress = (step - self.warmup_steps) / ( self.total_steps - self.warmup_steps ) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) decayed_lr = ( self.min_learning_rate + (self.learning_rate - self.min_learning_rate) * cosine_decay ) return decayed_lr return lr_scheduler def _get_current_lr(self) -> float: """获取当前学习率""" return self.lr_scheduler(self.current_step) def _update_learning_rate(self): """更新优化器中的学习率""" current_lr = self._get_current_lr() for param_group in self.optimizer.param_groups: param_group["lr"] = current_lr return current_lr def train_step( self, batch: Dict[str, torch.Tensor] ) -> Tuple[float, Dict[str, Any]]: """ 执行单个训练步骤 Args: batch: 包含输入数据的批次 Returns: loss: 损失值 metrics: 训练指标字典 """ self.model.train() # 移动数据到设备 input_ids = batch["input_ids"].to(self.device) token_type_ids = batch["token_type_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) history_slot_ids = batch["history_slot_ids"].to(self.device) pinyin_ids = batch["pinyin_ids"].to(self.device) labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size] # 混合精度训练 with autocast(device_type=self.device.type, enabled=self.mixed_precision): # 前向传播 logits = self.model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) # 计算损失 loss = self.criterion(logits, labels) loss = loss / self.grad_accum_steps # 反向传播 self.scaler.scale(loss).backward() metrics = { "loss": loss.item() * self.grad_accum_steps, "lr": self._get_current_lr(), } # 计算准确率 with torch.no_grad(): preds = torch.argmax(logits, dim=-1) correct = (preds == labels).sum().item() total = labels.size(0) metrics["accuracy"] = correct / total if total > 0 else 0.0 return loss.item() * self.grad_accum_steps, metrics def evaluate(self) -> Dict[str, float]: """ 在评估集上评估模型 Returns: 评估指标字典 """ if self.eval_dataloader is None: return {} self.model.eval() total_loss = 0.0 total_correct = 0 total_samples = 0 with torch.no_grad(): for batch in self.eval_dataloader: # 移动数据到设备 input_ids = batch["input_ids"].to(self.device) token_type_ids = batch["token_type_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) history_slot_ids = batch["history_slot_ids"].to(self.device) pinyin_ids = batch["pinyin_ids"].to(self.device) labels = batch["labels"].to(self.device).squeeze(-1) # 前向传播 logits = self.model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) # 计算损失 loss = self.criterion(logits, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 preds = torch.argmax(logits, dim=-1) correct = (preds == labels).sum().item() total_correct += correct total_samples += labels.size(0) avg_loss = total_loss / total_samples if total_samples > 0 else 0.0 accuracy = total_correct / total_samples if total_samples > 0 else 0.0 return {"eval_loss": avg_loss, "eval_accuracy": accuracy} def save_checkpoint(self, filename: str, is_best: bool = False): """ 保存检查点 Args: filename: 检查点文件名 is_best: 是否是最佳模型 """ checkpoint_path = self.checkpoint_dir / filename checkpoint = { "step": self.current_step, "epoch": self.current_epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scaler_state_dict": self.scaler.state_dict(), "best_eval_loss": self.best_eval_loss, "config": { "learning_rate": self.learning_rate, "weight_decay": self.weight_decay, "warmup_ratio": self.warmup_ratio, "label_smoothing": self.label_smoothing, "total_steps": self.total_steps, }, } torch.save(checkpoint, checkpoint_path) logger.info(f"Checkpoint saved to {checkpoint_path}") if is_best: best_path = self.checkpoint_dir / "best_model.pt" torch.save(checkpoint, best_path) logger.info(f"Best model saved to {best_path}") def load_checkpoint(self, checkpoint_path: Union[str, Path]): """ 加载检查点 Args: checkpoint_path: 检查点文件路径 """ checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.current_step = checkpoint["step"] self.current_epoch = checkpoint["epoch"] self.best_eval_loss = checkpoint["best_eval_loss"] logger.info(f"Checkpoint loaded from {checkpoint_path}") logger.info( f"Resuming from step {self.current_step}, epoch {self.current_epoch}" ) def _log_to_tensorboard(self, metrics: Dict[str, float], step: int): """将指标记录到TensorBoard""" if self.writer is None: return for key, value in metrics.items(): self.writer.add_scalar(key, value, step) def _create_progress_bar(self) -> Progress: """创建Rich进度条""" return Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), TimeRemainingColumn(), console=self.console, expand=True, ) def _print_training_info(self): """打印训练信息""" info_table = Table( title="Training Configuration", show_header=True, header_style="bold magenta", ) info_table.add_column("Parameter", style="cyan") info_table.add_column("Value", style="green") info_table.add_row("Device", str(self.device)) info_table.add_row("Total Steps", str(self.total_steps)) info_table.add_row("Warmup Steps", str(self.warmup_steps)) info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}") info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}") info_table.add_row("Weight Decay", str(self.weight_decay)) info_table.add_row("Label Smoothing", str(self.label_smoothing)) info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps)) info_table.add_row("Mixed Precision", str(self.mixed_precision)) self.console.print(info_table) def train(self, resume_from: Optional[str] = None): """ 主训练循环 Args: resume_from: 从哪个检查点恢复训练(可选) """ # 如果提供了检查点,则恢复训练 if resume_from is not None: self.load_checkpoint(resume_from) # 打印训练信息 self._print_training_info() # 初始化训练状态 global_step = self.current_step accumulated_loss = 0.0 accumulated_accuracy = 0.0 accumulation_counter = 0 # 创建进度条 with self._create_progress_bar() as progress: epoch_task = progress.add_task( f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}", total=self.total_steps, ) # 训练循环 for epoch in range(self.current_epoch, self.num_epochs): self.current_epoch = epoch progress.update( epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}" ) for batch_idx, batch in enumerate(self.train_dataloader): # 更新学习率 current_lr = self._update_learning_rate() # 训练步骤 loss, metrics = self.train_step(batch) # 累积指标 accumulated_loss += loss accumulated_accuracy += metrics.get("accuracy", 0.0) accumulation_counter += 1 # 梯度累积:每grad_accum_steps步更新一次参数 if (global_step + 1) % self.grad_accum_steps == 0: # 梯度裁剪 self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.clip_grad_norm ) # 更新参数 self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() # 更新进度条 progress.update( epoch_task, advance=1, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | " f"Step {global_step}/{self.total_steps} | " f"Loss: {loss:.4f} | " f"LR: {current_lr:.2e}", ) # 定期评估和记录 if (global_step + 1) % self.eval_frequency == 0: # 计算平均指标 avg_loss = accumulated_loss / accumulation_counter avg_accuracy = accumulated_accuracy / accumulation_counter # 评估模型 eval_metrics = self.evaluate() # 准备日志指标 log_metrics = { "train/loss": avg_loss, "train/accuracy": avg_accuracy, "train/learning_rate": current_lr, } if eval_metrics: log_metrics.update( { "eval/loss": eval_metrics["eval_loss"], "eval/accuracy": eval_metrics["eval_accuracy"], } ) # 更新最佳模型 if eval_metrics["eval_loss"] < self.best_eval_loss: self.best_eval_loss = eval_metrics["eval_loss"] self.save_checkpoint( f"step_{global_step + 1}.pt", is_best=True ) # 记录到TensorBoard self._log_to_tensorboard(log_metrics, global_step) # 打印日志 log_text = ( f"[Epoch {epoch + 1}/{self.num_epochs}] " f"[Step {global_step}/{self.total_steps}] " f"Train Loss: {avg_loss:.4f} | " f"Train Acc: {avg_accuracy:.4f} | " f"LR: {current_lr:.2e}" ) if eval_metrics: log_text += ( f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | " f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}" ) progress.console.log(log_text) # 重置累积指标 accumulated_loss = 0.0 accumulated_accuracy = 0.0 accumulation_counter = 0 # 定期保存检查点 if (global_step + 1) % self.save_frequency == 0: self.save_checkpoint(f"step_{global_step}.pt") # 更新步数 global_step += 1 self.current_step = global_step # 检查是否达到总步数 if global_step >= self.total_steps: progress.update(epoch_task, completed=self.total_steps) break # 重置进度条 progress.reset(epoch_task) # 每个epoch结束后保存检查点 self.save_checkpoint(f"epoch_{epoch + 1}.pt") # 检查是否达到总步数 if global_step >= self.total_steps: break # 训练完成 logger.info("Training completed!") # 保存最终模型 self.save_checkpoint("final_model.pt") # 关闭TensorBoard写入器 if self.writer is not None: self.writer.close() def worker_init_fn(worker_id: int) -> None: """ 初始化每个DataLoader worker的随机种子,确保可复现性 Args: worker_id: worker的ID """ worker_seed = torch.initial_seed() % (2**32) np.random.seed(worker_seed) random.seed(worker_seed) def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ 自定义批处理函数,将多个样本组合成一个batch Args: batch: 样本列表,每个样本是一个字典 Returns: 批处理后的字典,tensor字段已stack,字符串字段保持为列表 """ # 处理tensor字段 - 使用squeeze去除多余的batch维度 input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) labels = torch.stack([item["label"].squeeze(0) for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) # 字符串字段保持为列表 prefixes = [item["prefix"] for item in batch] suffixes = [item["suffix"] for item in batch] pinyins = [item["pinyin"] for item in batch] return { "input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "labels": labels, "history_slot_ids": history_slot_ids, "prefix": prefixes, "suffix": suffixes, "pinyin": pinyins, "pinyin_ids": pinyin_ids, } # Typer CLI应用 app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False) @app.command() def train( # 数据参数 train_data_path: str = typer.Option( ..., "--train-data-path", "-t", help="训练数据集路径" ), eval_data_path: str = typer.Option( ..., "--eval-data-path", "-e", help="评估数据集路径" ), output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"), # 模型参数 vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), pinyin_vocab_size: int = typer.Option( 30, "--pinyin-vocab-size", help="拼音词汇表大小" ), max_iter_length: int = typer.Option( 1024 * 1024 * 128, "--max_iter_length", help="数据集大小" ), dim: int = typer.Option(512, "--dim", help="模型维度"), num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"), n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"), n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"), num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"), max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"), use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"), min_learning_rate: float = typer.Option( 1e-9, "--min-learning-rate", help="最小学习率" ), weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"), warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"), label_smoothing: float = typer.Option( 0.15, "--label-smoothing", help="标签平滑参数" ), grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"), save_frequency: int = typer.Option(10000, "--save-frequency", help="保存频率"), # 其他参数 mixed_precision: bool = typer.Option( True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" ), use_tensorboard: bool = typer.Option( True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" ), resume_from: Optional[str] = typer.Option( None, "--resume-from", help="从检查点恢复训练" ), seed: int = typer.Option(42, "--seed", help="随机种子"), ): """ 训练输入法模型 """ torch.multiprocessing.set_sharing_strategy("file_system") # 设置随机种子 torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) console = Console() # 打印配置信息 console.print( Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan") ) config_table = Table(show_header=True, header_style="bold magenta") config_table.add_column("Category", style="cyan") config_table.add_column("Parameter", style="green") config_table.add_column("Value", style="yellow") # 添加配置信息 config_table.add_row("数据", "训练数据路径", train_data_path) config_table.add_row("数据", "评估数据路径", eval_data_path) config_table.add_row("数据", "输出目录", output_dir) config_table.add_row("数据", "批次大小", str(batch_size)) config_table.add_row("模型", "词汇表大小", str(vocab_size)) config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size)) config_table.add_row("模型", "模型维度", str(dim)) config_table.add_row("模型", "槽位数量", str(num_slots)) config_table.add_row("模型", "Transformer层数", str(n_layers)) config_table.add_row("模型", "注意力头数", str(n_heads)) config_table.add_row("模型", "MoE专家数", str(num_experts)) config_table.add_row("模型", "使用拼音", str(use_pinyin)) config_table.add_row("训练", "训练轮数", str(num_epochs)) config_table.add_row("训练", "学习率", f"{learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}") config_table.add_row("训练", "权重衰减", str(weight_decay)) config_table.add_row("训练", "热身比例", str(warmup_ratio)) config_table.add_row("训练", "标签平滑", str(label_smoothing)) config_table.add_row("训练", "梯度累积", str(grad_accum_steps)) config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) console.print(config_table) # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 保存配置 config = { "train_data_path": train_data_path, "eval_data_path": eval_data_path, "output_dir": output_dir, "vocab_size": vocab_size, "pinyin_vocab_size": pinyin_vocab_size, "dim": dim, "num_slots": num_slots, "n_layers": n_layers, "n_heads": n_heads, "num_experts": num_experts, "max_seq_len": max_seq_len, "use_pinyin": use_pinyin, "batch_size": batch_size, "num_epochs": num_epochs, "learning_rate": learning_rate, "min_learning_rate": min_learning_rate, "weight_decay": weight_decay, "warmup_ratio": warmup_ratio, "label_smoothing": label_smoothing, "grad_accum_steps": grad_accum_steps, "clip_grad_norm": clip_grad_norm, "eval_frequency": eval_frequency, "save_frequency": save_frequency, "mixed_precision": mixed_precision, "use_tensorboard": use_tensorboard, "seed": seed, "max_iter_length": max_iter_length, } config_file = output_path / "training_config.json" with open(config_file, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info(f"Configuration saved to {config_file}") # 创建数据加载器 console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") # 训练数据集 train_dataset = PinyinInputDataset( data_path=train_data_path, max_workers=-1, # 自动选择worker数量 max_iter_length=max_iter_length, max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=5000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) # 训练数据加载器 # 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数 # 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理 train_dataloader = DataLoader( train_dataset, batch_size=batch_size, num_workers=min(max(1, (os.cpu_count() or 1) - 1), 25), pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景 persistent_workers=True, # 保持worker进程存活,避免重建开销 ) # 评估数据集(使用相同的设置,但可以调整参数) eval_dataset = PinyinInputDataset( data_path=eval_data_path, max_workers=-1, max_iter_length=1024, # 评估集较小 max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=1000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) eval_dataloader = DataLoader( eval_dataset, batch_size=batch_size, num_workers=1, pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=64, # 每个worker预取64个batch persistent_workers=True, # 保持worker进程存活 ) console.print("[bold cyan]正在创建模型...[/bold cyan]") model = InputMethodEngine( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, num_slots=num_slots, n_layers=n_layers, n_heads=n_heads, num_experts=num_experts, max_seq_len=max_seq_len, ) console.print( f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" ) # 创建训练器 console.print("[bold cyan]正在创建训练器...[/bold cyan]") trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, total_steps=int(max_iter_length / batch_size), output_dir=output_dir, num_epochs=num_epochs, learning_rate=learning_rate, min_learning_rate=min_learning_rate, weight_decay=weight_decay, warmup_ratio=warmup_ratio, label_smoothing=label_smoothing, grad_accum_steps=grad_accum_steps, clip_grad_norm=clip_grad_norm, eval_frequency=eval_frequency, save_frequency=save_frequency, mixed_precision=mixed_precision, use_tensorboard=use_tensorboard, ) console.print("[green]✓ 训练器创建完成[/green]") # 开始训练 console.print("\n[bold cyan]开始训练...[/bold cyan]") console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") trainer.train(resume_from=resume_from) console.print("[bold green]✓ 训练完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"模型和日志保存在: {output_dir}") @app.command() def evaluate( checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"), data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"), batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"), ): """ 评估训练好的模型 """ console = Console() console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]") # 这里应该实现评估逻辑 # 1. 加载检查点 # 2. 创建数据加载器 # 3. 评估模型 console.print("[yellow]评估功能待实现[/yellow]") @app.command() def export( checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"), output_path: str = typer.Option( "./exported_model.onnx", "--output", "-o", help="输出路径" ), ): """ 导出模型为ONNX格式 """ console = Console() console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]") # 这里应该实现导出逻辑 # 1. 加载检查点 # 2. 导出为ONNX console.print("[yellow]导出功能待实现[/yellow]") if __name__ == "__main__": app()