feat(cli): 升级 CLI 为 Typer + Rich 版本,新增配置命令与交互式向导
This commit is contained in:
parent
98b75a732c
commit
5129fc1bcf
|
|
@ -1,13 +1,34 @@
|
||||||
"""命令行入口"""
|
"""命令行入口 - Typer + Rich 版本"""
|
||||||
import fire
|
import typer
|
||||||
|
from typing import Optional
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.prompt import Prompt, Confirm
|
from rich.prompt import Prompt, Confirm
|
||||||
from typing import Optional
|
from rich.table import Table
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
from .config import ConfigManager
|
from .config import ConfigManager
|
||||||
from .llm import LLMClient
|
from .llm import LLMClient
|
||||||
|
|
||||||
|
# 创建 Typer 应用
|
||||||
|
app = typer.Typer(
|
||||||
|
help="AutoCommit - 自动生成提交消息工具",
|
||||||
|
context_settings={"help_option_names": ["-h", "--help"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 Rich 控制台
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
# 全局 CLI 实例(类似单例模式)
|
||||||
|
_cli_instance = None
|
||||||
|
|
||||||
|
def get_cli() -> "AutoCommitCLI":
|
||||||
|
"""获取 CLI 单例实例"""
|
||||||
|
global _cli_instance
|
||||||
|
if _cli_instance is None:
|
||||||
|
_cli_instance = AutoCommitCLI()
|
||||||
|
return _cli_instance
|
||||||
|
|
||||||
class AutoCommitCLI:
|
class AutoCommitCLI:
|
||||||
"""AutoCommit命令行接口"""
|
"""AutoCommit命令行接口"""
|
||||||
|
|
||||||
|
|
@ -15,140 +36,317 @@ class AutoCommitCLI:
|
||||||
self.config_manager = ConfigManager()
|
self.config_manager = ConfigManager()
|
||||||
self.llm_client = LLMClient(self.config_manager)
|
self.llm_client = LLMClient(self.config_manager)
|
||||||
|
|
||||||
def config(
|
def _create_config_table(self, config_data: dict) -> Table:
|
||||||
self,
|
"""创建配置表格 - 美观的 Rich 表格"""
|
||||||
api_key: Optional[str] = None,
|
table = Table(title="LLM 配置", show_header=True, header_style="bold cyan")
|
||||||
base_url: Optional[str] = None,
|
table.add_column("配置项", style="cyan", width=20)
|
||||||
model: Optional[str] = None,
|
table.add_column("当前值", style="green", width=40)
|
||||||
timeout: Optional[int] = None,
|
|
||||||
show: bool = False,
|
|
||||||
clear: bool = False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
配置LLM参数
|
|
||||||
|
|
||||||
Args:
|
for key, value in config_data.items():
|
||||||
api_key: LLM API密钥
|
if key == "api_key" and value:
|
||||||
base_url: API基础URL (默认: https://api.openai.com/v1)
|
display_value = "[yellow]已设置[/yellow] (安全隐藏)"
|
||||||
model: 模型名称 (默认: gpt-3.5-turbo)
|
elif value:
|
||||||
timeout: 请求超时时间(秒)
|
display_value = f"[green]{value}[/green]"
|
||||||
show: 显示当前配置
|
else:
|
||||||
clear: 清除所有配置
|
display_value = "[dim]未设置[/dim]"
|
||||||
"""
|
table.add_row(key.replace("_", " ").title(), display_value)
|
||||||
if clear:
|
|
||||||
if Confirm.ask("确定要清除所有配置吗?"):
|
|
||||||
self.config_manager.clear_config()
|
|
||||||
console.print("[green]配置已清除[/green]")
|
|
||||||
return
|
|
||||||
|
|
||||||
if show:
|
return table
|
||||||
config_data = self.config_manager.show_config()
|
|
||||||
console.print("[bold cyan]当前配置:[/bold cyan]")
|
|
||||||
|
|
||||||
|
# 配置命令组
|
||||||
|
config_app = typer.Typer(name="config", help="配置管理")
|
||||||
|
app.add_typer(config_app)
|
||||||
|
|
||||||
|
@config_app.command("show", help="显示当前配置")
|
||||||
|
def config_show():
|
||||||
|
"""显示当前配置"""
|
||||||
|
cli = get_cli()
|
||||||
|
config_data = cli.config_manager.show_config()
|
||||||
llm_config = config_data.get("llm", {})
|
llm_config = config_data.get("llm", {})
|
||||||
table = self._create_config_table(llm_config)
|
|
||||||
console.print(table)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 交互式配置
|
console.print("[bold cyan]📋 当前配置[/bold cyan]")
|
||||||
if not any([api_key, base_url, model, timeout is not None]):
|
console.print(cli._create_config_table(llm_config))
|
||||||
console.print("[bold yellow]交互式配置模式[/bold yellow]")
|
|
||||||
console.print("(直接按Enter跳过当前项)\n")
|
|
||||||
|
|
||||||
current_config = self.config_manager.show_config()["llm"]
|
|
||||||
|
|
||||||
api_key = Prompt.ask(
|
@config_app.command("clear", help="清除所有配置")
|
||||||
f"API密钥 (当前: {'*' * 8 if current_config.get('api_key') else '未设置'})",
|
def config_clear():
|
||||||
password=True,
|
"""清除配置"""
|
||||||
default=""
|
cli = get_cli()
|
||||||
) or None
|
if Confirm.ask("[yellow]⚠️ 确定要清除所有配置吗?[/yellow]"):
|
||||||
|
cli.config_manager.clear_config()
|
||||||
|
console.print("[green]✅ 配置已清除[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[dim]操作已取消[/dim]")
|
||||||
|
|
||||||
base_url = Prompt.ask(
|
@config_app.command("set", help="设置配置参数")
|
||||||
f"基础URL (当前: {current_config.get('base_url', 'https://api.openai.com/v1')})",
|
def config_set(
|
||||||
default=current_config.get('base_url', 'https://api.openai.com/v1')
|
api_key: Optional[str] = typer.Option(None, "--api-key", "-k", help="LLM API密钥"),
|
||||||
)
|
base_url: Optional[str] = typer.Option(None, "--base-url", "-u", help="API基础URL"),
|
||||||
|
model: Optional[str] = typer.Option(None, "--model", "-m", help="模型名称"),
|
||||||
model = Prompt.ask(
|
timeout: Optional[int] = typer.Option(None, "--timeout", "-t", help="请求超时时间(秒)")
|
||||||
f"模型名称 (当前: {current_config.get('model', 'gpt-3.5-turbo')})",
|
):
|
||||||
default=current_config.get('model', 'gpt-3.5-turbo')
|
"""设置配置参数"""
|
||||||
)
|
cli = get_cli()
|
||||||
|
|
||||||
timeout_input = Prompt.ask(
|
|
||||||
f"超时时间(秒) (当前: {current_config.get('timeout', 30)})",
|
|
||||||
default=str(current_config.get('timeout', 30))
|
|
||||||
)
|
|
||||||
timeout = int(timeout_input) if timeout_input else None
|
|
||||||
|
|
||||||
# 更新配置
|
# 更新配置
|
||||||
updated_config = self.config_manager.update_llm_config(
|
updated_config = cli.config_manager.update_llm_config(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
model=model,
|
model=model,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
console.print("[green]配置已更新[/green]")
|
console.print("[green]✅ 配置已更新[/green]")
|
||||||
table = self._create_config_table(updated_config.model_dump())
|
console.print(cli._create_config_table(updated_config.model_dump()))
|
||||||
console.print(table)
|
|
||||||
|
@config_app.command("interactive", help="交互式配置向导")
|
||||||
|
def config_interactive():
|
||||||
|
"""交互式配置向导"""
|
||||||
|
cli = get_cli()
|
||||||
|
|
||||||
|
console.print("[bold yellow]🎛️ 交互式配置向导[/bold yellow]")
|
||||||
|
console.print("[dim](直接按Enter跳过当前项)[/dim]\n")
|
||||||
|
|
||||||
|
current_config = cli.config_manager.show_config()["llm"]
|
||||||
|
|
||||||
|
api_key = Prompt.ask(
|
||||||
|
f"[cyan]API密钥[/cyan] (当前: {'[yellow]已设置[/yellow]' if current_config.get('api_key') else '[dim]未设置[/dim]'})",
|
||||||
|
password=True,
|
||||||
|
default=""
|
||||||
|
) or None
|
||||||
|
|
||||||
|
base_url = Prompt.ask(
|
||||||
|
f"[cyan]基础URL[/cyan] (当前: {current_config.get('base_url', 'https://api.openai.com/v1')})",
|
||||||
|
default=current_config.get('base_url', 'https://api.openai.com/v1')
|
||||||
|
)
|
||||||
|
|
||||||
|
model = Prompt.ask(
|
||||||
|
f"[cyan]模型名称[/cyan] (当前: {current_config.get('model', 'gpt-3.5-turbo')})",
|
||||||
|
default=current_config.get('model', 'gpt-3.5-turbo')
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout_input = Prompt.ask(
|
||||||
|
f"[cyan]超时时间(秒)[/cyan] (当前: {current_config.get('timeout', 30)})",
|
||||||
|
default=str(current_config.get('timeout', 30))
|
||||||
|
)
|
||||||
|
timeout = int(timeout_input) if timeout_input else None
|
||||||
|
|
||||||
|
# 更新配置
|
||||||
|
updated_config = cli.config_manager.update_llm_config(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
model=model,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print("\n[green]✅ 配置已更新[/green]")
|
||||||
|
console.print(cli._create_config_table(updated_config.model_dump()))
|
||||||
|
|
||||||
# 询问是否测试连接
|
# 询问是否测试连接
|
||||||
if Confirm.ask("是否测试LLM连接?"):
|
if Confirm.ask("[cyan]是否测试LLM连接?[/cyan]"):
|
||||||
self.check()
|
check(test=True)
|
||||||
|
|
||||||
def _create_config_table(self, config_data: dict):
|
@app.command(help="检查LLM可用性")
|
||||||
"""创建配置表格"""
|
def check(
|
||||||
from rich.table import Table
|
test: bool = typer.Option(False, "--test", "-t", help="发送测试消息")
|
||||||
table = Table(show_header=False, box=None)
|
):
|
||||||
table.add_column("项目", style="cyan")
|
"""检查LLM可用性"""
|
||||||
table.add_column("值", style="green")
|
cli = get_cli()
|
||||||
|
|
||||||
for key, value in config_data.items():
|
console.print("[bold cyan]🔍 检查LLM可用性...[/bold cyan]\n")
|
||||||
if key == "api_key" and value:
|
|
||||||
table.add_row(key, "*" * 8)
|
|
||||||
else:
|
|
||||||
table.add_row(key, str(value))
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
def check(self, test: bool = False):
|
|
||||||
"""
|
|
||||||
检查LLM可用性
|
|
||||||
|
|
||||||
Args:
|
|
||||||
test: 是否发送测试消息
|
|
||||||
"""
|
|
||||||
console.print("[bold cyan]检查LLM可用性...[/bold cyan]")
|
|
||||||
|
|
||||||
# 检查配置
|
# 检查配置
|
||||||
config = self.config_manager.load_config()
|
config = cli.config_manager.load_config()
|
||||||
if not config.llm.api_key:
|
if not config.llm.api_key:
|
||||||
console.print("[red]错误: API密钥未设置,请先运行 `autocommit config`[/red]")
|
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||||
return
|
console.print("[yellow]💡 提示: 请先运行 `autocommit config interactive` 进行配置[/yellow]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
# 检查可用性
|
# 检查可用性
|
||||||
result = self.llm_client.check_availability()
|
with console.status("[bold green]正在检查LLM连接...[/bold green]"):
|
||||||
self.llm_client.display_check_result(result)
|
result = cli.llm_client.check_availability()
|
||||||
|
|
||||||
|
cli.llm_client.display_check_result(result)
|
||||||
|
|
||||||
# 测试补全
|
# 测试补全
|
||||||
if test and result["available"]:
|
if test and result.get("available"):
|
||||||
console.print("\n[bold cyan]发送测试消息...[/bold cyan]")
|
console.print("\n[bold cyan]📤 发送测试消息...[/bold cyan]")
|
||||||
response = self.llm_client.test_completion()
|
with console.status("[bold green]正在等待LLM回复...[/bold green]"):
|
||||||
|
response = cli.llm_client.test_generation()
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
console.print(Panel(response, title="测试回复", border_style="green"))
|
console.print(Panel(
|
||||||
|
response,
|
||||||
|
title="🤖 LLM 测试回复",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 2)
|
||||||
|
))
|
||||||
|
|
||||||
def run(self):
|
def _run():
|
||||||
"""运行主程序(示例)"""
|
"""运行主程序 - 内部函数"""
|
||||||
console.print("[bold green]AutoCommit 启动[/bold green]")
|
cli = get_cli()
|
||||||
# 这里添加你的主要逻辑
|
|
||||||
|
|
||||||
def version(self):
|
console.print(Panel(
|
||||||
|
"[bold green]🚀 AutoCommit 启动[/bold green]\n"
|
||||||
|
"[dim]自动生成 Git 提交消息并提交暂存的文件[/dim]",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 4)
|
||||||
|
))
|
||||||
|
|
||||||
|
# 检查配置
|
||||||
|
config = cli.config_manager.load_config()
|
||||||
|
if not config.llm.api_key:
|
||||||
|
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||||
|
console.print("[yellow]💡 提示: 请先运行 `autocommit config interactive` 进行配置[/yellow]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# 初始化 GitHandler
|
||||||
|
try:
|
||||||
|
from .git import GitHandler
|
||||||
|
git_handler = GitHandler(cli.llm_client)
|
||||||
|
except ImportError:
|
||||||
|
console.print("[red]❌ 错误: 无法导入 Git 模块[/red]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]❌ 错误: 初始化 GitHandler 失败: {e}[/red]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# 执行自动提交
|
||||||
|
try:
|
||||||
|
console.print("[cyan]📤 正在检查暂存的文件...[/cyan]")
|
||||||
|
git_handler.auto_commit()
|
||||||
|
|
||||||
|
# 显示成功信息
|
||||||
|
console.print(Panel(
|
||||||
|
"[bold green]✅ 自动提交完成![/bold green]\n"
|
||||||
|
"[dim]已成功生成提交消息并提交暂存的更改[/dim]",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 4)
|
||||||
|
))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]❌ 自动提交失败: {e}[/red]")
|
||||||
|
console.print("[yellow]💡 提示: 确保您有暂存的文件并且处于 Git 仓库中[/yellow]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(help="测试生成提交消息(不实际提交)")
|
||||||
|
def test_commit():
|
||||||
|
"""测试生成提交消息而不实际提交"""
|
||||||
|
cli = get_cli()
|
||||||
|
|
||||||
|
console.print(Panel(
|
||||||
|
"[bold yellow]🧪 测试模式启动[/bold yellow]\n"
|
||||||
|
"[dim]将生成提交消息但不会实际提交到 Git[/dim]",
|
||||||
|
border_style="yellow",
|
||||||
|
padding=(1, 4)
|
||||||
|
))
|
||||||
|
|
||||||
|
# 检查配置
|
||||||
|
config = cli.config_manager.load_config()
|
||||||
|
if not config.llm.api_key:
|
||||||
|
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||||
|
console.print("[yellow]💡 提示: 请先运行 `autocommit config interactive` 进行配置[/yellow]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .git import GitHandler
|
||||||
|
git_handler = GitHandler(cli.llm_client)
|
||||||
|
|
||||||
|
# 获取差异
|
||||||
|
diff = git_handler.get_diff()
|
||||||
|
if not diff:
|
||||||
|
console.print("[yellow]⚠️ 没有暂存的更改[/yellow]")
|
||||||
|
console.print("[cyan]使用示例差异进行测试...[/cyan]")
|
||||||
|
|
||||||
|
# 使用示例差异
|
||||||
|
example_diff = """diff --git a/example.py b/example.py
|
||||||
|
index abc123..def456 100644
|
||||||
|
--- a/example.py
|
||||||
|
+++ b/example.py
|
||||||
|
@@ -1,5 +1,5 @@
|
||||||
|
def hello_world():
|
||||||
|
- print("Hello, World!")
|
||||||
|
+ print("Hello, OpenAI!")
|
||||||
|
|
||||||
|
def add(a, b):
|
||||||
|
return a + b"""
|
||||||
|
|
||||||
|
# 使用 LLMClient 直接测试生成
|
||||||
|
with console.status("[bold cyan]🤖 正在生成提交消息...[/bold cyan]"):
|
||||||
|
message = cli.llm_client.generate_commit_message(example_diff)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
console.print(Panel(
|
||||||
|
f"[bold green]如果提交,消息将是:[/bold green]\n\n{message}",
|
||||||
|
title="📝 测试生成的提交消息",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 2)
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
console.print("[red]❌ 无法生成提交消息[/red]")
|
||||||
|
else:
|
||||||
|
console.print("[green]✅ 检测到暂存的更改[/green]")
|
||||||
|
console.print("[cyan]正在分析差异...[/cyan]")
|
||||||
|
|
||||||
|
# 使用 LLMClient 生成消息但不提交
|
||||||
|
with console.status("[bold cyan]🤖 正在生成提交消息...[/bold cyan]"):
|
||||||
|
message = cli.llm_client.generate_commit_message(diff)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
console.print(Panel(
|
||||||
|
f"[bold green]如果提交,消息将是:[/bold green]\n\n{message}",
|
||||||
|
title="📝 测试生成的提交消息",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 2)
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
console.print("[red]❌ 无法生成提交消息[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# 询问用户是否要实际提交
|
||||||
|
if Confirm.ask("[yellow]是否使用此消息进行实际提交?[/yellow]"):
|
||||||
|
git_handler.commit_with_message(message)
|
||||||
|
console.print("[green]✅ 已提交更改[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[dim]测试完成,未实际提交[/dim]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]❌ 测试失败: {e}[/red]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(help="显示版本信息")
|
||||||
|
def version():
|
||||||
"""显示版本信息"""
|
"""显示版本信息"""
|
||||||
console.print("[bold cyan]AutoCommit v0.1.0[/bold cyan]")
|
console.print(Panel(
|
||||||
|
"[bold cyan]AutoCommit v0.1.0[/bold cyan]\n"
|
||||||
|
"[dim]基于 Typer + Rich 的现代化 CLI[/dim]",
|
||||||
|
border_style="cyan",
|
||||||
|
padding=(1, 4)
|
||||||
|
))
|
||||||
|
|
||||||
|
# 主回调函数 - 当没有子命令时执行
|
||||||
|
@app.callback(invoke_without_command=True)
|
||||||
|
def main_callback(ctx: typer.Context):
|
||||||
|
"""
|
||||||
|
AutoCommit - 自动生成 Git 提交消息工具
|
||||||
|
|
||||||
|
直接运行 autocommit 会自动分析暂存的 Git 更改并生成提交消息。
|
||||||
|
使用 autocommit config 进行配置管理。
|
||||||
|
"""
|
||||||
|
if ctx.invoked_subcommand is None:
|
||||||
|
_run()
|
||||||
|
|
||||||
|
# 主入口函数
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数 - Typer 入口点"""
|
||||||
cli = AutoCommitCLI()
|
# 设置彩色的错误消息
|
||||||
fire.Fire(cli)
|
typer.secho = console.print
|
||||||
|
|
||||||
|
# 运行应用
|
||||||
|
app()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,8 @@ class LLMConfig(BaseModel):
|
||||||
timeout: int = Field(30, description="请求超时时间(秒)")
|
timeout: int = Field(30, description="请求超时时间(秒)")
|
||||||
max_tokens: int = Field(1000, description="最大生成token数")
|
max_tokens: int = Field(1000, description="最大生成token数")
|
||||||
temperature: float = Field(0.7, description="温度参数")
|
temperature: float = Field(0.7, description="温度参数")
|
||||||
extra_body: Dict = field(default_factory=lambda: {"enable_thinking": False})
|
extra_body: Dict = Field(default_factory=lambda: {"enable_thinking": False})
|
||||||
response_format: bool = field(default_factory=lambda: {"type": "json_object"})
|
response_format: Dict = Field(default_factory=lambda: {"type": "json_object"})
|
||||||
system_prompt: str = SYSTEM_PROMPT
|
system_prompt: str = SYSTEM_PROMPT
|
||||||
# stream_output: bool = False
|
# stream_output: bool = False
|
||||||
|
|
||||||
|
|
@ -85,7 +85,8 @@ class ConfigManager:
|
||||||
self.ensure_config_dir()
|
self.ensure_config_dir()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_dict = self._config.model_dump(exclude_unset=True)
|
# 直接保存所有字段,不排除任何内容
|
||||||
|
config_dict = self._config.model_dump()
|
||||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||||
yaml.dump(config_dict, f, default_flow_style=False)
|
yaml.dump(config_dict, f, default_flow_style=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
import subprocess
|
||||||
|
from .llm import LLMClient
|
||||||
|
|
||||||
|
class GitHandler:
|
||||||
|
def __init__(self, llm_client=None):
|
||||||
|
if llm_client:
|
||||||
|
self.llm_client = llm_client
|
||||||
|
else:
|
||||||
|
from .llm import LLMClient
|
||||||
|
self.llm_client = LLMClient()
|
||||||
|
|
||||||
|
def get_diff(self) -> str:
|
||||||
|
# 使用 subprocess 获取 git diff
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "diff", "--staged"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise Exception("Git 未安装或不在系统路径中")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"获取 Git 差异失败: {e}")
|
||||||
|
|
||||||
|
def commit_with_message(self, message: str):
|
||||||
|
# 使用 git commit -m
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "commit", "-m", message],
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise Exception(f"Git 提交失败: {result.stderr}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise Exception("Git 未安装或不在系统路径中")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"提交失败: {e}")
|
||||||
|
|
||||||
|
def auto_commit(self):
|
||||||
|
diff = self.get_diff()
|
||||||
|
if not diff:
|
||||||
|
print("没有暂存的更改")
|
||||||
|
console.print("[yellow]⚠️ 没有暂存的更改[/yellow]")
|
||||||
|
console.print("[cyan]提示: 请先使用 `git add` 暂存文件[/cyan]")
|
||||||
|
return
|
||||||
|
|
||||||
|
message = self.llm_client.generate_commit_message(diff)
|
||||||
|
if message:
|
||||||
|
self.commit_with_message(message)
|
||||||
|
console.print(f"[green]✅ 已提交: {message}[/green]")
|
||||||
|
|
@ -1,122 +1,323 @@
|
||||||
"""LLM客户端模块"""
|
import os
|
||||||
import requests
|
import asyncio
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
from openai import AsyncOpenAI, OpenAI
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from .config import ConfigManager
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
from rich.status import Status
|
||||||
|
|
||||||
|
from .config import ConfigManager, LLMConfig
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""LLM客户端"""
|
"""LLM客户端 - 使用OpenAI库"""
|
||||||
|
|
||||||
def __init__(self, config_manager: Optional[ConfigManager] = None):
|
def __init__(self, config_manager: Optional[ConfigManager] = None):
|
||||||
self.config_manager = config_manager or ConfigManager()
|
self.config_manager = config_manager or ConfigManager()
|
||||||
self.config = self.config_manager.load_config()
|
self.config = self.config_manager.load_config().llm
|
||||||
|
self.client = None
|
||||||
|
self.async_client = None
|
||||||
|
self._init_clients()
|
||||||
|
|
||||||
|
def _init_clients(self):
|
||||||
|
"""初始化OpenAI客户端"""
|
||||||
|
if not self.config.api_key:
|
||||||
|
console.print("[yellow]⚠️ API密钥未设置,客户端未初始化[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 初始化同步客户端
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
timeout=self.config.timeout,
|
||||||
|
max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化异步客户端
|
||||||
|
self.async_client = AsyncOpenAI(
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
timeout=self.config.timeout,
|
||||||
|
max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]❌ 初始化OpenAI客户端失败: {e}[/red]")
|
||||||
|
|
||||||
|
def _ensure_client_initialized(self) -> bool:
|
||||||
|
"""确保客户端已初始化"""
|
||||||
|
if not self.config.api_key:
|
||||||
|
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||||
|
console.print("[yellow]💡 提示: 请先运行 `autocommit config set --api-key <你的密钥>`[/yellow]")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.client or not self.async_client:
|
||||||
|
self._init_clients()
|
||||||
|
|
||||||
|
return self.client is not None
|
||||||
|
|
||||||
def check_availability(self) -> Dict[str, Any]:
|
def check_availability(self) -> Dict[str, Any]:
|
||||||
"""检查LLM是否可用"""
|
"""检查LLM服务可用性"""
|
||||||
result = {
|
result = {
|
||||||
"available": False,
|
"available": False,
|
||||||
"message": "",
|
"message": "",
|
||||||
"details": {}
|
"details": {},
|
||||||
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self.config.llm.api_key:
|
if not self._ensure_client_initialized():
|
||||||
result["message"] = "API密钥未设置"
|
result["message"] = "客户端未初始化"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 测试OpenAI兼容API
|
# 检查模型是否可用
|
||||||
headers = {
|
with console.status("[bold cyan]正在检查模型可用性...[/bold cyan]"):
|
||||||
"Authorization": f"Bearer {self.config.llm.api_key}",
|
# 尝试列出模型(部分API可能不支持)
|
||||||
"Content-Type": "application/json"
|
try:
|
||||||
}
|
models = self.client.models.list()
|
||||||
|
model_names = [model.id for model in models.data]
|
||||||
|
result["details"]["model_count"] = len(model_names)
|
||||||
|
result["details"]["models"] = model_names[:5] # 只显示前5个
|
||||||
|
|
||||||
# 尝试列出模型
|
# 检查配置的模型是否在列表中
|
||||||
response = requests.get(
|
configured_in_list = self.config.model in model_names
|
||||||
f"{self.config.llm.base_url}/models",
|
result["details"]["configured_model"] = self.config.model
|
||||||
headers=headers,
|
result["details"]["model_available"] = configured_in_list
|
||||||
timeout=self.config.llm.timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
if not configured_in_list:
|
||||||
models = response.json().get("data", [])
|
console.print(f"[yellow]⚠️ 配置的模型 '{self.config.model}' 不在可用模型列表中[/yellow]")
|
||||||
model_names = [model["id"] for model in models]
|
console.print(f"[dim]可用模型: {', '.join(model_names[:3])}...[/dim]")
|
||||||
|
|
||||||
result["available"] = True
|
result["available"] = True
|
||||||
result["message"] = "LLM服务可用"
|
result["message"] = f"LLM服务可用 (找到 {len(model_names)} 个模型)"
|
||||||
result["details"] = {
|
|
||||||
"model_count": len(models),
|
|
||||||
"models": model_names[:10], # 只显示前10个模型
|
|
||||||
"configured_model": self.config.llm.model,
|
|
||||||
"model_in_list": self.config.llm.model in model_names
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
result["message"] = f"API请求失败: {response.status_code}"
|
|
||||||
result["details"] = {"status_code": response.status_code}
|
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
result["message"] = "无法连接到API服务器"
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
result["message"] = f"请求超时({self.config.llm.timeout}秒)"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["message"] = f"未知错误: {str(e)}"
|
# 如果列出模型失败,尝试通过简单请求测试
|
||||||
|
console.print("[dim]模型列表不可用,尝试通过聊天请求测试...[/dim]")
|
||||||
|
try:
|
||||||
|
# 发送一个极简的测试请求
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
max_tokens=1,
|
||||||
|
)
|
||||||
|
result["available"] = True
|
||||||
|
result["message"] = f"LLM服务可用 (模型: {self.config.model})"
|
||||||
|
result["details"]["test_response"] = "成功"
|
||||||
|
except Exception as inner_e:
|
||||||
|
raise inner_e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result["message"] = f"LLM服务检查失败"
|
||||||
|
result["error"] = str(e)
|
||||||
|
console.print(f"[red]❌ 错误详情: {e}[/red]")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def display_check_result(self, result: Dict[str, Any]):
|
def display_check_result(self, result: Dict[str, Any]):
|
||||||
"""显示检查结果"""
|
"""显示检查结果 - 美观的Rich表格"""
|
||||||
table = Table(title="LLM可用性检查")
|
table = Table(title="🤖 LLM 服务状态检查", box=None)
|
||||||
table.add_column("项目", style="cyan")
|
table.add_column("项目", style="cyan bold", width=20)
|
||||||
table.add_column("状态", style="green")
|
table.add_column("状态", style="green", width=40)
|
||||||
|
|
||||||
table.add_row("可用性", "✅ 可用" if result["available"] else "❌ 不可用")
|
# 可用性状态
|
||||||
table.add_row("状态信息", result["message"])
|
if result["available"]:
|
||||||
|
status_icon = "✅"
|
||||||
|
status_text = "[green]可用[/green]"
|
||||||
|
else:
|
||||||
|
status_icon = "❌"
|
||||||
|
status_text = "[red]不可用[/red]"
|
||||||
|
|
||||||
|
table.add_row(f"{status_icon} 可用性", status_text)
|
||||||
|
table.add_row("📝 状态信息", result["message"])
|
||||||
|
|
||||||
|
# 详细信息
|
||||||
if result["details"]:
|
if result["details"]:
|
||||||
for key, value in result["details"].items():
|
for key, value in result["details"].items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
table.add_row(key, ", ".join(value[:5]) + ("..." if len(value) > 5 else ""))
|
display_value = f"[cyan]{', '.join(map(str, value))}[/cyan]"
|
||||||
|
elif isinstance(value, bool):
|
||||||
|
display_value = "[green]是[/green]" if value else "[red]否[/red]"
|
||||||
else:
|
else:
|
||||||
table.add_row(key, str(value))
|
display_value = f"[cyan]{value}[/cyan]"
|
||||||
|
|
||||||
|
# 美化键名
|
||||||
|
key_display = key.replace("_", " ").title()
|
||||||
|
table.add_row(f"📊 {key_display}", display_value)
|
||||||
|
|
||||||
|
# 错误信息
|
||||||
|
if result.get("error"):
|
||||||
|
table.add_row("🚨 错误详情", f"[red]{result['error']}[/red]")
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
def test_completion(self, prompt: str = "Hello, are you working?") -> Optional[str]:
|
async def generate_commit_message_async(self, diff_content: str) -> Optional[str]:
|
||||||
"""测试补全功能"""
|
"""异步生成提交消息(推荐用于异步应用)"""
|
||||||
if not self.config.llm.api_key:
|
if not self._ensure_client_initialized():
|
||||||
console.print("[red]错误: API密钥未设置[/red]")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = {
|
# 构建完整的提示
|
||||||
"Authorization": f"Bearer {self.config.llm.api_key}",
|
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
messages = [
|
||||||
"model": self.config.llm.model,
|
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
{"role": "user", "content": full_prompt}
|
||||||
"max_tokens": 50,
|
]
|
||||||
"temperature": 0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
# 发送请求
|
||||||
f"{self.config.llm.base_url}/chat/completions",
|
response = await self.async_client.chat.completions.create(
|
||||||
headers=headers,
|
model=self.config.model,
|
||||||
json=data,
|
messages=messages,
|
||||||
timeout=self.config.llm.timeout
|
max_tokens=self.config.max_tokens,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
extra_body=self.config.extra_body,
|
||||||
|
# response_format=self.config.response_format, # 暂时注释,因为可能不是所有API都支持
|
||||||
|
stream=False # 非流式
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
# 提取消息内容
|
||||||
result = response.json()
|
if response.choices and len(response.choices) > 0:
|
||||||
return result["choices"][0]["message"]["content"]
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# 清理返回的文本(移除多余的引号、空格等)
|
||||||
|
content = content.strip().strip('"').strip("'").strip()
|
||||||
|
|
||||||
|
return content
|
||||||
else:
|
else:
|
||||||
console.print(f"[red]请求失败: {response.status_code}[/red]")
|
console.print("[red]❌ 未收到有效的响应内容[/red]")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]错误: {str(e)}[/red]")
|
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def generate_commit_message(self, diff_content: str) -> Optional[str]:
|
||||||
|
"""同步生成提交消息"""
|
||||||
|
if not self._ensure_client_initialized():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(diff_content) >= 1024 * 16:
|
||||||
|
diff_content = diff_content[0: 1024 * 16]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建完整的提示
|
||||||
|
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
|
||||||
|
{"role": "user", "content": full_prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 使用Rich显示进度
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
transient=True,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task(
|
||||||
|
description="[cyan]🤖 正在生成提交消息...[/cyan]",
|
||||||
|
total=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
extra_body=self.config.extra_body,
|
||||||
|
# response_format=self.config.response_format, # 暂时注释
|
||||||
|
stream=False # 非流式
|
||||||
|
)
|
||||||
|
|
||||||
|
progress.update(task, completed=100)
|
||||||
|
|
||||||
|
# 提取消息内容
|
||||||
|
if response.choices and len(response.choices) > 0:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# 清理返回的文本
|
||||||
|
content = content.strip().strip('"').strip("'").strip()
|
||||||
|
|
||||||
|
# 显示结果
|
||||||
|
console.print(Panel(
|
||||||
|
f"[bold green]{content}[/bold green]",
|
||||||
|
title="📝 生成的提交消息",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 2)
|
||||||
|
))
|
||||||
|
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
console.print("[red]❌ 未收到有效的响应内容[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def test_generation(self, test_diff: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""测试消息生成功能"""
|
||||||
|
if not self._ensure_client_initialized():
|
||||||
|
return None
|
||||||
|
|
||||||
|
console.print("[bold cyan]🧪 测试消息生成...[/bold cyan]")
|
||||||
|
|
||||||
|
# 使用示例diff或提供的测试diff
|
||||||
|
if test_diff is None:
|
||||||
|
test_diff = """diff --git a/example.py b/example.py
|
||||||
|
index abc123..def456 100644
|
||||||
|
--- a/example.py
|
||||||
|
+++ b/example.py
|
||||||
|
@@ -1,5 +1,5 @@
|
||||||
|
def hello_world():
|
||||||
|
- print("Hello, World!")
|
||||||
|
+ print("Hello, OpenAI!")
|
||||||
|
|
||||||
|
def add(a, b):
|
||||||
|
return a + b"""
|
||||||
|
|
||||||
|
console.print("[dim]使用示例diff进行测试...[/dim]")
|
||||||
|
|
||||||
|
return self.generate_commit_message(test_diff)
|
||||||
|
|
||||||
|
def update_config(self, **kwargs):
|
||||||
|
"""更新配置并重新初始化客户端"""
|
||||||
|
# 更新配置
|
||||||
|
self.config_manager.update_llm_config(**kwargs)
|
||||||
|
|
||||||
|
# 重新加载配置
|
||||||
|
self.config = self.config_manager.load_config().llm
|
||||||
|
|
||||||
|
# 重新初始化客户端
|
||||||
|
self._init_clients()
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
console.print("[green]✅ 配置已更新,客户端重新初始化[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[yellow]⚠️ 配置已更新,但客户端初始化失败[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数:同步调用异步函数
|
||||||
|
def generate_commit_message_sync(diff_content: str, config_manager: Optional[ConfigManager] = None) -> Optional[str]:
|
||||||
|
"""同步辅助函数:生成提交消息"""
|
||||||
|
client = LLMClient(config_manager)
|
||||||
|
|
||||||
|
# 尝试同步方法
|
||||||
|
result = client.generate_commit_message(diff_content)
|
||||||
|
|
||||||
|
# 如果同步方法失败,尝试异步方法
|
||||||
|
if result is None and client.async_client:
|
||||||
|
try:
|
||||||
|
console.print("[dim]尝试异步生成...[/dim]")
|
||||||
|
result = asyncio.run(client.generate_commit_message_async(diff_content))
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]❌ 异步生成也失败: {e}[/red]")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,12 @@ readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
"fire>=0.5.0",
|
|
||||||
"rich>=13.0.0",
|
"rich>=13.0.0",
|
||||||
"pydantic>=2.0.0",
|
"pydantic>=2.0.0",
|
||||||
"pyyaml>=6.0", # 用于配置文件
|
"pyyaml>=6.0", # 用于配置文件
|
||||||
"requests>=2.31.0", # 用于网络请求
|
"requests>=2.31.0", # 用于网络请求
|
||||||
"setuptools>=65.0"
|
"setuptools>=65.0",
|
||||||
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue