526 lines
19 KiB
Python
526 lines
19 KiB
Python
"""
|
||
交互式训练配置界面
|
||
|
||
提供终端风格的交互式配置界面,用于配置输入法模型训练参数。
|
||
使用Rich库创建美观的终端界面。
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.table import Table
|
||
from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt
|
||
from rich.layout import Layout
|
||
from rich.columns import Columns
|
||
from rich.text import Text
|
||
from rich import box
|
||
|
||
|
||
class TrainingConfigUI:
|
||
"""
|
||
训练配置交互式界面
|
||
|
||
提供终端风格的交互式配置界面,按必需程度分组参数:
|
||
1. 必需参数(必须提供)
|
||
2. 推荐参数(有合理默认值)
|
||
3. 可选参数(高级参数)
|
||
4. 恢复参数(训练恢复相关)
|
||
"""
|
||
|
||
def __init__(self, console: Optional[Console] = None):
|
||
"""
|
||
初始化交互式界面
|
||
|
||
Args:
|
||
console: Rich控制台实例,如果为None则创建新的
|
||
"""
|
||
self.console = console or Console()
|
||
self.config = {}
|
||
|
||
# 参数定义
|
||
self.param_definitions = self._load_param_definitions()
|
||
|
||
# 颜色主题
|
||
self.colors = {
|
||
"primary": "cyan",
|
||
"secondary": "green",
|
||
"accent": "yellow",
|
||
"warning": "red",
|
||
"success": "green",
|
||
"info": "blue",
|
||
}
|
||
|
||
def _load_param_definitions(self) -> Dict[str, List[Dict]]:
|
||
"""
|
||
加载参数定义
|
||
|
||
Returns:
|
||
按组分类的参数定义
|
||
"""
|
||
return {
|
||
"required": [
|
||
{
|
||
"name": "train_data_path",
|
||
"type": "path",
|
||
"prompt": "训练数据集路径",
|
||
"required": True,
|
||
"validate": self._validate_path,
|
||
},
|
||
{
|
||
"name": "eval_data_path",
|
||
"type": "path",
|
||
"prompt": "评估数据集路径",
|
||
"required": True,
|
||
"validate": self._validate_path,
|
||
},
|
||
],
|
||
"recommended": [
|
||
{
|
||
"name": "output_dir",
|
||
"type": "path",
|
||
"prompt": "输出目录",
|
||
"default": "./output",
|
||
"required": False,
|
||
},
|
||
{
|
||
"name": "batch_size",
|
||
"type": "int",
|
||
"prompt": "批次大小",
|
||
"default": 128,
|
||
"required": False,
|
||
"validate": lambda x: x > 0,
|
||
},
|
||
{
|
||
"name": "num_epochs",
|
||
"type": "int",
|
||
"prompt": "训练轮数",
|
||
"default": 10,
|
||
"required": False,
|
||
"validate": lambda x: x > 0,
|
||
},
|
||
{
|
||
"name": "learning_rate",
|
||
"type": "float",
|
||
"prompt": "学习率",
|
||
"default": 1e-5,
|
||
"required": False,
|
||
"validate": lambda x: 1e-9 <= x <= 1,
|
||
},
|
||
{
|
||
"name": "mixed_precision",
|
||
"type": "bool",
|
||
"prompt": "使用混合精度训练",
|
||
"default": True,
|
||
"required": False,
|
||
},
|
||
{
|
||
"name": "compile",
|
||
"type": "bool",
|
||
"prompt": "启用torch.compile优化",
|
||
"default": False,
|
||
"required": False,
|
||
},
|
||
{
|
||
"name": "use_tensorboard",
|
||
"type": "bool",
|
||
"prompt": "使用TensorBoard记录",
|
||
"default": True,
|
||
"required": False,
|
||
},
|
||
],
|
||
"optional": [
|
||
{
|
||
"name": "min_learning_rate",
|
||
"type": "float",
|
||
"prompt": "最小学习率",
|
||
"default": 1e-9,
|
||
"required": False,
|
||
"validate": lambda x: 1e-12 <= x <= 1,
|
||
},
|
||
{
|
||
"name": "warmup_ratio",
|
||
"type": "float",
|
||
"prompt": "热身步数比例",
|
||
"default": 0.1,
|
||
"required": False,
|
||
"validate": lambda x: 0 <= x <= 1,
|
||
},
|
||
{
|
||
"name": "eval_frequency",
|
||
"type": "int",
|
||
"prompt": "评估频率(步数)",
|
||
"default": 500,
|
||
"required": False,
|
||
"validate": lambda x: x > 0,
|
||
},
|
||
{
|
||
"name": "save_frequency",
|
||
"type": "int",
|
||
"prompt": "保存检查点频率(步数)",
|
||
"default": 1000,
|
||
"required": False,
|
||
"validate": lambda x: x > 0,
|
||
},
|
||
{
|
||
"name": "max_iter_length",
|
||
"type": "int",
|
||
"prompt": "数据集大小",
|
||
"default": 1024 * 1024 * 128,
|
||
"required": False,
|
||
"validate": lambda x: x > 0,
|
||
},
|
||
{
|
||
"name": "num_workers",
|
||
"type": "int",
|
||
"prompt": "数据加载worker数量",
|
||
"default": 2,
|
||
"required": False,
|
||
"validate": lambda x: x >= 0,
|
||
},
|
||
{
|
||
"name": "seed",
|
||
"type": "int",
|
||
"prompt": "随机种子",
|
||
"default": 42,
|
||
"required": False,
|
||
},
|
||
],
|
||
"recovery": [
|
||
{
|
||
"name": "auto_resume",
|
||
"type": "bool",
|
||
"prompt": "自动从最新checkpoint恢复",
|
||
"default": True,
|
||
"required": False,
|
||
},
|
||
{
|
||
"name": "resume_from",
|
||
"type": "path",
|
||
"prompt": "从指定checkpoint恢复(可选)",
|
||
"default": None,
|
||
"required": False,
|
||
"validate": self._validate_optional_path,
|
||
},
|
||
{
|
||
"name": "reset_training_state",
|
||
"type": "bool",
|
||
"prompt": "重置训练状态(只加载权重)",
|
||
"default": False,
|
||
"required": False,
|
||
},
|
||
],
|
||
}
|
||
|
||
def _validate_path(self, path: str) -> bool:
|
||
"""
|
||
验证路径
|
||
|
||
Args:
|
||
path: 路径字符串
|
||
|
||
Returns:
|
||
是否有效
|
||
"""
|
||
if not path:
|
||
return False
|
||
path_obj = Path(path)
|
||
return path_obj.exists()
|
||
|
||
def _validate_optional_path(self, path: Optional[str]) -> bool:
|
||
"""
|
||
验证可选路径(可以为None)
|
||
|
||
Args:
|
||
path: 路径字符串或None
|
||
|
||
Returns:
|
||
是否有效
|
||
"""
|
||
if path is None or path == "":
|
||
return True
|
||
return self._validate_path(path)
|
||
|
||
def show_welcome(self):
|
||
"""
|
||
显示欢迎界面
|
||
"""
|
||
self.console.clear()
|
||
|
||
welcome_text = Text()
|
||
welcome_text.append("🚀 ", style="bold yellow")
|
||
welcome_text.append("输入法模型训练系统", style="bold cyan")
|
||
welcome_text.append("\n")
|
||
welcome_text.append("=" * 50, style="dim")
|
||
welcome_text.append("\n\n")
|
||
welcome_text.append("欢迎使用交互式训练配置界面!\n", style="bold green")
|
||
welcome_text.append("请按照提示配置训练参数。\n", style="green")
|
||
welcome_text.append("\n")
|
||
welcome_text.append("按 [Enter] 使用默认值\n", style="dim")
|
||
welcome_text.append("按 [Ctrl+C] 退出\n", style="dim")
|
||
|
||
panel = Panel(
|
||
welcome_text,
|
||
title="🎯 开始配置",
|
||
border_style=self.colors["primary"],
|
||
padding=(1, 2),
|
||
expand=False,
|
||
)
|
||
|
||
self.console.print(panel)
|
||
self.console.print()
|
||
|
||
def ask_param_group(self, group_name: str, group_title: str) -> Dict[str, Any]:
|
||
"""
|
||
询问一组参数
|
||
|
||
Args:
|
||
group_name: 参数组名称
|
||
group_title: 显示标题
|
||
|
||
Returns:
|
||
该组的配置字典
|
||
"""
|
||
group_config = {}
|
||
params = self.param_definitions.get(group_name, [])
|
||
|
||
if not params:
|
||
return group_config
|
||
|
||
# 显示组标题
|
||
self.console.print()
|
||
self.console.print(f"[bold {self.colors['primary']}]{group_title}[/bold {self.colors['primary']}]")
|
||
self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]")
|
||
|
||
for param in params:
|
||
name = param["name"]
|
||
prompt = param["prompt"]
|
||
param_type = param["type"]
|
||
default = param.get("default")
|
||
required = param.get("required", False)
|
||
validate_func = param.get("validate")
|
||
|
||
# 如果已经有值(从命令行传入),跳过
|
||
if name in self.config and self.config[name] is not None:
|
||
group_config[name] = self.config[name]
|
||
self.console.print(f" {prompt}: [green]{self.config[name]}[/green] (已提供)")
|
||
continue
|
||
|
||
# 询问参数
|
||
while True:
|
||
try:
|
||
if param_type == "bool":
|
||
value = Confirm.ask(
|
||
f" {prompt}",
|
||
default=default if default is not None else False,
|
||
show_default=True,
|
||
)
|
||
elif param_type == "int":
|
||
value = IntPrompt.ask(
|
||
f" {prompt}",
|
||
default=default if default is not None else 0,
|
||
show_default=True,
|
||
)
|
||
elif param_type == "float":
|
||
value = FloatPrompt.ask(
|
||
f" {prompt}",
|
||
default=default if default is not None else 0.0,
|
||
show_default=True,
|
||
)
|
||
else: # "path" or other string types
|
||
default_str = str(default) if default is not None else ""
|
||
value = Prompt.ask(
|
||
f" {prompt}",
|
||
default=default_str,
|
||
show_default=True,
|
||
)
|
||
# 处理空字符串
|
||
if value == "":
|
||
value = None
|
||
|
||
# 验证
|
||
if validate_func and value is not None:
|
||
if not validate_func(value):
|
||
self.console.print(f" [red]无效值,请重新输入[/red]")
|
||
continue
|
||
|
||
# 检查必需参数
|
||
if required and (value is None or value == ""):
|
||
self.console.print(f" [red]此参数为必需参数[/red]")
|
||
continue
|
||
|
||
group_config[name] = value
|
||
break
|
||
|
||
except KeyboardInterrupt:
|
||
self.console.print("\n[yellow]已取消[/yellow]")
|
||
raise
|
||
except Exception as e:
|
||
self.console.print(f" [red]输入错误: {e}[/red]")
|
||
if not required:
|
||
# 宽松处理:使用默认值
|
||
group_config[name] = default
|
||
self.console.print(f" [yellow]使用默认值: {default}[/yellow]")
|
||
break
|
||
|
||
return group_config
|
||
|
||
def show_config_summary(self, config: Dict[str, Any]):
|
||
"""
|
||
显示配置摘要
|
||
|
||
Args:
|
||
config: 完整的配置字典
|
||
"""
|
||
self.console.print()
|
||
self.console.print(f"[bold {self.colors['primary']]}📋 配置摘要[/bold {self.colors['primary']}]")
|
||
self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]")
|
||
|
||
# 创建表格
|
||
table = Table(show_header=True, header_style=f"bold {self.colors['accent']}", box=box.ROUNDED)
|
||
table.add_column("参数", style=self.colors["primary"])
|
||
table.add_column("值", style=self.colors["success"])
|
||
table.add_column("类型", style=self.colors["info"])
|
||
|
||
# 按组添加参数
|
||
groups = [
|
||
("必需参数", "required"),
|
||
("推荐参数", "recommended"),
|
||
("可选参数", "optional"),
|
||
("恢复参数", "recovery"),
|
||
]
|
||
|
||
for group_title, group_name in groups:
|
||
params = self.param_definitions.get(group_name, [])
|
||
if params:
|
||
table.add_row(f"[bold]{group_title}[/bold]", "", "", style="bold")
|
||
|
||
for param in params:
|
||
name = param["name"]
|
||
if name in config:
|
||
value = config[name]
|
||
param_type = param["type"]
|
||
|
||
# 格式化值
|
||
if value is None:
|
||
value_str = "[dim]None[/dim]"
|
||
elif isinstance(value, bool):
|
||
value_str = "是" if value else "否"
|
||
else:
|
||
value_str = str(value)
|
||
|
||
table.add_row(f" {param['prompt']}", value_str, param_type)
|
||
|
||
self.console.print(table)
|
||
self.console.print()
|
||
|
||
def run(self, initial_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
运行完整的交互式配置流程
|
||
|
||
Args:
|
||
initial_config: 初始配置(如从命令行传入的参数)
|
||
|
||
Returns:
|
||
完整的配置字典
|
||
"""
|
||
# 初始化配置
|
||
self.config = initial_config or {}
|
||
|
||
try:
|
||
# 显示欢迎界面
|
||
self.show_welcome()
|
||
|
||
# 询问必需参数
|
||
self.console.print()
|
||
if Confirm.ask("[bold cyan]是否配置必需参数?[/bold cyan]", default=True):
|
||
required_config = self.ask_param_group("required", "📁 必需参数")
|
||
self.config.update(required_config)
|
||
|
||
# 检查必需参数是否已提供
|
||
required_params = self.param_definitions["required"]
|
||
for param in required_params:
|
||
if param["name"] not in self.config or self.config[param["name"]] is None:
|
||
self.console.print(f"[red]错误: 必需参数 '{param['prompt']}' 未提供[/red]")
|
||
self.console.print("[yellow]请重新配置必需参数[/yellow]")
|
||
required_config = self.ask_param_group("required", "📁 必需参数")
|
||
self.config.update(required_config)
|
||
|
||
# 询问推荐参数
|
||
self.console.print()
|
||
if Confirm.ask("[bold cyan]是否配置推荐参数?[/bold cyan]", default=True):
|
||
recommended_config = self.ask_param_group("recommended", "⚙️ 推荐参数")
|
||
self.config.update(recommended_config)
|
||
|
||
# 询问可选参数
|
||
self.console.print()
|
||
if Confirm.ask("[bold cyan]是否配置可选参数(高级)?[/bold cyan]", default=False):
|
||
optional_config = self.ask_param_group("optional", "🎛️ 可选参数")
|
||
self.config.update(optional_config)
|
||
|
||
# 询问恢复参数
|
||
self.console.print()
|
||
if Confirm.ask("[bold cyan]是否配置恢复参数?[/bold cyan]", default=True):
|
||
recovery_config = self.ask_param_group("recovery", "🔄 恢复参数")
|
||
self.config.update(recovery_config)
|
||
|
||
# 显示配置摘要
|
||
self.show_config_summary(self.config)
|
||
|
||
# 确认配置
|
||
self.console.print()
|
||
if not Confirm.ask("[bold green]是否使用此配置开始训练?[/bold green]", default=True):
|
||
self.console.print("[yellow]配置已取消[/yellow]")
|
||
return {}
|
||
|
||
return self.config
|
||
|
||
except KeyboardInterrupt:
|
||
self.console.print("\n[yellow]配置已取消[/yellow]")
|
||
return {}
|
||
except Exception as e:
|
||
self.console.print(f"[red]配置过程中出现错误: {e}[/red]")
|
||
return {}
|
||
|
||
|
||
def get_interactive_config(provided_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
获取交互式配置
|
||
|
||
Args:
|
||
provided_params: 已提供的参数(如从命令行传入)
|
||
|
||
Returns:
|
||
完整的配置字典
|
||
"""
|
||
console = Console()
|
||
ui = TrainingConfigUI(console)
|
||
|
||
# 过滤掉None值
|
||
if provided_params:
|
||
provided_params = {k: v for k, v in provided_params.items() if v is not None}
|
||
|
||
config = ui.run(provided_params)
|
||
return config
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试交互式界面
|
||
console = Console()
|
||
console.print("[bold]测试交互式配置界面[/bold]")
|
||
|
||
# 模拟已提供的参数
|
||
test_params = {
|
||
"train_data_path": "./data/train.txt",
|
||
"eval_data_path": "./data/eval.txt",
|
||
}
|
||
|
||
config = get_interactive_config(test_params)
|
||
|
||
if config:
|
||
console.print("\n[green]✓ 配置完成[/green]")
|
||
console.print(f"配置参数: {config}")
|
||
else:
|
||
console.print("\n[yellow]配置已取消[/yellow]") |