llmcodegen/llmcodegen.py

361 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/home/songsenand/env/.venv/bin/python
#!
"""
基于LLM的自动化代码生成工具
根据README.md文件自动生成项目文件结构并填充代码执行必要命令。
"""
import json
import os
import subprocess
import sys
from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path
import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
from loguru import logger
from openai import OpenAI
# ==================== 配置 ====================
DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单
app = typer.Typer(help="基于LLM的自动化代码生成工具")
console = Console()
# ==================== 工具函数 ====================
def is_dangerous_command(cmd: str) -> Tuple[bool, str]:
"""
判断命令是否危险
返回 (是否危险, 原因)
"""
cmd_lower = cmd.lower()
for danger in DANGEROUS_COMMANDS:
if danger in cmd_lower:
return True, f"包含危险关键词 '{danger}'"
return False, ""
# ==================== 核心类 ====================
class CodeGenerator:
"""代码生成器,封装所有逻辑"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "https://api.deepseek.com",
model: str = "deepseek-reasoner",
output_dir: str = "./generated",
log_file: Optional[str] = None,
):
"""
初始化生成器
Args:
api_key: OpenAI API密钥默认从环境变量DEEPSEEK_APIKEY读取
base_url: API基础URL
model: 使用的模型
output_dir: 输出根目录
log_file: 日志文件路径,默认自动生成
"""
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
if not self.api_key:
raise ValueError("必须提供API密钥或设置环境变量DEEPSEEK_APIKEY")
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
self.model = model
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# 配置日志
if log_file is None:
log_file = self.output_dir / "generator.log"
logger.remove() # 移除默认handler
logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
logger.info(f"日志已初始化,保存至: {log_file}")
self.readme_content = None
self.progress: Optional[Progress] = None
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
def _call_llm(
self,
system_prompt: str,
user_prompt: str,
temperature: float = 0.2,
expect_json: bool = True,
) -> Dict[str, Any]:
"""
调用LLM并返回解析后的JSON
"""
logger.debug(f"调用LLM模型: {self.model}")
logger.debug(f"System: {system_prompt[:200]}...")
logger.debug(f"User: {user_prompt[:200]}...")
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
response_format={"type": "json_object"} if expect_json else None,
)
message = response.choices[0].message
content = message.content
# 记录思考过程(如果存在)
if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f"模型思考过程: {message.reasoning_content}")
logger.debug(f"LLM原始响应: {content[:500]}...")
if expect_json:
result = json.loads(content)
else:
result = {"content": content}
return result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
except Exception as e:
logger.error(f"LLM调用失败: {e}")
raise
def parse_readme(self, readme_path: Path) -> str:
"""
读取README文件内容
"""
logger.info(f"读取README文件: {readme_path}")
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
logger.debug(f"README内容长度: {len(content)} 字符")
if (readme_path.parent / 'design.json').exists():
with open((readme_path.parent / 'design.json')) as f:
content += f'\n\ndesign.json(包含项目设计有关信息)内容如下:{f.read()}'
return content
except Exception as e:
logger.error(f"读取README失败: {e}")
raise
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
"""
根据README内容让LLM生成文件列表和依赖关系
Returns:
(files, dependencies)
files: 按顺序需要生成的文件路径列表
dependencies: 字典 {file: [依赖文件路径]}
"""
system_prompt = (
"你是一个软件架构师。请根据README描述分析需要生成哪些源代码文件并确定它们的生成顺序"
"同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。"
"返回严格的JSON对象包含两个字段\n"
"- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n"
"- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n"
"注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。"
)
user_prompt = f"README内容如下\n\n{self.readme_content}"
result = self._call_llm(system_prompt, user_prompt)
files = result.get("files", [])
dependencies = result.get("dependencies", {})
if not files:
raise ValueError("LLM未返回任何文件列表")
logger.info(f"解析到 {len(files)} 个待生成文件")
logger.debug(f"文件列表: {files}")
logger.debug(f"依赖关系: {dependencies}")
return files, dependencies
def generate_file(
self,
file_path: str,
prompt_instruction: str,
dependency_files: List[str],
) -> Tuple[str, str, List[str]]:
"""
生成单个文件,返回 (代码, 描述, 命令列表)
"""
# 读取依赖文件内容
context_content = []
if self.readme_content:
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
for dep in dependency_files:
dep_path = Path(dep)
if not dep_path.exists():
# 尝试相对于当前目录或输出目录查找
alt_path = self.output_dir / dep
if alt_path.exists():
dep_path = alt_path
else:
logger.warning(FileNotFoundError(f"依赖文件不存在: {dep}"))
with open(dep_path, "r", encoding="utf-8") as f:
content = f.read()
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
full_context = "\n".join(context_content)
system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
"返回严格的JSON对象包含三个字段\n"
"- code: (string) 生成的完整代码\n"
"- description: (string) 简短的中文功能描述\n"
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
)
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
result = self._call_llm(system_prompt, user_prompt)
code = result.get("code", "")
description = result.get("description", "")
commands = result.get("commands", [])
if not isinstance(commands, list):
commands = []
return code, description, commands
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None:
"""
执行单个命令,检查风险
"""
dangerous, reason = is_dangerous_command(cmd)
if dangerous:
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
return
logger.info(f"执行命令: {cmd}")
try:
result = subprocess.run(
cmd,
shell=True,
cwd=cwd or self.output_dir,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
logger.debug(f"命令返回码: {result.returncode}")
if result.stdout:
logger.debug(f"stdout: {result.stdout[:500]}")
if result.stderr:
logger.warning(f"stderr: {result.stderr[:500]}")
except subprocess.TimeoutExpired:
logger.error(f"命令执行超时: {cmd}")
except Exception as e:
logger.error(f"命令执行失败: {e}")
def run(self, readme_path: Path):
"""
主执行流程
"""
logger.info("=" * 50)
logger.info("开始代码生成流程")
logger.info(f"README: {readme_path}")
logger.info(f"输出目录: {self.output_dir}")
# 初始化阶段用rich输出状态不会被日志级别过滤
console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
self.readme_content = self.parse_readme(readme_path)
console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
files, dependencies = self.get_project_structure()
console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
# 3. 创建进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console,
) as progress:
self.progress = progress
# 创建总任务
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
# 依次生成每个文件
for idx, file in enumerate(files, 1):
logger.info(f"处理文件 [{idx}/{len(files)}]: {file}")
# 创建子任务(可选)
file_task = progress.add_task(f"生成 {file}", total=None)
try:
# 获取依赖文件
deps = dependencies.get(file, [])
# 构造生成指令
instruction = f"请根据README描述和依赖文件生成文件 '{file}' 的完整代码。"
# 调用LLM生成代码
code, desc, commands = self.generate_file(file, instruction, deps)
logger.info(f"生成完成: {file} - {desc}")
# 写入文件
output_path = self.output_dir / file
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"已写入: {output_path}")
# 执行命令
for cmd in commands:
logger.info(f"准备执行命令: {cmd}")
self.execute_command(cmd, cwd=self.output_dir)
except Exception as e:
logger.error(f"处理文件 {file} 失败: {e}")
# 可选:继续或终止
finally:
progress.remove_task(file_task)
progress.update(total_task, advance=1)
logger.success("所有文件处理完成!")
# ==================== CLI入口 ====================
@app.command()
def main(
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"),
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录默认为readme所在目录"),
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥也可通过环境变量DEEPSEEK_APIKEY设置"),
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径默认输出目录下generator.log"),
):
"""
根据README自动生成项目代码
"""
if output_dir is None:
output_dir = readme.parent
generator = CodeGenerator(
api_key=api_key,
base_url=base_url,
model=model,
output_dir=output_dir,
log_file=log_file,
)
generator.run(readme)
if __name__ == "__main__":
app()