autocommit/autocommit/llm.py

324 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import asyncio
from typing import Optional, Dict, Any
from openai import AsyncOpenAI, OpenAI
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.status import Status
from .config import ConfigManager, LLMConfig
console = Console()
class LLMClient:
"""LLM客户端 - 使用OpenAI库"""
def __init__(self, config_manager: Optional[ConfigManager] = None):
self.config_manager = config_manager or ConfigManager()
self.config = self.config_manager.load_config().llm
self.client = None
self.async_client = None
self._init_clients()
def _init_clients(self):
"""初始化OpenAI客户端"""
if not self.config.api_key:
console.print("[yellow]⚠️ API密钥未设置客户端未初始化[/yellow]")
return
try:
# 初始化同步客户端
self.client = OpenAI(
api_key=self.config.api_key,
base_url=self.config.base_url,
timeout=self.config.timeout,
max_retries=2,
)
# 初始化异步客户端
self.async_client = AsyncOpenAI(
api_key=self.config.api_key,
base_url=self.config.base_url,
timeout=self.config.timeout,
max_retries=2,
)
except Exception as e:
console.print(f"[red]❌ 初始化OpenAI客户端失败: {e}[/red]")
def _ensure_client_initialized(self) -> bool:
"""确保客户端已初始化"""
if not self.config.api_key:
console.print("[red]❌ 错误: API密钥未设置[/red]")
console.print("[yellow]💡 提示: 请先运行 `autocommit config set --api-key <你的密钥>`[/yellow]")
return False
if not self.client or not self.async_client:
self._init_clients()
return self.client is not None
def check_availability(self) -> Dict[str, Any]:
"""检查LLM服务可用性"""
result = {
"available": False,
"message": "",
"details": {},
"error": None
}
if not self._ensure_client_initialized():
result["message"] = "客户端未初始化"
return result
try:
# 检查模型是否可用
with console.status("[bold cyan]正在检查模型可用性...[/bold cyan]"):
# 尝试列出模型部分API可能不支持
try:
models = self.client.models.list()
model_names = [model.id for model in models.data]
result["details"]["model_count"] = len(model_names)
result["details"]["models"] = model_names[:5] # 只显示前5个
# 检查配置的模型是否在列表中
configured_in_list = self.config.model in model_names
result["details"]["configured_model"] = self.config.model
result["details"]["model_available"] = configured_in_list
if not configured_in_list:
console.print(f"[yellow]⚠️ 配置的模型 '{self.config.model}' 不在可用模型列表中[/yellow]")
console.print(f"[dim]可用模型: {', '.join(model_names[:3])}...[/dim]")
result["available"] = True
result["message"] = f"LLM服务可用 (找到 {len(model_names)} 个模型)"
except Exception as e:
# 如果列出模型失败,尝试通过简单请求测试
console.print("[dim]模型列表不可用,尝试通过聊天请求测试...[/dim]")
try:
# 发送一个极简的测试请求
response = self.client.chat.completions.create(
model=self.config.model,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=1,
)
result["available"] = True
result["message"] = f"LLM服务可用 (模型: {self.config.model})"
result["details"]["test_response"] = "成功"
except Exception as inner_e:
raise inner_e
except Exception as e:
result["message"] = f"LLM服务检查失败"
result["error"] = str(e)
console.print(f"[red]❌ 错误详情: {e}[/red]")
return result
def display_check_result(self, result: Dict[str, Any]):
"""显示检查结果 - 美观的Rich表格"""
table = Table(title="🤖 LLM 服务状态检查", box=None)
table.add_column("项目", style="cyan bold", width=20)
table.add_column("状态", style="green", width=40)
# 可用性状态
if result["available"]:
status_icon = ""
status_text = "[green]可用[/green]"
else:
status_icon = ""
status_text = "[red]不可用[/red]"
table.add_row(f"{status_icon} 可用性", status_text)
table.add_row("📝 状态信息", result["message"])
# 详细信息
if result["details"]:
for key, value in result["details"].items():
if isinstance(value, list):
display_value = f"[cyan]{', '.join(map(str, value))}[/cyan]"
elif isinstance(value, bool):
display_value = "[green]是[/green]" if value else "[red]否[/red]"
else:
display_value = f"[cyan]{value}[/cyan]"
# 美化键名
key_display = key.replace("_", " ").title()
table.add_row(f"📊 {key_display}", display_value)
# 错误信息
if result.get("error"):
table.add_row("🚨 错误详情", f"[red]{result['error']}[/red]")
console.print(table)
async def generate_commit_message_async(self, diff_content: str) -> Optional[str]:
"""异步生成提交消息(推荐用于异步应用)"""
if not self._ensure_client_initialized():
return None
try:
# 构建完整的提示
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
messages = [
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
{"role": "user", "content": full_prompt}
]
# 发送请求
response = await self.async_client.chat.completions.create(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
extra_body=self.config.extra_body,
# response_format=self.config.response_format, # 暂时注释因为可能不是所有API都支持
stream=False # 非流式
)
# 提取消息内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 清理返回的文本(移除多余的引号、空格等)
content = content.strip().strip('"').strip("'").strip()
return content
else:
console.print("[red]❌ 未收到有效的响应内容[/red]")
return None
except Exception as e:
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
return None
def generate_commit_message(self, diff_content: str) -> Optional[str]:
"""同步生成提交消息"""
if not self._ensure_client_initialized():
return None
if len(diff_content) >= 1024 * 16:
diff_content = diff_content[0: 1024 * 16]
try:
# 构建完整的提示
full_prompt = f"{self.config.system_prompt}\n{diff_content}"
messages = [
{"role": "system", "content": "你是一个专业的代码审查和版本控制助手。"},
{"role": "user", "content": full_prompt}
]
# 使用Rich显示进度
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True,
) as progress:
task = progress.add_task(
description="[cyan]🤖 正在生成提交消息...[/cyan]",
total=None
)
# 发送请求
response = self.client.chat.completions.create(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
extra_body=self.config.extra_body,
# response_format=self.config.response_format, # 暂时注释
stream=False # 非流式
)
progress.update(task, completed=100)
# 提取消息内容
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
# 清理返回的文本
content = content.strip().strip('"').strip("'").strip()
# 显示结果
console.print(Panel(
f"[bold green]{content}[/bold green]",
title="📝 生成的提交消息",
border_style="green",
padding=(1, 2)
))
return content
else:
console.print("[red]❌ 未收到有效的响应内容[/red]")
return None
except Exception as e:
console.print(f"[red]❌ 生成提交消息失败: {e}[/red]")
return None
def test_generation(self, test_diff: Optional[str] = None) -> Optional[str]:
"""测试消息生成功能"""
if not self._ensure_client_initialized():
return None
console.print("[bold cyan]🧪 测试消息生成...[/bold cyan]")
# 使用示例diff或提供的测试diff
if test_diff is None:
test_diff = """diff --git a/example.py b/example.py
index abc123..def456 100644
--- a/example.py
+++ b/example.py
@@ -1,5 +1,5 @@
def hello_world():
- print("Hello, World!")
+ print("Hello, OpenAI!")
def add(a, b):
return a + b"""
console.print("[dim]使用示例diff进行测试...[/dim]")
return self.generate_commit_message(test_diff)
def update_config(self, **kwargs):
"""更新配置并重新初始化客户端"""
# 更新配置
self.config_manager.update_llm_config(**kwargs)
# 重新加载配置
self.config = self.config_manager.load_config().llm
# 重新初始化客户端
self._init_clients()
if self.client:
console.print("[green]✅ 配置已更新,客户端重新初始化[/green]")
else:
console.print("[yellow]⚠️ 配置已更新,但客户端初始化失败[/yellow]")
# 辅助函数:同步调用异步函数
def generate_commit_message_sync(diff_content: str, config_manager: Optional[ConfigManager] = None) -> Optional[str]:
"""同步辅助函数:生成提交消息"""
client = LLMClient(config_manager)
# 尝试同步方法
result = client.generate_commit_message(diff_content)
# 如果同步方法失败,尝试异步方法
if result is None and client.async_client:
try:
console.print("[dim]尝试异步生成...[/dim]")
result = asyncio.run(client.generate_commit_message_async(diff_content))
except Exception as e:
console.print(f"[red]❌ 异步生成也失败: {e}[/red]")
return result