llmcodegen/src/llm_codegen/core.py

385 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import subprocess
import sys
from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path
import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
from loguru import logger
from openai import OpenAI
from .utils import is_dangerous_command, read_file, write_file, ensure_dir, safe_join
from .models import DesignModel, StateModel, LLMResponse
class CodeGenerator:
"""代码生成器,封装所有逻辑,支持设计层、断点续写和命令执行"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "https://api.deepseek.com",
model: str = "deepseek-reasoner",
output_dir: str = "./generated",
log_file: Optional[str] = None,
):
"""
初始化生成器
Args:
api_key: OpenAI API密钥默认从环境变量DEEPSEEK_APIKEY读取
base_url: API基础URL
model: 使用的模型
output_dir: 输出根目录
log_file: 日志文件路径,默认自动生成
"""
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
if not self.api_key:
raise ValueError("必须提供API密钥或设置环境变量DEEPSEEK_APIKEY")
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
self.model = model
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.state_file = self.output_dir / ".llm_generator_state.json"
# 配置日志
if log_file is None:
log_file = self.output_dir / "generator.log"
logger.remove() # 移除默认handler
logger.add(sys.stderr, level="WARNING") # 控制台输出WARNING及以上
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
logger.info(f"日志已初始化,保存至: {log_file}")
self.readme_content = None
self.design: Optional[DesignModel] = None
self.state: Optional[StateModel] = None
self.progress: Optional[Progress] = None
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
def _call_llm(
self,
system_prompt: str,
user_prompt: str,
temperature: float = 0.2,
expect_json: bool = True,
) -> Dict[str, Any]:
"""
调用LLM并返回解析后的JSON
"""
logger.debug(f"调用LLM模型: {self.model}")
logger.debug(f"System: {system_prompt[:200]}...")
logger.debug(f"User: {user_prompt[:200]}...")
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
response_format={"type": "json_object"} if expect_json else None,
)
message = response.choices[0].message
content = message.content
# 记录思考过程(如果存在)
if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f"模型思考过程: {message.reasoning_content}")
logger.debug(f"LLM原始响应: {content[:500]}...")
if expect_json:
result = json.loads(content)
else:
result = {"content": content}
return result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
except Exception as e:
logger.error(f"LLM调用失败: {e}")
raise
def parse_readme(self, readme_path: Path) -> str:
"""
读取README文件内容
"""
logger.info(f"读取README文件: {readme_path}")
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
logger.debug(f"README内容长度: {len(content)} 字符")
return content
except Exception as e:
logger.error(f"读取README失败: {e}")
raise
def generate_design_json(self) -> DesignModel:
"""
调用LLM生成design.json内容并解析为DesignModel
"""
system_prompt = (
"你是一个软件架构师。请根据README描述生成项目的中间设计文件design.json。"
"design.json应包含项目名称、版本、描述、文件列表含路径、摘要、依赖、函数和类、建议命令和检查工具。"
"返回严格的JSON对象符合DesignModel结构。"
)
user_prompt = f"README内容如下\n\n{self.readme_content}"
result = self._call_llm(system_prompt, user_prompt)
design_data = result
design = DesignModel(**design_data)
# 写入design.json文件
design_path = self.output_dir / "design.json"
with open(design_path, "w", encoding="utf-8") as f:
json.dump(design.dict(), f, indent=2, ensure_ascii=False)
logger.info(f"已生成design.json: {design_path}")
return design
def load_state(self) -> Optional[StateModel]:
"""加载断点续写状态"""
if self.state_file.exists():
try:
with open(self.state_file, "r", encoding="utf-8") as f:
state_data = json.load(f)
self.state = StateModel(**state_data)
logger.info(f"加载状态成功: 当前文件索引 {self.state.current_file_index}")
return self.state
except Exception as e:
logger.error(f"加载状态失败: {e}")
return None
return None
def save_state(self, current_file_index: int, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
"""保存断点续写状态"""
state = StateModel(
current_file_index=current_file_index,
generated_files=generated_files,
dependencies_map=dependencies_map,
total_files=len(self.design.files) if self.design else 0,
output_dir=str(self.output_dir),
readme_path=self.readme_content[:100] if self.readme_content else ""
)
with open(self.state_file, "w", encoding="utf-8") as f:
json.dump(state.dict(), f, indent=2, ensure_ascii=False)
logger.debug(f"状态已保存: {self.state_file}")
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
"""
从design.json获取文件列表和依赖关系
Returns:
(files, dependencies)
files: 按顺序需要生成的文件路径列表
dependencies: 字典 {file: [依赖文件路径]}
"""
if not self.design:
raise ValueError("design.json未加载请先调用generate_design_json")
files = [file.path for file in self.design.files]
dependencies = {file.path: file.dependencies for file in self.design.files}
logger.info(f"从design.json解析到 {len(files)} 个待生成文件")
logger.debug(f"文件列表: {files}")
logger.debug(f"依赖关系: {dependencies}")
return files, dependencies
def generate_file(
self,
file_path: str,
prompt_instruction: str,
dependency_files: List[str],
) -> Tuple[str, str, List[str]]:
"""
生成单个文件,返回 (代码, 描述, 命令列表)
"""
# 读取依赖文件内容
context_content = []
if self.readme_content:
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
# 添加design.json上下文
design_path = self.output_dir / "design.json"
if design_path.exists():
with open(design_path, "r", encoding="utf-8") as f:
design_content = f.read()
context_content.append(f"### 设计文件: design.json ###\n{design_content}\n")
for dep in dependency_files:
dep_path = Path(dep)
if not dep_path.exists():
# 尝试相对于当前目录或输出目录查找
alt_path = self.output_dir / dep
if alt_path.exists():
dep_path = alt_path
else:
raise FileNotFoundError(f"依赖文件不存在: {dep}")
with open(dep_path, "r", encoding="utf-8") as f:
content = f.read()
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
full_context = "\n".join(context_content)
system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
"返回严格的JSON对象包含三个字段\n"
"- code: (string) 生成的完整代码\n"
"- description: (string) 简短的中文功能描述\n"
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
)
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
result = self._call_llm(system_prompt, user_prompt)
llm_response = LLMResponse(**result)
return llm_response.code, llm_response.description, llm_response.commands
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
"""
执行单个命令,检查风险,失败仅记录错误不抛出异常
Returns:
bool: 命令是否成功执行
"""
dangerous, reason = is_dangerous_command(cmd)
if dangerous:
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
return False
logger.info(f"执行命令: {cmd}")
try:
result = subprocess.run(
cmd,
shell=True,
cwd=cwd or self.output_dir,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
logger.debug(f"命令返回码: {result.returncode}")
if result.stdout:
logger.debug(f"stdout: {result.stdout[:500]}")
if result.stderr:
logger.warning(f"stderr: {result.stderr[:500]}")
if result.returncode != 0:
logger.error(f"命令执行失败,返回码: {result.returncode}")
return False
return True
except subprocess.TimeoutExpired:
logger.error(f"命令执行超时: {cmd}")
return False
except Exception as e:
logger.error(f"命令执行失败: {e}")
return False
def run(self, readme_path: Path):
"""
主执行流程,支持设计层生成和断点续写
"""
console = Console()
logger.info("=" * 50)
logger.info("开始代码生成流程")
logger.info(f"README: {readme_path}")
logger.info(f"输出目录: {self.output_dir}")
# 解析README
console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
self.readme_content = self.parse_readme(readme_path)
# 加载状态
state = self.load_state()
if state:
console.print(f"[green]✅ 检测到断点状态,从文件索引 {state.current_file_index} 继续[/green]")
self.state = state
# 从状态恢复设计假设design.json已存在
design_path = self.output_dir / "design.json"
if design_path.exists():
with open(design_path, "r", encoding="utf-8") as f:
design_data = json.load(f)
self.design = DesignModel(**design_data)
else:
console.print("[bold yellow]⚠ design.json不存在重新生成...[/bold yellow]")
self.design = self.generate_design_json()
else:
console.print("[bold yellow]📋 正在生成设计文件...[/bold yellow]")
self.design = self.generate_design_json()
self.state = None
# 获取项目结构
console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
files, dependencies = self.get_project_structure()
console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
# 断点续写:确定起始索引
start_index = self.state.current_file_index if self.state else 0
generated_files = self.state.generated_files if self.state else []
# 创建进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console,
) as progress:
self.progress = progress
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
progress.update(total_task, completed=start_index)
# 依次生成每个文件
for idx in range(start_index, len(files)):
file = files[idx]
logger.info(f"处理文件 [{idx + 1}/{len(files)}]: {file}")
file_task = progress.add_task(f"生成 {file}", total=None)
try:
# 获取依赖文件
deps = dependencies.get(file, [])
instruction = f"请根据README描述和依赖文件生成文件 '{file}' 的完整代码。"
code, desc, commands = self.generate_file(file, instruction, deps)
logger.info(f"生成完成: {file} - {desc}")
# 写入文件
output_path = self.output_dir / file
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"已写入: {output_path}")
generated_files.append(file)
# 执行命令
for cmd in commands:
logger.info(f"准备执行命令: {cmd}")
success = self.execute_command(cmd, cwd=self.output_dir)
if not success:
logger.warning(f"命令执行失败,但继续处理: {cmd}")
except Exception as e:
logger.error(f"处理文件 {file} 失败: {e}")
# 保存状态以便断点续写
self.save_state(idx, generated_files, dependencies)
raise
finally:
progress.remove_task(file_task)
progress.update(total_task, advance=1)
# 更新状态
self.save_state(idx + 1, generated_files, dependencies)
logger.success("所有文件处理完成!")
# 清理状态文件
if self.state_file.exists():
self.state_file.unlink()
logger.info("状态文件已清理")