385 lines
15 KiB
Python
385 lines
15 KiB
Python
import json
|
||
import os
|
||
import subprocess
|
||
import sys
|
||
from typing import List, Dict, Optional, Any, Tuple
|
||
from pathlib import Path
|
||
|
||
import typer
|
||
from rich.console import Console
|
||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
||
from loguru import logger
|
||
from openai import OpenAI
|
||
|
||
from .utils import is_dangerous_command, read_file, write_file, ensure_dir, safe_join
|
||
from .models import DesignModel, StateModel, LLMResponse
|
||
|
||
|
||
class CodeGenerator:
|
||
"""代码生成器,封装所有逻辑,支持设计层、断点续写和命令执行"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: Optional[str] = None,
|
||
base_url: str = "https://api.deepseek.com",
|
||
model: str = "deepseek-reasoner",
|
||
output_dir: str = "./generated",
|
||
log_file: Optional[str] = None,
|
||
):
|
||
"""
|
||
初始化生成器
|
||
|
||
Args:
|
||
api_key: OpenAI API密钥,默认从环境变量DEEPSEEK_APIKEY读取
|
||
base_url: API基础URL
|
||
model: 使用的模型
|
||
output_dir: 输出根目录
|
||
log_file: 日志文件路径,默认自动生成
|
||
"""
|
||
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
|
||
if not self.api_key:
|
||
raise ValueError("必须提供API密钥,或设置环境变量DEEPSEEK_APIKEY")
|
||
|
||
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
|
||
self.model = model
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
self.state_file = self.output_dir / ".llm_generator_state.json"
|
||
|
||
# 配置日志
|
||
if log_file is None:
|
||
log_file = self.output_dir / "generator.log"
|
||
logger.remove() # 移除默认handler
|
||
logger.add(sys.stderr, level="WARNING") # 控制台输出WARNING及以上
|
||
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
|
||
logger.info(f"日志已初始化,保存至: {log_file}")
|
||
|
||
self.readme_content = None
|
||
self.design: Optional[DesignModel] = None
|
||
self.state: Optional[StateModel] = None
|
||
self.progress: Optional[Progress] = None
|
||
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
|
||
|
||
def _call_llm(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float = 0.2,
|
||
expect_json: bool = True,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用LLM并返回解析后的JSON
|
||
"""
|
||
logger.debug(f"调用LLM,模型: {self.model}")
|
||
logger.debug(f"System: {system_prompt[:200]}...")
|
||
logger.debug(f"User: {user_prompt[:200]}...")
|
||
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
temperature=temperature,
|
||
response_format={"type": "json_object"} if expect_json else None,
|
||
)
|
||
|
||
message = response.choices[0].message
|
||
content = message.content
|
||
|
||
# 记录思考过程(如果存在)
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
logger.info(f"模型思考过程: {message.reasoning_content}")
|
||
|
||
logger.debug(f"LLM原始响应: {content[:500]}...")
|
||
|
||
if expect_json:
|
||
result = json.loads(content)
|
||
else:
|
||
result = {"content": content}
|
||
|
||
return result
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON解析失败: {e}")
|
||
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
|
||
except Exception as e:
|
||
logger.error(f"LLM调用失败: {e}")
|
||
raise
|
||
|
||
def parse_readme(self, readme_path: Path) -> str:
|
||
"""
|
||
读取README文件内容
|
||
"""
|
||
logger.info(f"读取README文件: {readme_path}")
|
||
try:
|
||
with open(readme_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
logger.debug(f"README内容长度: {len(content)} 字符")
|
||
return content
|
||
except Exception as e:
|
||
logger.error(f"读取README失败: {e}")
|
||
raise
|
||
|
||
def generate_design_json(self) -> DesignModel:
|
||
"""
|
||
调用LLM生成design.json内容,并解析为DesignModel
|
||
"""
|
||
system_prompt = (
|
||
"你是一个软件架构师。请根据README描述,生成项目的中间设计文件design.json。"
|
||
"design.json应包含项目名称、版本、描述、文件列表(含路径、摘要、依赖、函数和类)、建议命令和检查工具。"
|
||
"返回严格的JSON对象,符合DesignModel结构。"
|
||
)
|
||
user_prompt = f"README内容如下:\n\n{self.readme_content}"
|
||
|
||
result = self._call_llm(system_prompt, user_prompt)
|
||
design_data = result
|
||
design = DesignModel(**design_data)
|
||
|
||
# 写入design.json文件
|
||
design_path = self.output_dir / "design.json"
|
||
with open(design_path, "w", encoding="utf-8") as f:
|
||
json.dump(design.dict(), f, indent=2, ensure_ascii=False)
|
||
logger.info(f"已生成design.json: {design_path}")
|
||
|
||
return design
|
||
|
||
def load_state(self) -> Optional[StateModel]:
|
||
"""加载断点续写状态"""
|
||
if self.state_file.exists():
|
||
try:
|
||
with open(self.state_file, "r", encoding="utf-8") as f:
|
||
state_data = json.load(f)
|
||
self.state = StateModel(**state_data)
|
||
logger.info(f"加载状态成功: 当前文件索引 {self.state.current_file_index}")
|
||
return self.state
|
||
except Exception as e:
|
||
logger.error(f"加载状态失败: {e}")
|
||
return None
|
||
return None
|
||
|
||
def save_state(self, current_file_index: int, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
|
||
"""保存断点续写状态"""
|
||
state = StateModel(
|
||
current_file_index=current_file_index,
|
||
generated_files=generated_files,
|
||
dependencies_map=dependencies_map,
|
||
total_files=len(self.design.files) if self.design else 0,
|
||
output_dir=str(self.output_dir),
|
||
readme_path=self.readme_content[:100] if self.readme_content else ""
|
||
)
|
||
with open(self.state_file, "w", encoding="utf-8") as f:
|
||
json.dump(state.dict(), f, indent=2, ensure_ascii=False)
|
||
logger.debug(f"状态已保存: {self.state_file}")
|
||
|
||
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
||
"""
|
||
从design.json获取文件列表和依赖关系
|
||
|
||
Returns:
|
||
(files, dependencies)
|
||
files: 按顺序需要生成的文件路径列表
|
||
dependencies: 字典 {file: [依赖文件路径]}
|
||
"""
|
||
if not self.design:
|
||
raise ValueError("design.json未加载,请先调用generate_design_json")
|
||
|
||
files = [file.path for file in self.design.files]
|
||
dependencies = {file.path: file.dependencies for file in self.design.files}
|
||
|
||
logger.info(f"从design.json解析到 {len(files)} 个待生成文件")
|
||
logger.debug(f"文件列表: {files}")
|
||
logger.debug(f"依赖关系: {dependencies}")
|
||
|
||
return files, dependencies
|
||
|
||
def generate_file(
|
||
self,
|
||
file_path: str,
|
||
prompt_instruction: str,
|
||
dependency_files: List[str],
|
||
) -> Tuple[str, str, List[str]]:
|
||
"""
|
||
生成单个文件,返回 (代码, 描述, 命令列表)
|
||
"""
|
||
# 读取依赖文件内容
|
||
context_content = []
|
||
|
||
if self.readme_content:
|
||
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
|
||
|
||
# 添加design.json上下文
|
||
design_path = self.output_dir / "design.json"
|
||
if design_path.exists():
|
||
with open(design_path, "r", encoding="utf-8") as f:
|
||
design_content = f.read()
|
||
context_content.append(f"### 设计文件: design.json ###\n{design_content}\n")
|
||
|
||
for dep in dependency_files:
|
||
dep_path = Path(dep)
|
||
if not dep_path.exists():
|
||
# 尝试相对于当前目录或输出目录查找
|
||
alt_path = self.output_dir / dep
|
||
if alt_path.exists():
|
||
dep_path = alt_path
|
||
else:
|
||
raise FileNotFoundError(f"依赖文件不存在: {dep}")
|
||
|
||
with open(dep_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
|
||
|
||
full_context = "\n".join(context_content)
|
||
|
||
system_prompt = (
|
||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||
"返回严格的JSON对象,包含三个字段:\n"
|
||
"- code: (string) 生成的完整代码\n"
|
||
"- description: (string) 简短的中文功能描述\n"
|
||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
||
)
|
||
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
||
|
||
result = self._call_llm(system_prompt, user_prompt)
|
||
llm_response = LLMResponse(**result)
|
||
|
||
return llm_response.code, llm_response.description, llm_response.commands
|
||
|
||
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
|
||
"""
|
||
执行单个命令,检查风险,失败仅记录错误不抛出异常
|
||
|
||
Returns:
|
||
bool: 命令是否成功执行
|
||
"""
|
||
dangerous, reason = is_dangerous_command(cmd)
|
||
if dangerous:
|
||
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
|
||
return False
|
||
|
||
logger.info(f"执行命令: {cmd}")
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
shell=True,
|
||
cwd=cwd or self.output_dir,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=300, # 5分钟超时
|
||
)
|
||
logger.debug(f"命令返回码: {result.returncode}")
|
||
if result.stdout:
|
||
logger.debug(f"stdout: {result.stdout[:500]}")
|
||
if result.stderr:
|
||
logger.warning(f"stderr: {result.stderr[:500]}")
|
||
if result.returncode != 0:
|
||
logger.error(f"命令执行失败,返回码: {result.returncode}")
|
||
return False
|
||
return True
|
||
except subprocess.TimeoutExpired:
|
||
logger.error(f"命令执行超时: {cmd}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"命令执行失败: {e}")
|
||
return False
|
||
|
||
def run(self, readme_path: Path):
|
||
"""
|
||
主执行流程,支持设计层生成和断点续写
|
||
"""
|
||
console = Console()
|
||
logger.info("=" * 50)
|
||
logger.info("开始代码生成流程")
|
||
logger.info(f"README: {readme_path}")
|
||
logger.info(f"输出目录: {self.output_dir}")
|
||
|
||
# 解析README
|
||
console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
|
||
self.readme_content = self.parse_readme(readme_path)
|
||
|
||
# 加载状态
|
||
state = self.load_state()
|
||
if state:
|
||
console.print(f"[green]✅ 检测到断点状态,从文件索引 {state.current_file_index} 继续[/green]")
|
||
self.state = state
|
||
# 从状态恢复设计,假设design.json已存在
|
||
design_path = self.output_dir / "design.json"
|
||
if design_path.exists():
|
||
with open(design_path, "r", encoding="utf-8") as f:
|
||
design_data = json.load(f)
|
||
self.design = DesignModel(**design_data)
|
||
else:
|
||
console.print("[bold yellow]⚠ design.json不存在,重新生成...[/bold yellow]")
|
||
self.design = self.generate_design_json()
|
||
else:
|
||
console.print("[bold yellow]📋 正在生成设计文件...[/bold yellow]")
|
||
self.design = self.generate_design_json()
|
||
self.state = None
|
||
|
||
# 获取项目结构
|
||
console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
|
||
files, dependencies = self.get_project_structure()
|
||
console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
|
||
|
||
# 断点续写:确定起始索引
|
||
start_index = self.state.current_file_index if self.state else 0
|
||
generated_files = self.state.generated_files if self.state else []
|
||
|
||
# 创建进度条
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[progress.description]{task.description}"),
|
||
BarColumn(),
|
||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||
console=console,
|
||
) as progress:
|
||
self.progress = progress
|
||
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
|
||
progress.update(total_task, completed=start_index)
|
||
|
||
# 依次生成每个文件
|
||
for idx in range(start_index, len(files)):
|
||
file = files[idx]
|
||
logger.info(f"处理文件 [{idx + 1}/{len(files)}]: {file}")
|
||
file_task = progress.add_task(f"生成 {file}", total=None)
|
||
|
||
try:
|
||
# 获取依赖文件
|
||
deps = dependencies.get(file, [])
|
||
instruction = f"请根据README描述和依赖文件,生成文件 '{file}' 的完整代码。"
|
||
code, desc, commands = self.generate_file(file, instruction, deps)
|
||
logger.info(f"生成完成: {file} - {desc}")
|
||
|
||
# 写入文件
|
||
output_path = self.output_dir / file
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
f.write(code)
|
||
logger.info(f"已写入: {output_path}")
|
||
generated_files.append(file)
|
||
|
||
# 执行命令
|
||
for cmd in commands:
|
||
logger.info(f"准备执行命令: {cmd}")
|
||
success = self.execute_command(cmd, cwd=self.output_dir)
|
||
if not success:
|
||
logger.warning(f"命令执行失败,但继续处理: {cmd}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file} 失败: {e}")
|
||
# 保存状态以便断点续写
|
||
self.save_state(idx, generated_files, dependencies)
|
||
raise
|
||
finally:
|
||
progress.remove_task(file_task)
|
||
progress.update(total_task, advance=1)
|
||
# 更新状态
|
||
self.save_state(idx + 1, generated_files, dependencies)
|
||
|
||
logger.success("所有文件处理完成!")
|
||
# 清理状态文件
|
||
if self.state_file.exists():
|
||
self.state_file.unlink()
|
||
logger.info("状态文件已清理")
|