""" 交互式训练配置界面 提供终端风格的交互式配置界面,用于配置输入法模型训练参数。 使用Rich库创建美观的终端界面。 """ import sys from pathlib import Path from typing import Any, Dict, List, Optional, Union from rich.console import Console from rich.panel import Panel from rich.table import Table from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt from rich.layout import Layout from rich.columns import Columns from rich.text import Text from rich import box class TrainingConfigUI: """ 训练配置交互式界面 提供终端风格的交互式配置界面,按必需程度分组参数: 1. 必需参数(必须提供) 2. 推荐参数(有合理默认值) 3. 可选参数(高级参数) 4. 恢复参数(训练恢复相关) """ def __init__(self, console: Optional[Console] = None): """ 初始化交互式界面 Args: console: Rich控制台实例,如果为None则创建新的 """ self.console = console or Console() self.config = {} # 参数定义 self.param_definitions = self._load_param_definitions() # 颜色主题 self.colors = { "primary": "cyan", "secondary": "green", "accent": "yellow", "warning": "red", "success": "green", "info": "blue", } def _load_param_definitions(self) -> Dict[str, List[Dict]]: """ 加载参数定义 Returns: 按组分类的参数定义 """ return { "required": [ { "name": "train_data_path", "type": "path", "prompt": "训练数据集路径", "required": True, "validate": self._validate_path, }, { "name": "eval_data_path", "type": "path", "prompt": "评估数据集路径", "required": True, "validate": self._validate_path, }, ], "recommended": [ { "name": "output_dir", "type": "path", "prompt": "输出目录", "default": "./output", "required": False, }, { "name": "batch_size", "type": "int", "prompt": "批次大小", "default": 128, "required": False, "validate": lambda x: x > 0, }, { "name": "num_epochs", "type": "int", "prompt": "训练轮数", "default": 10, "required": False, "validate": lambda x: x > 0, }, { "name": "learning_rate", "type": "float", "prompt": "学习率", "default": 1e-5, "required": False, "validate": lambda x: 1e-9 <= x <= 1, }, { "name": "mixed_precision", "type": "bool", "prompt": "使用混合精度训练", "default": True, "required": False, }, { "name": "compile", "type": "bool", "prompt": "启用torch.compile优化", "default": False, "required": False, }, { "name": "use_tensorboard", "type": "bool", "prompt": "使用TensorBoard记录", "default": True, "required": False, }, ], "optional": [ { "name": "min_learning_rate", "type": "float", "prompt": "最小学习率", "default": 1e-9, "required": False, "validate": lambda x: 1e-12 <= x <= 1, }, { "name": "warmup_ratio", "type": "float", "prompt": "热身步数比例", "default": 0.1, "required": False, "validate": lambda x: 0 <= x <= 1, }, { "name": "eval_frequency", "type": "int", "prompt": "评估频率(步数)", "default": 500, "required": False, "validate": lambda x: x > 0, }, { "name": "save_frequency", "type": "int", "prompt": "保存检查点频率(步数)", "default": 1000, "required": False, "validate": lambda x: x > 0, }, { "name": "max_iter_length", "type": "int", "prompt": "数据集大小", "default": 1024 * 1024 * 128, "required": False, "validate": lambda x: x > 0, }, { "name": "num_workers", "type": "int", "prompt": "数据加载worker数量", "default": 2, "required": False, "validate": lambda x: x >= 0, }, { "name": "seed", "type": "int", "prompt": "随机种子", "default": 42, "required": False, }, ], "recovery": [ { "name": "auto_resume", "type": "bool", "prompt": "自动从最新checkpoint恢复", "default": True, "required": False, }, { "name": "resume_from", "type": "path", "prompt": "从指定checkpoint恢复(可选)", "default": None, "required": False, "validate": self._validate_optional_path, }, { "name": "reset_training_state", "type": "bool", "prompt": "重置训练状态(只加载权重)", "default": False, "required": False, }, ], } def _validate_path(self, path: str) -> bool: """ 验证路径 Args: path: 路径字符串 Returns: 是否有效 """ if not path: return False path_obj = Path(path) return path_obj.exists() def _validate_optional_path(self, path: Optional[str]) -> bool: """ 验证可选路径(可以为None) Args: path: 路径字符串或None Returns: 是否有效 """ if path is None or path == "": return True return self._validate_path(path) def show_welcome(self): """ 显示欢迎界面 """ self.console.clear() welcome_text = Text() welcome_text.append("🚀 ", style="bold yellow") welcome_text.append("输入法模型训练系统", style="bold cyan") welcome_text.append("\n") welcome_text.append("=" * 50, style="dim") welcome_text.append("\n\n") welcome_text.append("欢迎使用交互式训练配置界面!\n", style="bold green") welcome_text.append("请按照提示配置训练参数。\n", style="green") welcome_text.append("\n") welcome_text.append("按 [Enter] 使用默认值\n", style="dim") welcome_text.append("按 [Ctrl+C] 退出\n", style="dim") panel = Panel( welcome_text, title="🎯 开始配置", border_style=self.colors["primary"], padding=(1, 2), expand=False, ) self.console.print(panel) self.console.print() def ask_param_group(self, group_name: str, group_title: str) -> Dict[str, Any]: """ 询问一组参数 Args: group_name: 参数组名称 group_title: 显示标题 Returns: 该组的配置字典 """ group_config = {} params = self.param_definitions.get(group_name, []) if not params: return group_config # 显示组标题 self.console.print() self.console.print(f"[bold {self.colors['primary']}]{group_title}[/bold {self.colors['primary']}]") self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]") for param in params: name = param["name"] prompt = param["prompt"] param_type = param["type"] default = param.get("default") required = param.get("required", False) validate_func = param.get("validate") # 如果已经有值(从命令行传入),跳过 if name in self.config and self.config[name] is not None: group_config[name] = self.config[name] self.console.print(f" {prompt}: [green]{self.config[name]}[/green] (已提供)") continue # 询问参数 while True: try: if param_type == "bool": value = Confirm.ask( f" {prompt}", default=default if default is not None else False, show_default=True, ) elif param_type == "int": value = IntPrompt.ask( f" {prompt}", default=default if default is not None else 0, show_default=True, ) elif param_type == "float": value = FloatPrompt.ask( f" {prompt}", default=default if default is not None else 0.0, show_default=True, ) else: # "path" or other string types default_str = str(default) if default is not None else "" value = Prompt.ask( f" {prompt}", default=default_str, show_default=True, ) # 处理空字符串 if value == "": value = None # 验证 if validate_func and value is not None: if not validate_func(value): self.console.print(f" [red]无效值,请重新输入[/red]") continue # 检查必需参数 if required and (value is None or value == ""): self.console.print(f" [red]此参数为必需参数[/red]") continue group_config[name] = value break except KeyboardInterrupt: self.console.print("\n[yellow]已取消[/yellow]") raise except Exception as e: self.console.print(f" [red]输入错误: {e}[/red]") if not required: # 宽松处理:使用默认值 group_config[name] = default self.console.print(f" [yellow]使用默认值: {default}[/yellow]") break return group_config def show_config_summary(self, config: Dict[str, Any]): """ 显示配置摘要 Args: config: 完整的配置字典 """ self.console.print() self.console.print(f"[bold {self.colors['primary']]}📋 配置摘要[/bold {self.colors['primary']}]") self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]") # 创建表格 table = Table(show_header=True, header_style=f"bold {self.colors['accent']}", box=box.ROUNDED) table.add_column("参数", style=self.colors["primary"]) table.add_column("值", style=self.colors["success"]) table.add_column("类型", style=self.colors["info"]) # 按组添加参数 groups = [ ("必需参数", "required"), ("推荐参数", "recommended"), ("可选参数", "optional"), ("恢复参数", "recovery"), ] for group_title, group_name in groups: params = self.param_definitions.get(group_name, []) if params: table.add_row(f"[bold]{group_title}[/bold]", "", "", style="bold") for param in params: name = param["name"] if name in config: value = config[name] param_type = param["type"] # 格式化值 if value is None: value_str = "[dim]None[/dim]" elif isinstance(value, bool): value_str = "是" if value else "否" else: value_str = str(value) table.add_row(f" {param['prompt']}", value_str, param_type) self.console.print(table) self.console.print() def run(self, initial_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 运行完整的交互式配置流程 Args: initial_config: 初始配置(如从命令行传入的参数) Returns: 完整的配置字典 """ # 初始化配置 self.config = initial_config or {} try: # 显示欢迎界面 self.show_welcome() # 询问必需参数 self.console.print() if Confirm.ask("[bold cyan]是否配置必需参数?[/bold cyan]", default=True): required_config = self.ask_param_group("required", "📁 必需参数") self.config.update(required_config) # 检查必需参数是否已提供 required_params = self.param_definitions["required"] for param in required_params: if param["name"] not in self.config or self.config[param["name"]] is None: self.console.print(f"[red]错误: 必需参数 '{param['prompt']}' 未提供[/red]") self.console.print("[yellow]请重新配置必需参数[/yellow]") required_config = self.ask_param_group("required", "📁 必需参数") self.config.update(required_config) # 询问推荐参数 self.console.print() if Confirm.ask("[bold cyan]是否配置推荐参数?[/bold cyan]", default=True): recommended_config = self.ask_param_group("recommended", "⚙️ 推荐参数") self.config.update(recommended_config) # 询问可选参数 self.console.print() if Confirm.ask("[bold cyan]是否配置可选参数(高级)?[/bold cyan]", default=False): optional_config = self.ask_param_group("optional", "🎛️ 可选参数") self.config.update(optional_config) # 询问恢复参数 self.console.print() if Confirm.ask("[bold cyan]是否配置恢复参数?[/bold cyan]", default=True): recovery_config = self.ask_param_group("recovery", "🔄 恢复参数") self.config.update(recovery_config) # 显示配置摘要 self.show_config_summary(self.config) # 确认配置 self.console.print() if not Confirm.ask("[bold green]是否使用此配置开始训练?[/bold green]", default=True): self.console.print("[yellow]配置已取消[/yellow]") return {} return self.config except KeyboardInterrupt: self.console.print("\n[yellow]配置已取消[/yellow]") return {} except Exception as e: self.console.print(f"[red]配置过程中出现错误: {e}[/red]") return {} def get_interactive_config(provided_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 获取交互式配置 Args: provided_params: 已提供的参数(如从命令行传入) Returns: 完整的配置字典 """ console = Console() ui = TrainingConfigUI(console) # 过滤掉None值 if provided_params: provided_params = {k: v for k, v in provided_params.items() if v is not None} config = ui.run(provided_params) return config if __name__ == "__main__": # 测试交互式界面 console = Console() console.print("[bold]测试交互式配置界面[/bold]") # 模拟已提供的参数 test_params = { "train_data_path": "./data/train.txt", "eval_data_path": "./data/eval.txt", } config = get_interactive_config(test_params) if config: console.print("\n[green]✓ 配置完成[/green]") console.print(f"配置参数: {config}") else: console.print("\n[yellow]配置已取消[/yellow]")