361 lines
14 KiB
Python
361 lines
14 KiB
Python
#!/home/songsenand/env/.venv/bin/python
|
||
#!
|
||
"""
|
||
基于LLM的自动化代码生成工具
|
||
根据README.md文件,自动生成项目文件结构并填充代码,执行必要命令。
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import sys
|
||
from typing import List, Dict, Optional, Any, Tuple
|
||
from pathlib import Path
|
||
|
||
import typer
|
||
from rich.console import Console
|
||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
||
from loguru import logger
|
||
from openai import OpenAI
|
||
|
||
# ==================== 配置 ====================
|
||
DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||
ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单
|
||
|
||
app = typer.Typer(help="基于LLM的自动化代码生成工具")
|
||
console = Console()
|
||
|
||
# ==================== 工具函数 ====================
|
||
def is_dangerous_command(cmd: str) -> Tuple[bool, str]:
|
||
"""
|
||
判断命令是否危险
|
||
返回 (是否危险, 原因)
|
||
"""
|
||
cmd_lower = cmd.lower()
|
||
for danger in DANGEROUS_COMMANDS:
|
||
if danger in cmd_lower:
|
||
return True, f"包含危险关键词 '{danger}'"
|
||
return False, ""
|
||
|
||
# ==================== 核心类 ====================
|
||
class CodeGenerator:
|
||
"""代码生成器,封装所有逻辑"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: Optional[str] = None,
|
||
base_url: str = "https://api.deepseek.com",
|
||
model: str = "deepseek-reasoner",
|
||
output_dir: str = "./generated",
|
||
log_file: Optional[str] = None,
|
||
):
|
||
"""
|
||
初始化生成器
|
||
|
||
Args:
|
||
api_key: OpenAI API密钥,默认从环境变量DEEPSEEK_APIKEY读取
|
||
base_url: API基础URL
|
||
model: 使用的模型
|
||
output_dir: 输出根目录
|
||
log_file: 日志文件路径,默认自动生成
|
||
"""
|
||
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
|
||
if not self.api_key:
|
||
raise ValueError("必须提供API密钥,或设置环境变量DEEPSEEK_APIKEY")
|
||
|
||
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
|
||
self.model = model
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 配置日志
|
||
if log_file is None:
|
||
log_file = self.output_dir / "generator.log"
|
||
logger.remove() # 移除默认handler
|
||
logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上
|
||
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
|
||
logger.info(f"日志已初始化,保存至: {log_file}")
|
||
|
||
self.readme_content = None
|
||
|
||
self.progress: Optional[Progress] = None
|
||
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
|
||
|
||
def _call_llm(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float = 0.2,
|
||
expect_json: bool = True,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用LLM并返回解析后的JSON
|
||
"""
|
||
logger.debug(f"调用LLM,模型: {self.model}")
|
||
logger.debug(f"System: {system_prompt[:200]}...")
|
||
logger.debug(f"User: {user_prompt[:200]}...")
|
||
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
temperature=temperature,
|
||
response_format={"type": "json_object"} if expect_json else None,
|
||
)
|
||
|
||
message = response.choices[0].message
|
||
content = message.content
|
||
|
||
# 记录思考过程(如果存在)
|
||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||
logger.info(f"模型思考过程: {message.reasoning_content}")
|
||
|
||
logger.debug(f"LLM原始响应: {content[:500]}...")
|
||
|
||
if expect_json:
|
||
result = json.loads(content)
|
||
else:
|
||
result = {"content": content}
|
||
|
||
return result
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON解析失败: {e}")
|
||
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
|
||
except Exception as e:
|
||
logger.error(f"LLM调用失败: {e}")
|
||
raise
|
||
|
||
def parse_readme(self, readme_path: Path) -> str:
|
||
"""
|
||
读取README文件内容
|
||
"""
|
||
logger.info(f"读取README文件: {readme_path}")
|
||
try:
|
||
with open(readme_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
logger.debug(f"README内容长度: {len(content)} 字符")
|
||
if (readme_path.parent / 'design.json').exists():
|
||
with open((readme_path.parent / 'design.json')) as f:
|
||
content += f'\n\ndesign.json(包含项目设计有关信息)内容如下:{f.read()}'
|
||
return content
|
||
except Exception as e:
|
||
logger.error(f"读取README失败: {e}")
|
||
raise
|
||
|
||
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
||
"""
|
||
根据README内容,让LLM生成文件列表和依赖关系
|
||
|
||
Returns:
|
||
(files, dependencies)
|
||
files: 按顺序需要生成的文件路径列表
|
||
dependencies: 字典 {file: [依赖文件路径]}
|
||
"""
|
||
system_prompt = (
|
||
"你是一个软件架构师。请根据README描述,分析需要生成哪些源代码文件,并确定它们的生成顺序,"
|
||
"同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。"
|
||
"返回严格的JSON对象,包含两个字段:\n"
|
||
"- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n"
|
||
"- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n"
|
||
"注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。"
|
||
)
|
||
user_prompt = f"README内容如下:\n\n{self.readme_content}"
|
||
|
||
result = self._call_llm(system_prompt, user_prompt)
|
||
|
||
files = result.get("files", [])
|
||
dependencies = result.get("dependencies", {})
|
||
|
||
if not files:
|
||
raise ValueError("LLM未返回任何文件列表")
|
||
|
||
logger.info(f"解析到 {len(files)} 个待生成文件")
|
||
logger.debug(f"文件列表: {files}")
|
||
logger.debug(f"依赖关系: {dependencies}")
|
||
|
||
return files, dependencies
|
||
|
||
def generate_file(
|
||
self,
|
||
file_path: str,
|
||
prompt_instruction: str,
|
||
dependency_files: List[str],
|
||
) -> Tuple[str, str, List[str]]:
|
||
"""
|
||
生成单个文件,返回 (代码, 描述, 命令列表)
|
||
"""
|
||
# 读取依赖文件内容
|
||
context_content = []
|
||
|
||
if self.readme_content:
|
||
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
|
||
|
||
for dep in dependency_files:
|
||
dep_path = Path(dep)
|
||
if not dep_path.exists():
|
||
# 尝试相对于当前目录或输出目录查找
|
||
alt_path = self.output_dir / dep
|
||
if alt_path.exists():
|
||
dep_path = alt_path
|
||
else:
|
||
logger.warning(FileNotFoundError(f"依赖文件不存在: {dep}"))
|
||
|
||
with open(dep_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
|
||
|
||
full_context = "\n".join(context_content)
|
||
|
||
system_prompt = (
|
||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||
"返回严格的JSON对象,包含三个字段:\n"
|
||
"- code: (string) 生成的完整代码\n"
|
||
"- description: (string) 简短的中文功能描述\n"
|
||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
||
)
|
||
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
||
|
||
result = self._call_llm(system_prompt, user_prompt)
|
||
|
||
code = result.get("code", "")
|
||
description = result.get("description", "")
|
||
commands = result.get("commands", [])
|
||
|
||
if not isinstance(commands, list):
|
||
commands = []
|
||
|
||
return code, description, commands
|
||
|
||
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None:
|
||
"""
|
||
执行单个命令,检查风险
|
||
"""
|
||
dangerous, reason = is_dangerous_command(cmd)
|
||
if dangerous:
|
||
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
|
||
return
|
||
|
||
logger.info(f"执行命令: {cmd}")
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
shell=True,
|
||
cwd=cwd or self.output_dir,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=300, # 5分钟超时
|
||
)
|
||
logger.debug(f"命令返回码: {result.returncode}")
|
||
if result.stdout:
|
||
logger.debug(f"stdout: {result.stdout[:500]}")
|
||
if result.stderr:
|
||
logger.warning(f"stderr: {result.stderr[:500]}")
|
||
except subprocess.TimeoutExpired:
|
||
logger.error(f"命令执行超时: {cmd}")
|
||
except Exception as e:
|
||
logger.error(f"命令执行失败: {e}")
|
||
|
||
|
||
def run(self, readme_path: Path):
|
||
"""
|
||
主执行流程
|
||
"""
|
||
logger.info("=" * 50)
|
||
logger.info("开始代码生成流程")
|
||
logger.info(f"README: {readme_path}")
|
||
logger.info(f"输出目录: {self.output_dir}")
|
||
|
||
# 初始化阶段:用rich输出状态(不会被日志级别过滤)
|
||
console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
|
||
self.readme_content = self.parse_readme(readme_path)
|
||
|
||
console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
|
||
files, dependencies = self.get_project_structure()
|
||
|
||
console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
|
||
|
||
# 3. 创建进度条
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[progress.description]{task.description}"),
|
||
BarColumn(),
|
||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||
console=console,
|
||
) as progress:
|
||
self.progress = progress
|
||
# 创建总任务
|
||
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
|
||
|
||
# 依次生成每个文件
|
||
for idx, file in enumerate(files, 1):
|
||
logger.info(f"处理文件 [{idx}/{len(files)}]: {file}")
|
||
|
||
# 创建子任务(可选)
|
||
file_task = progress.add_task(f"生成 {file}", total=None)
|
||
|
||
try:
|
||
# 获取依赖文件
|
||
deps = dependencies.get(file, [])
|
||
|
||
# 构造生成指令
|
||
instruction = f"请根据README描述和依赖文件,生成文件 '{file}' 的完整代码。"
|
||
|
||
# 调用LLM生成代码
|
||
code, desc, commands = self.generate_file(file, instruction, deps)
|
||
|
||
logger.info(f"生成完成: {file} - {desc}")
|
||
|
||
# 写入文件
|
||
output_path = self.output_dir / file
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
f.write(code)
|
||
logger.info(f"已写入: {output_path}")
|
||
|
||
# 执行命令
|
||
for cmd in commands:
|
||
logger.info(f"准备执行命令: {cmd}")
|
||
self.execute_command(cmd, cwd=self.output_dir)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file} 失败: {e}")
|
||
# 可选:继续或终止
|
||
finally:
|
||
progress.remove_task(file_task)
|
||
progress.update(total_task, advance=1)
|
||
|
||
logger.success("所有文件处理完成!")
|
||
|
||
# ==================== CLI入口 ====================
|
||
@app.command()
|
||
def main(
|
||
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"),
|
||
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录,默认为readme所在目录"),
|
||
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥,也可通过环境变量DEEPSEEK_APIKEY设置"),
|
||
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"),
|
||
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
|
||
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径(默认输出目录下generator.log)"),
|
||
):
|
||
"""
|
||
根据README自动生成项目代码
|
||
"""
|
||
if output_dir is None:
|
||
output_dir = readme.parent
|
||
|
||
generator = CodeGenerator(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=model,
|
||
output_dir=output_dir,
|
||
log_file=log_file,
|
||
)
|
||
generator.run(readme)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app()
|