feat(cli): 升级 CLI 为 Typer + Rich 版本,新增配置命令与交互式向导
This commit is contained in:
parent
98b75a732c
commit
5129fc1bcf
|
|
@ -1,13 +1,34 @@
|
|||
"""命令行入口"""
|
||||
import fire
|
||||
"""命令行入口 - Typer + Rich 版本"""
|
||||
import typer
|
||||
from typing import Optional
|
||||
from rich.console import Console
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from typing import Optional
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich import print as rprint
|
||||
|
||||
from .config import ConfigManager
|
||||
from .llm import LLMClient
|
||||
|
||||
# 创建 Typer 应用
|
||||
app = typer.Typer(
|
||||
help="AutoCommit - 自动生成提交消息工具",
|
||||
context_settings={"help_option_names": ["-h", "--help"]}
|
||||
)
|
||||
|
||||
# 创建 Rich 控制台
|
||||
console = Console()
|
||||
|
||||
# 全局 CLI 实例(类似单例模式)
|
||||
_cli_instance = None
|
||||
|
||||
def get_cli() -> "AutoCommitCLI":
|
||||
"""获取 CLI 单例实例"""
|
||||
global _cli_instance
|
||||
if _cli_instance is None:
|
||||
_cli_instance = AutoCommitCLI()
|
||||
return _cli_instance
|
||||
|
||||
class AutoCommitCLI:
|
||||
"""AutoCommit命令行接口"""
|
||||
|
||||
|
|
@ -15,140 +36,317 @@ class AutoCommitCLI:
|
|||
self.config_manager = ConfigManager()
|
||||
self.llm_client = LLMClient(self.config_manager)
|
||||
|
||||
def config(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
show: bool = False,
|
||||
clear: bool = False
|
||||
):
|
||||
"""
|
||||
配置LLM参数
|
||||
def _create_config_table(self, config_data: dict) -> Table:
|
||||
"""创建配置表格 - 美观的 Rich 表格"""
|
||||
table = Table(title="LLM 配置", show_header=True, header_style="bold cyan")
|
||||
table.add_column("配置项", style="cyan", width=20)
|
||||
table.add_column("当前值", style="green", width=40)
|
||||
|
||||
Args:
|
||||
api_key: LLM API密钥
|
||||
base_url: API基础URL (默认: https://api.openai.com/v1)
|
||||
model: 模型名称 (默认: gpt-3.5-turbo)
|
||||
timeout: 请求超时时间(秒)
|
||||
show: 显示当前配置
|
||||
clear: 清除所有配置
|
||||
"""
|
||||
if clear:
|
||||
if Confirm.ask("确定要清除所有配置吗?"):
|
||||
self.config_manager.clear_config()
|
||||
console.print("[green]配置已清除[/green]")
|
||||
return
|
||||
for key, value in config_data.items():
|
||||
if key == "api_key" and value:
|
||||
display_value = "[yellow]已设置[/yellow] (安全隐藏)"
|
||||
elif value:
|
||||
display_value = f"[green]{value}[/green]"
|
||||
else:
|
||||
display_value = "[dim]未设置[/dim]"
|
||||
table.add_row(key.replace("_", " ").title(), display_value)
|
||||
|
||||
if show:
|
||||
config_data = self.config_manager.show_config()
|
||||
console.print("[bold cyan]当前配置:[/bold cyan]")
|
||||
return table
|
||||
|
||||
# 配置命令组
|
||||
config_app = typer.Typer(name="config", help="配置管理")
|
||||
app.add_typer(config_app)
|
||||
|
||||
@config_app.command("show", help="显示当前配置")
|
||||
def config_show():
|
||||
"""显示当前配置"""
|
||||
cli = get_cli()
|
||||
config_data = cli.config_manager.show_config()
|
||||
llm_config = config_data.get("llm", {})
|
||||
table = self._create_config_table(llm_config)
|
||||
console.print(table)
|
||||
return
|
||||
|
||||
# 交互式配置
|
||||
if not any([api_key, base_url, model, timeout is not None]):
|
||||
console.print("[bold yellow]交互式配置模式[/bold yellow]")
|
||||
console.print("(直接按Enter跳过当前项)\n")
|
||||
console.print("[bold cyan]📋 当前配置[/bold cyan]")
|
||||
console.print(cli._create_config_table(llm_config))
|
||||
|
||||
current_config = self.config_manager.show_config()["llm"]
|
||||
|
||||
api_key = Prompt.ask(
|
||||
f"API密钥 (当前: {'*' * 8 if current_config.get('api_key') else '未设置'})",
|
||||
password=True,
|
||||
default=""
|
||||
) or None
|
||||
@config_app.command("clear", help="清除所有配置")
|
||||
def config_clear():
|
||||
"""清除配置"""
|
||||
cli = get_cli()
|
||||
if Confirm.ask("[yellow]⚠️ 确定要清除所有配置吗?[/yellow]"):
|
||||
cli.config_manager.clear_config()
|
||||
console.print("[green]✅ 配置已清除[/green]")
|
||||
else:
|
||||
console.print("[dim]操作已取消[/dim]")
|
||||
|
||||
base_url = Prompt.ask(
|
||||
f"基础URL (当前: {current_config.get('base_url', 'https://api.openai.com/v1')})",
|
||||
default=current_config.get('base_url', 'https://api.openai.com/v1')
|
||||
)
|
||||
|
||||
model = Prompt.ask(
|
||||
f"模型名称 (当前: {current_config.get('model', 'gpt-3.5-turbo')})",
|
||||
default=current_config.get('model', 'gpt-3.5-turbo')
|
||||
)
|
||||
|
||||
timeout_input = Prompt.ask(
|
||||
f"超时时间(秒) (当前: {current_config.get('timeout', 30)})",
|
||||
default=str(current_config.get('timeout', 30))
|
||||
)
|
||||
timeout = int(timeout_input) if timeout_input else None
|
||||
@config_app.command("set", help="设置配置参数")
|
||||
def config_set(
|
||||
api_key: Optional[str] = typer.Option(None, "--api-key", "-k", help="LLM API密钥"),
|
||||
base_url: Optional[str] = typer.Option(None, "--base-url", "-u", help="API基础URL"),
|
||||
model: Optional[str] = typer.Option(None, "--model", "-m", help="模型名称"),
|
||||
timeout: Optional[int] = typer.Option(None, "--timeout", "-t", help="请求超时时间(秒)")
|
||||
):
|
||||
"""设置配置参数"""
|
||||
cli = get_cli()
|
||||
|
||||
# 更新配置
|
||||
updated_config = self.config_manager.update_llm_config(
|
||||
updated_config = cli.config_manager.update_llm_config(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
console.print("[green]配置已更新[/green]")
|
||||
table = self._create_config_table(updated_config.model_dump())
|
||||
console.print(table)
|
||||
console.print("[green]✅ 配置已更新[/green]")
|
||||
console.print(cli._create_config_table(updated_config.model_dump()))
|
||||
|
||||
@config_app.command("interactive", help="交互式配置向导")
|
||||
def config_interactive():
|
||||
"""交互式配置向导"""
|
||||
cli = get_cli()
|
||||
|
||||
console.print("[bold yellow]🎛️ 交互式配置向导[/bold yellow]")
|
||||
console.print("[dim](直接按Enter跳过当前项)[/dim]\n")
|
||||
|
||||
current_config = cli.config_manager.show_config()["llm"]
|
||||
|
||||
api_key = Prompt.ask(
|
||||
f"[cyan]API密钥[/cyan] (当前: {'[yellow]已设置[/yellow]' if current_config.get('api_key') else '[dim]未设置[/dim]'})",
|
||||
password=True,
|
||||
default=""
|
||||
) or None
|
||||
|
||||
base_url = Prompt.ask(
|
||||
f"[cyan]基础URL[/cyan] (当前: {current_config.get('base_url', 'https://api.openai.com/v1')})",
|
||||
default=current_config.get('base_url', 'https://api.openai.com/v1')
|
||||
)
|
||||
|
||||
model = Prompt.ask(
|
||||
f"[cyan]模型名称[/cyan] (当前: {current_config.get('model', 'gpt-3.5-turbo')})",
|
||||
default=current_config.get('model', 'gpt-3.5-turbo')
|
||||
)
|
||||
|
||||
timeout_input = Prompt.ask(
|
||||
f"[cyan]超时时间(秒)[/cyan] (当前: {current_config.get('timeout', 30)})",
|
||||
default=str(current_config.get('timeout', 30))
|
||||
)
|
||||
timeout = int(timeout_input) if timeout_input else None
|
||||
|
||||
# 更新配置
|
||||
updated_config = cli.config_manager.update_llm_config(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
console.print("\n[green]✅ 配置已更新[/green]")
|
||||
console.print(cli._create_config_table(updated_config.model_dump()))
|
||||
|
||||
# 询问是否测试连接
|
||||
if Confirm.ask("是否测试LLM连接?"):
|
||||
self.check()
|
||||
if Confirm.ask("[cyan]是否测试LLM连接?[/cyan]"):
|
||||
check(test=True)
|
||||
|
||||
def _create_config_table(self, config_data: dict):
|
||||
"""创建配置表格"""
|
||||
from rich.table import Table
|
||||
table = Table(show_header=False, box=None)
|
||||
table.add_column("项目", style="cyan")
|
||||
table.add_column("值", style="green")
|
||||
@app.command(help="检查LLM可用性")
|
||||
def check(
|
||||
test: bool = typer.Option(False, "--test", "-t", help="发送测试消息")
|
||||
):
|
||||
"""检查LLM可用性"""
|
||||
cli = get_cli()
|
||||
|
||||
for key, value in config_data.items():
|
||||
if key == "api_key" and value:
|
||||
table.add_row(key, "*" * 8)
|
||||
else:
|
||||
table.add_row(key, str(value))
|
||||
|
||||
return table
|
||||
|
||||
def check(self, test: bool = False):
|
||||
"""
|
||||
检查LLM可用性
|
||||
|
||||
Args:
|
||||
test: 是否发送测试消息
|
||||
"""
|
||||
console.print("[bold cyan]检查LLM可用性...[/bold cyan]")
|
||||
console.print("[bold cyan]🔍 检查LLM可用性...[/bold cyan]\n")
|
||||
|
||||
# 检查配置
|
||||
config = self.config_manager.load_config()
|
||||
config = cli.config_manager.load_config()
|
||||
if not config.llm.api_key:
|
||||
console.print("[red]错误: API密钥未设置,请先运行 `autocommit config`[/red]")
|
||||
return
|
||||
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||
console.print("[yellow]💡 提示: 请先运行 `autocommit config interactive` 进行配置[/yellow]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# 检查可用性
|
||||
result = self.llm_client.check_availability()
|
||||
self.llm_client.display_check_result(result)
|
||||
with console.status("[bold green]正在检查LLM连接...[/bold green]"):
|
||||
result = cli.llm_client.check_availability()
|
||||
|
||||
cli.llm_client.display_check_result(result)
|
||||
|
||||
# 测试补全
|
||||
if test and result["available"]:
|
||||
console.print("\n[bold cyan]发送测试消息...[/bold cyan]")
|
||||
response = self.llm_client.test_completion()
|
||||
if test and result.get("available"):
|
||||
console.print("\n[bold cyan]📤 发送测试消息...[/bold cyan]")
|
||||
with console.status("[bold green]正在等待LLM回复...[/bold green]"):
|
||||
response = cli.llm_client.test_generation()
|
||||
|
||||
if response:
|
||||
console.print(Panel(response, title="测试回复", border_style="green"))
|
||||
console.print(Panel(
|
||||
response,
|
||||
title="🤖 LLM 测试回复",
|
||||
border_style="green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
|
||||
def run(self):
|
||||
"""运行主程序(示例)"""
|
||||
console.print("[bold green]AutoCommit 启动[/bold green]")
|
||||
# 这里添加你的主要逻辑
|
||||
def _run():
|
||||
"""运行主程序 - 内部函数"""
|
||||
cli = get_cli()
|
||||
|
||||
def version(self):
|
||||
console.print(Panel(
|
||||
"[bold green]🚀 AutoCommit 启动[/bold green]\n"
|
||||
"[dim]自动生成 Git 提交消息并提交暂存的文件[/dim]",
|
||||
border_style="green",
|
||||
padding=(1, 4)
|
||||
))
|
||||
|
||||
# 检查配置
|
||||
config = cli.config_manager.load_config()
|
||||
if not config.llm.api_key:
|
||||
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||
console.print("[yellow]💡 提示: 请先运行 `autocommit config interactive` 进行配置[/yellow]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# 初始化 GitHandler
|
||||
try:
|
||||
from .git import GitHandler
|
||||
git_handler = GitHandler(cli.llm_client)
|
||||
except ImportError:
|
||||
console.print("[red]❌ 错误: 无法导入 Git 模块[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ 错误: 初始化 GitHandler 失败: {e}[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# 执行自动提交
|
||||
try:
|
||||
console.print("[cyan]📤 正在检查暂存的文件...[/cyan]")
|
||||
git_handler.auto_commit()
|
||||
|
||||
# 显示成功信息
|
||||
console.print(Panel(
|
||||
"[bold green]✅ 自动提交完成![/bold green]\n"
|
||||
"[dim]已成功生成提交消息并提交暂存的更改[/dim]",
|
||||
border_style="green",
|
||||
padding=(1, 4)
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ 自动提交失败: {e}[/red]")
|
||||
console.print("[yellow]💡 提示: 确保您有暂存的文件并且处于 Git 仓库中[/yellow]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@app.command(help="测试生成提交消息(不实际提交)")
|
||||
def test_commit():
|
||||
"""测试生成提交消息而不实际提交"""
|
||||
cli = get_cli()
|
||||
|
||||
console.print(Panel(
|
||||
"[bold yellow]🧪 测试模式启动[/bold yellow]\n"
|
||||
"[dim]将生成提交消息但不会实际提交到 Git[/dim]",
|
||||
border_style="yellow",
|
||||
padding=(1, 4)
|
||||
))
|
||||
|
||||
# 检查配置
|
||||
config = cli.config_manager.load_config()
|
||||
if not config.llm.api_key:
|
||||
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||
console.print("[yellow]💡 提示: 请先运行 `autocommit config interactive` 进行配置[/yellow]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
try:
|
||||
from .git import GitHandler
|
||||
git_handler = GitHandler(cli.llm_client)
|
||||
|
||||
# 获取差异
|
||||
diff = git_handler.get_diff()
|
||||
if not diff:
|
||||
console.print("[yellow]⚠️ 没有暂存的更改[/yellow]")
|
||||
console.print("[cyan]使用示例差异进行测试...[/cyan]")
|
||||
|
||||
# 使用示例差异
|
||||
example_diff = """diff --git a/example.py b/example.py
|
||||
index abc123..def456 100644
|
||||
--- a/example.py
|
||||
+++ b/example.py
|
||||
@@ -1,5 +1,5 @@
|
||||
def hello_world():
|
||||
- print("Hello, World!")
|
||||
+ print("Hello, OpenAI!")
|
||||
|
||||
def add(a, b):
|
||||
return a + b"""
|
||||
|
||||
# 使用 LLMClient 直接测试生成
|
||||
with console.status("[bold cyan]🤖 正在生成提交消息...[/bold cyan]"):
|
||||
message = cli.llm_client.generate_commit_message(example_diff)
|
||||
|
||||
if message:
|
||||
console.print(Panel(
|
||||
f"[bold green]如果提交,消息将是:[/bold green]\n\n{message}",
|
||||
title="📝 测试生成的提交消息",
|
||||
border_style="green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
else:
|
||||
console.print("[red]❌ 无法生成提交消息[/red]")
|
||||
else:
|
||||
console.print("[green]✅ 检测到暂存的更改[/green]")
|
||||
console.print("[cyan]正在分析差异...[/cyan]")
|
||||
|
||||
# 使用 LLMClient 生成消息但不提交
|
||||
with console.status("[bold cyan]🤖 正在生成提交消息...[/bold cyan]"):
|
||||
message = cli.llm_client.generate_commit_message(diff)
|
||||
|
||||
if message:
|
||||
console.print(Panel(
|
||||
f"[bold green]如果提交,消息将是:[/bold green]\n\n{message}",
|
||||
title="📝 测试生成的提交消息",
|
||||
border_style="green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
else:
|
||||
console.print("[red]❌ 无法生成提交消息[/red]")
|
||||
return
|
||||
|
||||
|
||||
# 询问用户是否要实际提交
|
||||
if Confirm.ask("[yellow]是否使用此消息进行实际提交?[/yellow]"):
|
||||
git_handler.commit_with_message(message)
|
||||
console.print("[green]✅ 已提交更改[/green]")
|
||||
else:
|
||||
console.print("[dim]测试完成,未实际提交[/dim]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ 测试失败: {e}[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@app.command(help="显示版本信息")
|
||||
def version():
|
||||
"""显示版本信息"""
|
||||
console.print("[bold cyan]AutoCommit v0.1.0[/bold cyan]")
|
||||
console.print(Panel(
|
||||
"[bold cyan]AutoCommit v0.1.0[/bold cyan]\n"
|
||||
"[dim]基于 Typer + Rich 的现代化 CLI[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 4)
|
||||
))
|
||||
|
||||
# 主回调函数 - 当没有子命令时执行
|
||||
@app.callback(invoke_without_command=True)
|
||||
def main_callback(ctx: typer.Context):
|
||||
"""
|
||||
AutoCommit - 自动生成 Git 提交消息工具
|
||||
|
||||
直接运行 autocommit 会自动分析暂存的 Git 更改并生成提交消息。
|
||||
使用 autocommit config 进行配置管理。
|
||||
"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
_run()
|
||||
|
||||
# 主入口函数
|
||||
def main():
|
||||
"""主函数"""
|
||||
cli = AutoCommitCLI()
|
||||
fire.Fire(cli)
|
||||
"""主函数 - Typer 入口点"""
|
||||
# 设置彩色的错误消息
|
||||
typer.secho = console.print
|
||||
|
||||
# 运行应用
|
||||
app()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ class LLMConfig(BaseModel):
|
|||
timeout: int = Field(30, description="请求超时时间(秒)")
|
||||
max_tokens: int = Field(1000, description="最大生成token数")
|
||||
temperature: float = Field(0.7, description="温度参数")
|
||||
extra_body: Dict = field(default_factory=lambda: {"enable_thinking": False})
|
||||
response_format: bool = field(default_factory=lambda: {"type": "json_object"})
|
||||
extra_body: Dict = Field(default_factory=lambda: {"enable_thinking": False})
|
||||
response_format: Dict = Field(default_factory=lambda: {"type": "json_object"})
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
# stream_output: bool = False
|
||||
|
||||
|
|
@ -85,7 +85,8 @@ class ConfigManager:
|
|||
self.ensure_config_dir()
|
||||
|
||||
try:
|
||||
config_dict = self._config.model_dump(exclude_unset=True)
|
||||
# 直接保存所有字段,不排除任何内容
|
||||
config_dict = self._config.model_dump()
|
||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config_dict, f, default_flow_style=False)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
import subprocess
|
||||
from .llm import LLMClient
|
||||
|
||||
class GitHandler:
|
||||
def __init__(self, llm_client=None):
|
||||
if llm_client:
|
||||
self.llm_client = llm_client
|
||||
else:
|
||||
from .llm import LLMClient
|
||||
self.llm_client = LLMClient()
|
||||
|
||||
def get_diff(self) -> str:
|
||||
# 使用 subprocess 获取 git diff
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--staged"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
return result.stdout
|
||||
except FileNotFoundError:
|
||||
raise Exception("Git 未安装或不在系统路径中")
|
||||
except Exception as e:
|
||||
raise Exception(f"获取 Git 差异失败: {e}")
|
||||
|
||||
def commit_with_message(self, message: str):
|
||||
# 使用 git commit -m
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "commit", "-m", message],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"Git 提交失败: {result.stderr}")
|
||||
except FileNotFoundError:
|
||||
raise Exception("Git 未安装或不在系统路径中")
|
||||
except Exception as e:
|
||||
raise Exception(f"提交失败: {e}")
|
||||
|
||||
def auto_commit(self):
|
||||
diff = self.get_diff()
|
||||
if not diff:
|
||||
print("没有暂存的更改")
|
||||
console.print("[yellow]⚠️ 没有暂存的更改[/yellow]")
|
||||
console.print("[cyan]提示: 请先使用 `git add` 暂存文件[/cyan]")
|
||||
return
|
||||
|
||||
message = self.llm_client.generate_commit_message(diff)
|
||||
if message:
|
||||
self.commit_with_message(message)
|
||||
console.print(f"[green]✅ 已提交: {message}[/green]")
|
||||
|
|
@ -1,122 +1,323 @@
|
|||
"""LLM客户端模块"""
|
||||
import requests
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from .config import ConfigManager
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.status import Status
|
||||
|
||||
from .config import ConfigManager, LLMConfig
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM客户端"""
|
||||
"""LLM客户端 - 使用OpenAI库"""
|
||||
|
||||
def __init__(self, config_manager: Optional[ConfigManager] = None):
|
||||
self.config_manager = config_manager or ConfigManager()
|
||||
self.config = self.config_manager.load_config()
|
||||
self.config = self.config_manager.load_config().llm
|
||||
self.client = None
|
||||
self.async_client = None
|
||||
self._init_clients()
|
||||
|
||||
def _init_clients(self):
|
||||
"""初始化OpenAI客户端"""
|
||||
if not self.config.api_key:
|
||||
console.print("[yellow]⚠️ API密钥未设置,客户端未初始化[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# 初始化同步客户端
|
||||
self.client = OpenAI(
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
timeout=self.config.timeout,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# 初始化异步客户端
|
||||
self.async_client = AsyncOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
timeout=self.config.timeout,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ 初始化OpenAI客户端失败: {e}[/red]")
|
||||
|
||||
def _ensure_client_initialized(self) -> bool:
|
||||
"""确保客户端已初始化"""
|
||||
if not self.config.api_key:
|
||||
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||||
console.print("[yellow]💡 提示: 请先运行 `autocommit config set --api-key <你的密钥>`[/yellow]")
|
||||
return False
|
||||
|
||||
if not self.client or not self.async_client:
|
||||
self._init_clients()
|
||||
|
||||
return self.client is not None
|
||||
|
||||
def check_availability(self) -> Dict[str, Any]:
|
||||
"""检查LLM是否可用"""
|
||||
"""检查LLM服务可用性"""
|
||||
result = {
|
||||
"available": False,
|
||||
"message": "",
|
||||
"details": {}
|
||||
"details": {},
|
||||
"error": None
|
||||
}
|
||||
|
||||
if not self.config.llm.api_key:
|
||||
result["message"] = "API密钥未设置"
|
||||
if not self._ensure_client_initialized():
|
||||
result["message"] = "客户端未初始化"
|
||||
return result
|
||||
|
||||
try:
|
||||
# 测试OpenAI兼容API
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.llm.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
# 检查模型是否可用
|
||||
with console.status("[bold cyan]正在检查模型可用性...[/bold cyan]"):
|
||||
# 尝试列出模型(部分API可能不支持)
|
||||
try:
|
||||
models = self.client.models.list()
|
||||
model_names = [model.id for model in models.data]
|
||||
result["details"]["model_count"] = len(model_names)
|
||||
result["details"]["models"] = model_names[:5] # 只显示前5个
|
||||
|
||||
# 尝试列出模型
|
||||
response = requests.get(
|
||||
f"{self.config.llm.base_url}/models",
|
||||
headers=headers,
|
||||
timeout=self.config.llm.timeout
|
||||
)
|
||||
# 检查配置的模型是否在列表中
|
||||
configured_in_list = self.config.model in model_names
|
||||
result["details"]["configured_model"] = self.config.model
|
||||
result["details"]["model_available"] = configured_in_list
|
||||
|
||||
if response.status_code == 200:
|
||||
models = response.json().get("data", [])
|
||||
model_names = [model["id"] for model in models]
|
||||
if not configured_in_list:
|
||||
console.print(f"[yellow]⚠️ 配置的模型 '{self.config.model}' 不在可用模型列表中[/yellow]")
|
||||
console.print(f"[dim]可用模型: {', '.join(model_names[:3])}...[/dim]")
|
||||
|
||||
result["available"] = True
|
||||
result["message"] = "LLM服务可用"
|
||||
result["details"] = {
|
||||
"model_count": len(models),
|
||||
"models": model_names[:10], # 只显示前10个模型
|
||||
"configured_model": self.config.llm.model,
|
||||
"model_in_list": self.config.llm.model in model_names
|
||||
}
|
||||
else:
|
||||
result["message"] = f"API请求失败: {response.status_code}"
|
||||
result["details"] = {"status_code": response.status_code}
|
||||
result["message"] = f"LLM服务可用 (找到 {len(model_names)} 个模型)"
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
result["message"] = "无法连接到API服务器"
|
||||
except requests.exceptions.Timeout:
|
||||
result["message"] = f"请求超时({self.config.llm.timeout}秒)"
|
||||
except Exception as e:
|
||||
result["message"] = f"未知错误: {str(e)}"
|
||||
# 如果列出模型失败,尝试通过简单请求测试
|
||||
console.print("[dim]模型列表不可用,尝试通过聊天请求测试...[/dim]")
|
||||
try:
|
||||
# 发送一个极简的测试请求
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
result["available"] = True
|
||||
result["message"] = f"LLM服务可用 (模型: {self.config.model})"
|
||||
result["details"]["test_response"] = "成功"
|
||||
except Exception as inner_e:
|
||||
raise inner_e
|
||||
|
||||
except Exception as e:
|
||||
result["message"] = f"LLM服务检查失败"
|
||||
result["error"] = str(e)
|
||||
console.print(f"[red]❌ 错误详情: {e}[/red]")
|
||||
|
||||
return result
|
||||
|
||||
def display_check_result(self, result: Dict[str, Any]):
|
||||
"""显示检查结果"""
|
||||
table = Table(title="LLM可用性检查")
|
||||
table.add_column("项目", style="cyan")
|
||||
table.add_column("状态", style="green")
|
||||
"""显示检查结果 - 美观的Rich表格"""
|
||||
table = Table(title="🤖 LLM 服务状态检查", box=None)
|
||||
table.add_column("项目", style="cyan bold", width=20)
|
||||
table.add_column("状态", style="green", width=40)
|
||||
|
||||
table.add_row("可用性", "✅ 可用" if result["available"] else "❌ 不可用")
|
||||
table.add_row("状态信息", result["message"])
|
||||
# 可用性状态
|
||||
if result["available"]:
|
||||
status_icon = "✅"
|
||||
status_text = "[green]可用[/green]"
|
||||
else:
|
||||
status_icon = "❌"
|
||||
status_text = "[red]不可用[/red]"
|
||||
|
||||
table.add_row(f"{status_icon} 可用性", status_text)
|
||||
table.add_row("📝 状态信息", result["message"])
|
||||
|
||||
# 详细信息
|
||||
if result["details"]:
|
||||
for key, value in result["details"].items():
|
||||
if isinstance(value, list):
|
||||
table.add_row(key, ", ".join(value[:5]) + ("..." if len(value) > 5 else ""))
|
||||
display_value = f"[cyan]{', '.join(map(str, value))}[/cyan]"
|
||||
elif isinstance(value, bool):
|
||||
display_value = "[green]是[/green]" if value else "[red]否[/red]"
|
||||
else:
|
||||
table.add_row(key, str(value))
|
||||
display_value = f"[cyan]{value}[/cyan]"
|
||||
|
||||
# 美化键名
|
||||
key_display = key.replace("_", " ").title()
|
||||
table.add_row(f"📊 {key_display}", display_value)
|
||||
|
||||
# 错误信息
|
||||
if result.get("error"):
|
||||
table.add_row("🚨 错误详情", f"[red]{result['error']}[/red]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
def test_completion(self, prompt: str = "Hello, are you working?") -> Optional[str]:
|
||||
"""测试补全功能"""
|
||||
if not self.config.llm.api_key:
|
||||
console.print("[red]错误: API密钥未设置[/red]")
|
||||
async def generate_commit_message_async(self, diff_content: str) -> Optional[str]:
|
||||
"""异步生成提交消息(推荐用于异步应用)"""
|
||||
if not self._ensure_client_initialized():
|
||||
return None
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.llm.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
# 构建完整的提示
|
||||
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
|
||||
|
||||
data = {
|
||||
"model": self.config.llm.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.5
|
||||
}
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
|
||||
{"role": "user", "content": full_prompt}
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
f"{self.config.llm.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=self.config.llm.timeout
|
||||
# 发送请求
|
||||
response = await self.async_client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
max_tokens=self.config.max_tokens,
|
||||
temperature=self.config.temperature,
|
||||
extra_body=self.config.extra_body,
|
||||
# response_format=self.config.response_format, # 暂时注释,因为可能不是所有API都支持
|
||||
stream=False # 非流式
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
# 提取消息内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 清理返回的文本(移除多余的引号、空格等)
|
||||
content = content.strip().strip('"').strip("'").strip()
|
||||
|
||||
return content
|
||||
else:
|
||||
console.print(f"[red]请求失败: {response.status_code}[/red]")
|
||||
console.print("[red]❌ 未收到有效的响应内容[/red]")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]错误: {str(e)}[/red]")
|
||||
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
|
||||
return None
|
||||
|
||||
def generate_commit_message(self, diff_content: str) -> Optional[str]:
|
||||
"""同步生成提交消息"""
|
||||
if not self._ensure_client_initialized():
|
||||
return None
|
||||
|
||||
if len(diff_content) >= 1024 * 16:
|
||||
diff_content = diff_content[0: 1024 * 16]
|
||||
|
||||
try:
|
||||
# 构建完整的提示
|
||||
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
|
||||
{"role": "user", "content": full_prompt}
|
||||
]
|
||||
|
||||
# 使用Rich显示进度
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task(
|
||||
description="[cyan]🤖 正在生成提交消息...[/cyan]",
|
||||
total=None
|
||||
)
|
||||
|
||||
# 发送请求
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
max_tokens=self.config.max_tokens,
|
||||
temperature=self.config.temperature,
|
||||
extra_body=self.config.extra_body,
|
||||
# response_format=self.config.response_format, # 暂时注释
|
||||
stream=False # 非流式
|
||||
)
|
||||
|
||||
progress.update(task, completed=100)
|
||||
|
||||
# 提取消息内容
|
||||
if response.choices and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 清理返回的文本
|
||||
content = content.strip().strip('"').strip("'").strip()
|
||||
|
||||
# 显示结果
|
||||
console.print(Panel(
|
||||
f"[bold green]{content}[/bold green]",
|
||||
title="📝 生成的提交消息",
|
||||
border_style="green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
|
||||
return content
|
||||
else:
|
||||
console.print("[red]❌ 未收到有效的响应内容[/red]")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
|
||||
return None
|
||||
|
||||
def test_generation(self, test_diff: Optional[str] = None) -> Optional[str]:
|
||||
"""测试消息生成功能"""
|
||||
if not self._ensure_client_initialized():
|
||||
return None
|
||||
|
||||
console.print("[bold cyan]🧪 测试消息生成...[/bold cyan]")
|
||||
|
||||
# 使用示例diff或提供的测试diff
|
||||
if test_diff is None:
|
||||
test_diff = """diff --git a/example.py b/example.py
|
||||
index abc123..def456 100644
|
||||
--- a/example.py
|
||||
+++ b/example.py
|
||||
@@ -1,5 +1,5 @@
|
||||
def hello_world():
|
||||
- print("Hello, World!")
|
||||
+ print("Hello, OpenAI!")
|
||||
|
||||
def add(a, b):
|
||||
return a + b"""
|
||||
|
||||
console.print("[dim]使用示例diff进行测试...[/dim]")
|
||||
|
||||
return self.generate_commit_message(test_diff)
|
||||
|
||||
def update_config(self, **kwargs):
|
||||
"""更新配置并重新初始化客户端"""
|
||||
# 更新配置
|
||||
self.config_manager.update_llm_config(**kwargs)
|
||||
|
||||
# 重新加载配置
|
||||
self.config = self.config_manager.load_config().llm
|
||||
|
||||
# 重新初始化客户端
|
||||
self._init_clients()
|
||||
|
||||
if self.client:
|
||||
console.print("[green]✅ 配置已更新,客户端重新初始化[/green]")
|
||||
else:
|
||||
console.print("[yellow]⚠️ 配置已更新,但客户端初始化失败[/yellow]")
|
||||
|
||||
|
||||
# 辅助函数:同步调用异步函数
|
||||
def generate_commit_message_sync(diff_content: str, config_manager: Optional[ConfigManager] = None) -> Optional[str]:
|
||||
"""同步辅助函数:生成提交消息"""
|
||||
client = LLMClient(config_manager)
|
||||
|
||||
# 尝试同步方法
|
||||
result = client.generate_commit_message(diff_content)
|
||||
|
||||
# 如果同步方法失败,尝试异步方法
|
||||
if result is None and client.async_client:
|
||||
try:
|
||||
console.print("[dim]尝试异步生成...[/dim]")
|
||||
result = asyncio.run(client.generate_commit_message_async(diff_content))
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ 异步生成也失败: {e}[/red]")
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ readme = "README.md"
|
|||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"openai>=1.0.0",
|
||||
"fire>=0.5.0",
|
||||
"rich>=13.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pyyaml>=6.0", # 用于配置文件
|
||||
"requests>=2.31.0", # 用于网络请求
|
||||
"setuptools>=65.0"
|
||||
"setuptools>=65.0",
|
||||
"typer>=0.21.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
Loading…
Reference in New Issue