324 lines
12 KiB
Python
324 lines
12 KiB
Python
import os
|
||
import asyncio
|
||
from typing import Optional, Dict, Any
|
||
from openai import AsyncOpenAI, OpenAI
|
||
from rich.console import Console
|
||
from rich.table import Table
|
||
from rich.panel import Panel
|
||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||
from rich.status import Status
|
||
|
||
from .config import ConfigManager, LLMConfig
|
||
|
||
console = Console()
|
||
|
||
|
||
class LLMClient:
|
||
"""LLM客户端 - 使用OpenAI库"""
|
||
|
||
def __init__(self, config_manager: Optional[ConfigManager] = None):
|
||
self.config_manager = config_manager or ConfigManager()
|
||
self.config = self.config_manager.load_config().llm
|
||
self.client = None
|
||
self.async_client = None
|
||
self._init_clients()
|
||
|
||
def _init_clients(self):
|
||
"""初始化OpenAI客户端"""
|
||
if not self.config.api_key:
|
||
console.print("[yellow]⚠️ API密钥未设置,客户端未初始化[/yellow]")
|
||
return
|
||
|
||
try:
|
||
# 初始化同步客户端
|
||
self.client = OpenAI(
|
||
api_key=self.config.api_key,
|
||
base_url=self.config.base_url,
|
||
timeout=self.config.timeout,
|
||
max_retries=2,
|
||
)
|
||
|
||
# 初始化异步客户端
|
||
self.async_client = AsyncOpenAI(
|
||
api_key=self.config.api_key,
|
||
base_url=self.config.base_url,
|
||
timeout=self.config.timeout,
|
||
max_retries=2,
|
||
)
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]❌ 初始化OpenAI客户端失败: {e}[/red]")
|
||
|
||
def _ensure_client_initialized(self) -> bool:
|
||
"""确保客户端已初始化"""
|
||
if not self.config.api_key:
|
||
console.print("[red]❌ 错误: API密钥未设置[/red]")
|
||
console.print("[yellow]💡 提示: 请先运行 `autocommit config set --api-key <你的密钥>`[/yellow]")
|
||
return False
|
||
|
||
if not self.client or not self.async_client:
|
||
self._init_clients()
|
||
|
||
return self.client is not None
|
||
|
||
def check_availability(self) -> Dict[str, Any]:
|
||
"""检查LLM服务可用性"""
|
||
result = {
|
||
"available": False,
|
||
"message": "",
|
||
"details": {},
|
||
"error": None
|
||
}
|
||
|
||
if not self._ensure_client_initialized():
|
||
result["message"] = "客户端未初始化"
|
||
return result
|
||
|
||
try:
|
||
# 检查模型是否可用
|
||
with console.status("[bold cyan]正在检查模型可用性...[/bold cyan]"):
|
||
# 尝试列出模型(部分API可能不支持)
|
||
try:
|
||
models = self.client.models.list()
|
||
model_names = [model.id for model in models.data]
|
||
result["details"]["model_count"] = len(model_names)
|
||
result["details"]["models"] = model_names[:5] # 只显示前5个
|
||
|
||
# 检查配置的模型是否在列表中
|
||
configured_in_list = self.config.model in model_names
|
||
result["details"]["configured_model"] = self.config.model
|
||
result["details"]["model_available"] = configured_in_list
|
||
|
||
if not configured_in_list:
|
||
console.print(f"[yellow]⚠️ 配置的模型 '{self.config.model}' 不在可用模型列表中[/yellow]")
|
||
console.print(f"[dim]可用模型: {', '.join(model_names[:3])}...[/dim]")
|
||
|
||
result["available"] = True
|
||
result["message"] = f"LLM服务可用 (找到 {len(model_names)} 个模型)"
|
||
|
||
except Exception as e:
|
||
# 如果列出模型失败,尝试通过简单请求测试
|
||
console.print("[dim]模型列表不可用,尝试通过聊天请求测试...[/dim]")
|
||
try:
|
||
# 发送一个极简的测试请求
|
||
response = self.client.chat.completions.create(
|
||
model=self.config.model,
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
max_tokens=1,
|
||
)
|
||
result["available"] = True
|
||
result["message"] = f"LLM服务可用 (模型: {self.config.model})"
|
||
result["details"]["test_response"] = "成功"
|
||
except Exception as inner_e:
|
||
raise inner_e
|
||
|
||
except Exception as e:
|
||
result["message"] = f"LLM服务检查失败"
|
||
result["error"] = str(e)
|
||
console.print(f"[red]❌ 错误详情: {e}[/red]")
|
||
|
||
return result
|
||
|
||
def display_check_result(self, result: Dict[str, Any]):
|
||
"""显示检查结果 - 美观的Rich表格"""
|
||
table = Table(title="🤖 LLM 服务状态检查", box=None)
|
||
table.add_column("项目", style="cyan bold", width=20)
|
||
table.add_column("状态", style="green", width=40)
|
||
|
||
# 可用性状态
|
||
if result["available"]:
|
||
status_icon = "✅"
|
||
status_text = "[green]可用[/green]"
|
||
else:
|
||
status_icon = "❌"
|
||
status_text = "[red]不可用[/red]"
|
||
|
||
table.add_row(f"{status_icon} 可用性", status_text)
|
||
table.add_row("📝 状态信息", result["message"])
|
||
|
||
# 详细信息
|
||
if result["details"]:
|
||
for key, value in result["details"].items():
|
||
if isinstance(value, list):
|
||
display_value = f"[cyan]{', '.join(map(str, value))}[/cyan]"
|
||
elif isinstance(value, bool):
|
||
display_value = "[green]是[/green]" if value else "[red]否[/red]"
|
||
else:
|
||
display_value = f"[cyan]{value}[/cyan]"
|
||
|
||
# 美化键名
|
||
key_display = key.replace("_", " ").title()
|
||
table.add_row(f"📊 {key_display}", display_value)
|
||
|
||
# 错误信息
|
||
if result.get("error"):
|
||
table.add_row("🚨 错误详情", f"[red]{result['error']}[/red]")
|
||
|
||
console.print(table)
|
||
|
||
async def generate_commit_message_async(self, diff_content: str) -> Optional[str]:
|
||
"""异步生成提交消息(推荐用于异步应用)"""
|
||
if not self._ensure_client_initialized():
|
||
return None
|
||
|
||
try:
|
||
# 构建完整的提示
|
||
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
|
||
{"role": "user", "content": full_prompt}
|
||
]
|
||
|
||
# 发送请求
|
||
response = await self.async_client.chat.completions.create(
|
||
model=self.config.model,
|
||
messages=messages,
|
||
max_tokens=self.config.max_tokens,
|
||
temperature=self.config.temperature,
|
||
extra_body=self.config.extra_body,
|
||
# response_format=self.config.response_format, # 暂时注释,因为可能不是所有API都支持
|
||
stream=False # 非流式
|
||
)
|
||
|
||
# 提取消息内容
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
|
||
# 清理返回的文本(移除多余的引号、空格等)
|
||
content = content.strip().strip('"').strip("'").strip()
|
||
|
||
return content
|
||
else:
|
||
console.print("[red]❌ 未收到有效的响应内容[/red]")
|
||
return None
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
|
||
return None
|
||
|
||
def generate_commit_message(self, diff_content: str) -> Optional[str]:
|
||
"""同步生成提交消息"""
|
||
if not self._ensure_client_initialized():
|
||
return None
|
||
|
||
if len(diff_content) >= 1024 * 16:
|
||
diff_content = diff_content[0: 1024 * 16]
|
||
|
||
try:
|
||
# 构建完整的提示
|
||
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
|
||
{"role": "user", "content": full_prompt}
|
||
]
|
||
|
||
# 使用Rich显示进度
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[progress.description]{task.description}"),
|
||
transient=True,
|
||
) as progress:
|
||
task = progress.add_task(
|
||
description="[cyan]🤖 正在生成提交消息...[/cyan]",
|
||
total=None
|
||
)
|
||
|
||
# 发送请求
|
||
response = self.client.chat.completions.create(
|
||
model=self.config.model,
|
||
messages=messages,
|
||
max_tokens=self.config.max_tokens,
|
||
temperature=self.config.temperature,
|
||
extra_body=self.config.extra_body,
|
||
# response_format=self.config.response_format, # 暂时注释
|
||
stream=False # 非流式
|
||
)
|
||
|
||
progress.update(task, completed=100)
|
||
|
||
# 提取消息内容
|
||
if response.choices and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
|
||
# 清理返回的文本
|
||
content = content.strip().strip('"').strip("'").strip()
|
||
|
||
# 显示结果
|
||
console.print(Panel(
|
||
f"[bold green]{content}[/bold green]",
|
||
title="📝 生成的提交消息",
|
||
border_style="green",
|
||
padding=(1, 2)
|
||
))
|
||
|
||
return content
|
||
else:
|
||
console.print("[red]❌ 未收到有效的响应内容[/red]")
|
||
return None
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
|
||
return None
|
||
|
||
def test_generation(self, test_diff: Optional[str] = None) -> Optional[str]:
|
||
"""测试消息生成功能"""
|
||
if not self._ensure_client_initialized():
|
||
return None
|
||
|
||
console.print("[bold cyan]🧪 测试消息生成...[/bold cyan]")
|
||
|
||
# 使用示例diff或提供的测试diff
|
||
if test_diff is None:
|
||
test_diff = """diff --git a/example.py b/example.py
|
||
index abc123..def456 100644
|
||
--- a/example.py
|
||
+++ b/example.py
|
||
@@ -1,5 +1,5 @@
|
||
def hello_world():
|
||
- print("Hello, World!")
|
||
+ print("Hello, OpenAI!")
|
||
|
||
def add(a, b):
|
||
return a + b"""
|
||
|
||
console.print("[dim]使用示例diff进行测试...[/dim]")
|
||
|
||
return self.generate_commit_message(test_diff)
|
||
|
||
def update_config(self, **kwargs):
|
||
"""更新配置并重新初始化客户端"""
|
||
# 更新配置
|
||
self.config_manager.update_llm_config(**kwargs)
|
||
|
||
# 重新加载配置
|
||
self.config = self.config_manager.load_config().llm
|
||
|
||
# 重新初始化客户端
|
||
self._init_clients()
|
||
|
||
if self.client:
|
||
console.print("[green]✅ 配置已更新,客户端重新初始化[/green]")
|
||
else:
|
||
console.print("[yellow]⚠️ 配置已更新,但客户端初始化失败[/yellow]")
|
||
|
||
|
||
# 辅助函数:同步调用异步函数
|
||
def generate_commit_message_sync(diff_content: str, config_manager: Optional[ConfigManager] = None) -> Optional[str]:
|
||
"""同步辅助函数:生成提交消息"""
|
||
client = LLMClient(config_manager)
|
||
|
||
# 尝试同步方法
|
||
result = client.generate_commit_message(diff_content)
|
||
|
||
# 如果同步方法失败,尝试异步方法
|
||
if result is None and client.async_client:
|
||
try:
|
||
console.print("[dim]尝试异步生成...[/dim]")
|
||
result = asyncio.run(client.generate_commit_message_async(diff_content))
|
||
except Exception as e:
|
||
console.print(f"[red]❌ 异步生成也失败: {e}[/red]")
|
||
|
||
return result
|