feat: 添加项目初始化文件和README模板以支持代码生成工具
This commit is contained in:
commit
d431b24755
|
|
@ -0,0 +1,10 @@
|
|||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
|
@ -0,0 +1 @@
|
|||
3.13
|
||||
|
|
@ -0,0 +1,533 @@
|
|||
# LLM 代码生成工具(自举版)
|
||||
|
||||
本项目是一个基于大语言模型的代码生成工具,能够根据项目 `README.md` 描述自动生成完整的 Python 包代码,并具备代码检查、测试和自动修复能力。它是前一个代码生成器的升级版本,采用 `uv` 进行包管理,包含完整的单元测试、并行检查模块,并可通过命令行直接调用。
|
||||
|
||||
## 特别说明
|
||||
|
||||
我已经实现了一个简易版本,请在此基础上拓展开发:
|
||||
|
||||
```python
|
||||
#!/home/songsenand/env/.venv/bin/python
|
||||
"""
|
||||
基于LLM的自动化代码生成工具
|
||||
根据README.md文件,自动生成项目文件结构并填充代码,执行必要命令。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
|
||||
# ==================== 配置 ====================
|
||||
DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||||
ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单
|
||||
|
||||
app = typer.Typer(help="基于LLM的自动化代码生成工具")
|
||||
console = Console()
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
def is_dangerous_command(cmd: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
判断命令是否危险
|
||||
返回 (是否危险, 原因)
|
||||
"""
|
||||
cmd_lower = cmd.lower()
|
||||
for danger in DANGEROUS_COMMANDS:
|
||||
if danger in cmd_lower:
|
||||
return True, f"包含危险关键词 '{danger}'"
|
||||
return False, ""
|
||||
|
||||
# ==================== 核心类 ====================
|
||||
class CodeGenerator:
|
||||
"""代码生成器,封装所有逻辑"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.deepseek.com",
|
||||
model: str = "deepseek-reasoner",
|
||||
output_dir: str = "./generated",
|
||||
log_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化生成器
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥,默认从环境变量DEEPSEEK_APIKEY读取
|
||||
base_url: API基础URL
|
||||
model: 使用的模型
|
||||
output_dir: 输出根目录
|
||||
log_file: 日志文件路径,默认自动生成
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("必须提供API密钥,或设置环境变量DEEPSEEK_APIKEY")
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 配置日志
|
||||
if log_file is None:
|
||||
log_file = self.output_dir / "generator.log"
|
||||
logger.remove() # 移除默认handler
|
||||
logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上
|
||||
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
|
||||
logger.info(f"日志已初始化,保存至: {log_file}")
|
||||
|
||||
self.readme_content = None
|
||||
|
||||
self.progress: Optional[Progress] = None
|
||||
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
|
||||
|
||||
def _call_llm(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = 0.2,
|
||||
expect_json: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用LLM并返回解析后的JSON
|
||||
"""
|
||||
logger.debug(f"调用LLM,模型: {self.model}")
|
||||
logger.debug(f"System: {system_prompt[:200]}...")
|
||||
logger.debug(f"User: {user_prompt[:200]}...")
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=temperature,
|
||||
response_format={"type": "json_object"} if expect_json else None,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
content = message.content
|
||||
|
||||
# 记录思考过程(如果存在)
|
||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||||
logger.info(f"模型思考过程: {message.reasoning_content}")
|
||||
|
||||
logger.debug(f"LLM原始响应: {content[:500]}...")
|
||||
|
||||
if expect_json:
|
||||
result = json.loads(content)
|
||||
else:
|
||||
result = {"content": content}
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}")
|
||||
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM调用失败: {e}")
|
||||
raise
|
||||
|
||||
def parse_readme(self, readme_path: Path) -> str:
|
||||
"""
|
||||
读取README文件内容
|
||||
"""
|
||||
logger.info(f"读取README文件: {readme_path}")
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
logger.debug(f"README内容长度: {len(content)} 字符")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"读取README失败: {e}")
|
||||
raise
|
||||
|
||||
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
||||
"""
|
||||
根据README内容,让LLM生成文件列表和依赖关系
|
||||
|
||||
Returns:
|
||||
(files, dependencies)
|
||||
files: 按顺序需要生成的文件路径列表
|
||||
dependencies: 字典 {file: [依赖文件路径]}
|
||||
"""
|
||||
system_prompt = (
|
||||
"你是一个软件架构师。请根据README描述,分析需要生成哪些源代码文件,并确定它们的生成顺序,"
|
||||
"同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。"
|
||||
"返回严格的JSON对象,包含两个字段:\n"
|
||||
"- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n"
|
||||
"- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n"
|
||||
"注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。"
|
||||
)
|
||||
user_prompt = f"README内容如下:\n\n{self.readme_content}"
|
||||
|
||||
result = self._call_llm(system_prompt, user_prompt)
|
||||
|
||||
files = result.get("files", [])
|
||||
dependencies = result.get("dependencies", {})
|
||||
|
||||
if not files:
|
||||
raise ValueError("LLM未返回任何文件列表")
|
||||
|
||||
logger.info(f"解析到 {len(files)} 个待生成文件")
|
||||
logger.debug(f"文件列表: {files}")
|
||||
logger.debug(f"依赖关系: {dependencies}")
|
||||
|
||||
return files, dependencies
|
||||
|
||||
def generate_file(
|
||||
self,
|
||||
file_path: str,
|
||||
prompt_instruction: str,
|
||||
dependency_files: List[str],
|
||||
) -> Tuple[str, str, List[str]]:
|
||||
"""
|
||||
生成单个文件,返回 (代码, 描述, 命令列表)
|
||||
"""
|
||||
# 读取依赖文件内容
|
||||
context_content = []
|
||||
|
||||
if self.readme_content:
|
||||
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
|
||||
|
||||
for dep in dependency_files:
|
||||
dep_path = Path(dep)
|
||||
if not dep_path.exists():
|
||||
# 尝试相对于当前目录或输出目录查找
|
||||
alt_path = self.output_dir / dep
|
||||
if alt_path.exists():
|
||||
dep_path = alt_path
|
||||
else:
|
||||
raise FileNotFoundError(f"依赖文件不存在: {dep}")
|
||||
|
||||
with open(dep_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
|
||||
|
||||
full_context = "\n".join(context_content)
|
||||
|
||||
system_prompt = (
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||||
"返回严格的JSON对象,包含三个字段:\n"
|
||||
"- code: (string) 生成的完整代码\n"
|
||||
"- description: (string) 简短的中文功能描述\n"
|
||||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
||||
)
|
||||
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
||||
|
||||
result = self._call_llm(system_prompt, user_prompt)
|
||||
|
||||
code = result.get("code", "")
|
||||
description = result.get("description", "")
|
||||
commands = result.get("commands", [])
|
||||
|
||||
if not isinstance(commands, list):
|
||||
commands = []
|
||||
|
||||
return code, description, commands
|
||||
|
||||
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None:
|
||||
"""
|
||||
执行单个命令,检查风险
|
||||
"""
|
||||
dangerous, reason = is_dangerous_command(cmd)
|
||||
if dangerous:
|
||||
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
|
||||
raise RuntimeError(f"危险命令: {cmd} ({reason})")
|
||||
|
||||
logger.info(f"执行命令: {cmd}")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
cwd=cwd or self.output_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
logger.debug(f"命令返回码: {result.returncode}")
|
||||
if result.stdout:
|
||||
logger.debug(f"stdout: {result.stdout[:500]}")
|
||||
if result.stderr:
|
||||
logger.warning(f"stderr: {result.stderr[:500]}")
|
||||
if result.returncode != 0:
|
||||
raise subprocess.CalledProcessError(result.returncode, cmd)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"命令执行超时: {cmd}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"命令执行失败: {e}")
|
||||
raise
|
||||
|
||||
def run(self, readme_path: Path):
|
||||
"""
|
||||
主执行流程
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info("开始代码生成流程")
|
||||
logger.info(f"README: {readme_path}")
|
||||
logger.info(f"输出目录: {self.output_dir}")
|
||||
|
||||
# 初始化阶段:用rich输出状态(不会被日志级别过滤)
|
||||
console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
|
||||
self.readme_content = self.parse_readme(readme_path)
|
||||
|
||||
console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
|
||||
files, dependencies = self.get_project_structure()
|
||||
|
||||
console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
|
||||
|
||||
# 3. 创建进度条
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
console=console,
|
||||
) as progress:
|
||||
self.progress = progress
|
||||
# 创建总任务
|
||||
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
|
||||
|
||||
# 依次生成每个文件
|
||||
for idx, file in enumerate(files, 1):
|
||||
logger.info(f"处理文件 [{idx}/{len(files)}]: {file}")
|
||||
|
||||
# 创建子任务(可选)
|
||||
file_task = progress.add_task(f"生成 {file}", total=None)
|
||||
|
||||
try:
|
||||
# 获取依赖文件
|
||||
deps = dependencies.get(file, [])
|
||||
|
||||
# 构造生成指令
|
||||
instruction = f"请根据README描述和依赖文件,生成文件 '{file}' 的完整代码。"
|
||||
|
||||
# 调用LLM生成代码
|
||||
code, desc, commands = self.generate_file(file, instruction, deps)
|
||||
|
||||
logger.info(f"生成完成: {file} - {desc}")
|
||||
|
||||
# 写入文件
|
||||
output_path = self.output_dir / file
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
logger.info(f"已写入: {output_path}")
|
||||
|
||||
# 执行命令
|
||||
for cmd in commands:
|
||||
logger.info(f"准备执行命令: {cmd}")
|
||||
self.execute_command(cmd, cwd=self.output_dir)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理文件 {file} 失败: {e}")
|
||||
# 可选:继续或终止
|
||||
raise
|
||||
finally:
|
||||
progress.remove_task(file_task)
|
||||
progress.update(total_task, advance=1)
|
||||
|
||||
logger.success("所有文件处理完成!")
|
||||
|
||||
# ==================== CLI入口 ====================
|
||||
@app.command()
|
||||
def main(
|
||||
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"),
|
||||
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录,默认为readme所在目录"),
|
||||
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥,也可通过环境变量DEEPSEEK_APIKEY设置"),
|
||||
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"),
|
||||
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
|
||||
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径(默认输出目录下generator.log)"),
|
||||
):
|
||||
"""
|
||||
根据README自动生成项目代码
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = readme.parent
|
||||
|
||||
try:
|
||||
generator = CodeGenerator(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
output_dir=output_dir,
|
||||
log_file=log_file,
|
||||
)
|
||||
generator.run(readme)
|
||||
except Exception as e:
|
||||
logger.error(f"程序异常退出: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 📦 **自动生成**:解析 `README.md`,分析需要生成的文件列表及依赖关系,按顺序生成每个文件的代码。
|
||||
- 🔧 **命令执行**:生成文件后可自动执行建议命令(如安装依赖、运行构建),内置危险命令拦截。
|
||||
- ✅ **单元测试**:使用 `pytest` 编写测试用例,支持测试覆盖率统计。
|
||||
- 🔍 **并行检查**:生成代码后并行运行多个检查工具(如 `pylint`、`mypy`、`black`),收集错误信息。
|
||||
- 🔄 **自修复**:将检查错误、`README` 和相关代码作为上下文提交给 LLM,自动生成修复补丁并应用。
|
||||
- ⏯️ **断点续写**:如果生成过程意外中断(如网络问题、API 限制),重新运行时会从上次中断处继续,已生成的文件和已执行的命令不会重复执行,状态自动保存在输出目录下的 `.llm_generator_state.json` 文件中。
|
||||
- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,参数兼容原脚本(`--output`、`--api-key`、`--model` 等)。
|
||||
- 📝 **详细日志**:所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。
|
||||
- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。
|
||||
|
||||
## 安装
|
||||
|
||||
### 依赖
|
||||
|
||||
- Python 3.9+
|
||||
- 使用 `uv` 管理包
|
||||
|
||||
```bash
|
||||
# 使用 uv
|
||||
uv add [dev]
|
||||
```
|
||||
|
||||
### 配置 API 密钥
|
||||
|
||||
设置环境变量(推荐):
|
||||
```bash
|
||||
export DEEPSEEK_APIKEY="your-api-key"
|
||||
```
|
||||
|
||||
或在命令行中通过 `--api-key` 传入。
|
||||
|
||||
## 使用方法
|
||||
|
||||
```bash
|
||||
llm-codegen [OPTIONS] README
|
||||
```
|
||||
|
||||
### 参数
|
||||
|
||||
| 参数 | 说明 |
|
||||
|------|------|
|
||||
| `README` | `README.md` 文件路径(必须) |
|
||||
| `--output, -o` | 输出根目录(默认:README 所在目录) |
|
||||
| `--api-key` | API 密钥(默认:环境变量 `DEEPSEEK_APIKEY`) |
|
||||
| `--base-url` | API 基础 URL(默认:`https://api.deepseek.com`) |
|
||||
| `--model, -m` | 使用的模型(默认:`deepseek-reasoner`) |
|
||||
| `--log` | 日志文件路径(默认:输出目录下 `generator.log`) |
|
||||
| `--resume/--no-resume` | 是否启用断点续写(默认:`--resume`,即自动从上次中断处继续) |
|
||||
| `--no-check` | 跳过生成后的检查和修复 |
|
||||
| `--help` | 显示帮助信息 |
|
||||
|
||||
### 示例
|
||||
|
||||
```bash
|
||||
llm-codegen my_project/README.md -o ./generated
|
||||
```
|
||||
|
||||
如果中途中断,只需再次运行相同的命令,工具会自动检测状态文件并从上次中断处继续生成。
|
||||
|
||||
## 项目结构
|
||||
|
||||
生成的项目将包含以下文件和目录:
|
||||
|
||||
```
|
||||
.
|
||||
├── README.md # 项目说明(原始输入)
|
||||
├── pyproject.toml # 项目元数据、依赖、脚本入口
|
||||
├── src/
|
||||
│ └── llm_codegen/ # 主代码包
|
||||
│ ├── __init__.py
|
||||
│ ├── cli.py # 命令行入口(typer)
|
||||
│ ├── core.py # 核心生成逻辑(CodeGenerator 类)
|
||||
│ ├── checker.py # 并行检查与修复模块
|
||||
│ ├── utils.py # 工具函数(危险命令判断、文件操作)
|
||||
│ └── models.py # 数据模型(Pydantic)
|
||||
├── tests/ # 单元测试
|
||||
│ ├── __init__.py
|
||||
│ ├── test_cli.py
|
||||
│ ├── test_core.py
|
||||
│ └── test_checker.py
|
||||
└── logs/ # 运行日志(自动创建)
|
||||
```
|
||||
|
||||
## 核心流程
|
||||
|
||||
1. **解析阶段**:读取 `README.md`,调用 LLM 获取 `files`(按生成顺序的文件路径列表)和 `dependencies`(每个文件依赖的已有文件列表)。
|
||||
2. **生成阶段**:按顺序生成每个文件,使用 `README` 和依赖文件作为上下文,同时获取 LLM 建议的命令。每成功生成一个文件并执行命令后,状态会自动保存到 `.llm_generator_state.json`。
|
||||
3. **命令执行**:对每个建议命令进行危险检查,低风险则执行。已执行的命令记录在状态文件中,避免重复执行。
|
||||
4. **检查阶段**(可选):生成完成后,并行运行配置的检查工具(如 `pytest`、`pylint`、`mypy`),收集错误。
|
||||
5. **修复阶段**(可选):若检查失败,将错误信息、`README` 和相关文件内容提交给 LLM,请求生成修复方案,并自动应用修改。重复直到检查通过或达到重试次数上限。
|
||||
|
||||
## 断点续写机制
|
||||
|
||||
- 状态文件保存在输出目录下的 `.llm_generator_state.json`,记录已成功生成的文件列表和已执行的命令。
|
||||
- 重新运行工具时(默认启用 `--resume`),会自动读取状态文件,跳过已完成的部分,从下一个文件开始继续。
|
||||
- 如果 `README` 发生重大变更导致文件列表不一致,工具会检测并提示用户重新开始(可通过 `--no-resume` 强制从头生成)。
|
||||
- 状态文件在全部流程成功完成后可手动删除,工具不会自动删除,以便后续查看或用于调试。
|
||||
|
||||
## 开发指南
|
||||
|
||||
### 环境设置
|
||||
|
||||
```bash
|
||||
# 安装 uv(若未安装)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# 创建虚拟环境并激活
|
||||
uv venv
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# 或 .venv\Scripts\activate # Windows
|
||||
|
||||
# 安装项目(可编辑模式)和开发依赖
|
||||
uv pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
pytest tests/ --cov=src/llm_codegen
|
||||
```
|
||||
|
||||
### 代码检查
|
||||
|
||||
```bash
|
||||
# 运行所有检查
|
||||
pre-commit run --all-files
|
||||
|
||||
# 或手动运行
|
||||
pylint src/llm_codegen
|
||||
mypy src/llm_codegen
|
||||
black --check src/llm_codegen
|
||||
```
|
||||
|
||||
### 添加新功能
|
||||
|
||||
1. 在 `src/llm_codegen/` 下添加或修改模块。
|
||||
2. 在 `tests/` 中添加对应的单元测试。
|
||||
3. 更新 `README.md` 和命令行帮助信息。
|
||||
|
||||
## 配置
|
||||
|
||||
通过 `pyproject.toml` 的 `[tool.llm-codegen]` 部分可以自定义检查工具和修复行为:
|
||||
|
||||
```toml
|
||||
[tool.llm-codegen]
|
||||
check_tools = ["pytest", "pylint", "mypy", "black"]
|
||||
max_retries = 3
|
||||
dangerous_commands = ["rm", "sudo", "chmod", "dd"]
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llm-codegen"
|
||||
version = "0.1.0"
|
||||
description = "基于大语言模型的自动化代码生成工具,根据README.md描述自动生成完整的Python包代码,具备代码检查、测试和自动修复能力。"
|
||||
authors = [
|
||||
{name = "Your Name", email = "your.email@example.com"}
|
||||
]
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"typer>=0.9.0",
|
||||
"rich>=13.0.0",
|
||||
"loguru>=0.7.0",
|
||||
"openai>=1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"pylint>=3.0.0",
|
||||
"mypy>=1.0.0",
|
||||
"black>=23.0.0",
|
||||
"pre-commit>=3.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/llm-codegen"
|
||||
|
||||
[project.scripts]
|
||||
llm-codegen = "llm_codegen.cli:app"
|
||||
|
||||
[tool.llm-codegen]
|
||||
check_tools = ["pytest", "pylint", "mypy", "black"]
|
||||
max_retries = 3
|
||||
dangerous_commands = ["rm", "sudo", "chmod", "dd"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py39']
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
addopts = "--cov=src/llm_codegen --cov-report=term-missing"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""
|
||||
LLM Code Generator package.
|
||||
|
||||
This package provides an automated code generation tool based on large language models (LLMs).
|
||||
It can generate complete Python package code from README descriptions, with features like code checking, testing, and auto-fixing.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "LLM CodeGen Team"
|
||||
__description__ = "A self-bootstrapping LLM-based code generation tool"
|
||||
|
||||
# Export main components for easy access from the package
|
||||
from .core import CodeGenerator
|
||||
from .cli import app
|
||||
from .utils import is_dangerous_command
|
||||
|
||||
__all__ = [
|
||||
"CodeGenerator",
|
||||
"app",
|
||||
"is_dangerous_command",
|
||||
]
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
checker.py - 并行检查与修复模块
|
||||
负责在代码生成后运行配置的检查工具(如pylint、mypy、black)并收集错误,
|
||||
然后使用LLM自动生成和应用修复补丁。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
|
||||
from .models import ConfigModel # 从models.py导入配置模型
|
||||
from .utils import safe_read_file, safe_write_file # 工具函数
|
||||
|
||||
|
||||
class Checker:
|
||||
"""
|
||||
检查与修复器类,提供并行运行检查工具和自动修复功能。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: Path,
|
||||
config: ConfigModel,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.deepseek.com",
|
||||
model: str = "deepseek-reasoner",
|
||||
):
|
||||
"""
|
||||
初始化检查器。
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录,包含生成的代码。
|
||||
config: 配置模型,包含check_tools、max_retries等。
|
||||
api_key: LLM API密钥,如果None则从环境变量DEEPSEEK_APIKEY获取。
|
||||
base_url: LLM API基础URL。
|
||||
model: LLM模型。
|
||||
"""
|
||||
self.output_dir = output_dir
|
||||
self.config = config
|
||||
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("API密钥未提供,请设置环境变量DEEPSEEK_APIKEY或传入api_key")
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.max_retries = config.max_retries
|
||||
|
||||
def run_check_tool(self, tool: str, file_path: Path) -> Tuple[bool, str]:
|
||||
"""
|
||||
运行单个检查工具并返回结果。
|
||||
|
||||
Args:
|
||||
tool: 工具名称,如"pylint"。
|
||||
file_path: 要检查的文件路径。
|
||||
|
||||
Returns:
|
||||
(success, output): 成功为True,错误输出字符串。
|
||||
"""
|
||||
commands = {
|
||||
"pylint": f"pylint {file_path}",
|
||||
"mypy": f"mypy {file_path}",
|
||||
"black": f"black --check {file_path}",
|
||||
"pytest": f"pytest {file_path}", # 假设检查测试文件
|
||||
}
|
||||
if tool not in commands:
|
||||
logger.warning(f"未知检查工具: {tool}")
|
||||
return True, "" # 跳过未知工具
|
||||
|
||||
cmd = commands[tool]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
cwd=self.output_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60, # 1分钟超时
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
else:
|
||||
output = result.stdout + result.stderr
|
||||
return False, output
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"检查工具 {tool} 超时")
|
||||
return False, "超时"
|
||||
except Exception as e:
|
||||
logger.error(f"运行检查工具 {tool} 失败: {e}")
|
||||
return False, str(e)
|
||||
|
||||
def run_parallel_checks(self, files: List[Path]) -> Dict[str, List[Tuple[str, str]]]:
|
||||
"""
|
||||
并行运行所有配置的检查工具。
|
||||
|
||||
Args:
|
||||
files: 要检查的文件路径列表。
|
||||
|
||||
Returns:
|
||||
错误字典,键为文件路径,值为列表,每个元素为(工具名, 错误输出)。
|
||||
"""
|
||||
errors = {}
|
||||
check_tools = self.config.check_tools
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for file in files:
|
||||
for tool in check_tools:
|
||||
future = executor.submit(self.run_check_tool, tool, file)
|
||||
futures.append((future, file, tool))
|
||||
|
||||
for future, file, tool in futures:
|
||||
success, output = future.result()
|
||||
if not success:
|
||||
if file not in errors:
|
||||
errors[file] = []
|
||||
errors[file].append((tool, output))
|
||||
|
||||
return errors
|
||||
|
||||
def call_llm_for_fix(self, file_path: Path, errors: List[Tuple[str, str]], readme_content: str) -> Optional[str]:
|
||||
"""
|
||||
调用LLM生成修复补丁。
|
||||
|
||||
Args:
|
||||
file_path: 需要修复的文件。
|
||||
errors: 错误列表,每个元素为(工具名, 错误输出)。
|
||||
readme_content: README内容。
|
||||
|
||||
Returns:
|
||||
修复后的代码字符串,如果失败返回None。
|
||||
"""
|
||||
error_summary = "\n".join([f"{tool}: {err}" for tool, err in errors])
|
||||
file_content = safe_read_file(file_path)
|
||||
|
||||
system_prompt = (
|
||||
"你是一个专业的代码修复助手。给定代码、错误信息和项目README,请生成修复后的完整代码。"
|
||||
"返回严格的JSON对象,包含字段:\n"
|
||||
"- code: (string) 修复后的完整代码\n"
|
||||
"- description: (string) 修复描述\n"
|
||||
)
|
||||
user_prompt = (
|
||||
f"项目README:\n{readme_content}\n\n"
|
||||
f"文件内容:\n{file_content}\n\n"
|
||||
f"错误信息:\n{error_summary}\n\n"
|
||||
"请生成修复后的代码,确保所有检查通过。"
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=0.2,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
return result.get("code")
|
||||
except Exception as e:
|
||||
logger.error(f"调用LLM修复失败: {e}")
|
||||
return None
|
||||
|
||||
def apply_fix(self, file_path: Path, new_code: str) -> bool:
|
||||
"""
|
||||
应用修复代码。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径。
|
||||
new_code: 新代码。
|
||||
|
||||
Returns:
|
||||
是否成功应用。
|
||||
"""
|
||||
try:
|
||||
safe_write_file(file_path, new_code)
|
||||
logger.info(f"已应用修复到 {file_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"应用修复失败: {e}")
|
||||
return False
|
||||
|
||||
def run_checks_and_fixes(self, readme_content: str, files: Optional[List[Path]] = None) -> bool:
|
||||
"""
|
||||
主方法:运行检查并自动修复。
|
||||
|
||||
Args:
|
||||
readme_content: README内容。
|
||||
files: 要检查的文件列表,如果None则检查output_dir下所有Python文件。
|
||||
|
||||
Returns:
|
||||
是否所有检查最终通过。
|
||||
"""
|
||||
if files is None:
|
||||
# 递归查找所有Python文件
|
||||
files = list(self.output_dir.rglob("*.py"))
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
logger.info(f"检查尝试 {attempt + 1}/{self.max_retries}")
|
||||
errors = self.run_parallel_checks(files)
|
||||
|
||||
if not errors:
|
||||
logger.success("所有检查通过!")
|
||||
return True
|
||||
|
||||
# 有错误,尝试修复
|
||||
logger.warning(f"发现 {len(errors)} 个文件有错误,尝试修复")
|
||||
all_fixed = True
|
||||
for file_path, error_list in errors.items():
|
||||
new_code = self.call_llm_for_fix(file_path, error_list, readme_content)
|
||||
if new_code:
|
||||
if self.apply_fix(file_path, new_code):
|
||||
# 修复后重新检查这个文件
|
||||
success, _ = self.run_check_tool("pylint", file_path) # 简化检查,重新运行一个工具
|
||||
if not success:
|
||||
all_fixed = False
|
||||
else:
|
||||
all_fixed = False
|
||||
else:
|
||||
all_fixed = False
|
||||
|
||||
if all_fixed:
|
||||
logger.info("修复成功,重新检查...")
|
||||
continue # 重新检查所有文件
|
||||
else:
|
||||
logger.error("修复失败或仍有错误")
|
||||
break
|
||||
|
||||
# 最终检查
|
||||
final_errors = self.run_parallel_checks(files)
|
||||
if not final_errors:
|
||||
logger.success("检查最终通过")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"检查失败,剩余错误: {final_errors}")
|
||||
return False
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import typer
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from .core import CodeGenerator
|
||||
|
||||
app = typer.Typer(help="基于LLM的自动化代码生成工具")
|
||||
console = Console()
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"),
|
||||
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录,默认为readme所在目录"),
|
||||
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥,也可通过环境变量DEEPSEEK_APIKEY设置"),
|
||||
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"),
|
||||
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
|
||||
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径(默认输出目录下generator.log)"),
|
||||
resume: bool = typer.Option(True, "--resume/--no-resume", help="是否启用断点续写(默认启用)"),
|
||||
no_check: bool = typer.Option(False, "--no-check", help="跳过生成后的检查和修复"),
|
||||
):
|
||||
"""
|
||||
根据README自动生成项目代码,支持断点续写和可选检查。
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = readme.parent
|
||||
|
||||
try:
|
||||
generator = CodeGenerator(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
output_dir=output_dir,
|
||||
log_file=log_file,
|
||||
resume=resume,
|
||||
config_path=None, # 配置文件路径,可从pyproject.toml加载,但CLI中暂不提供参数
|
||||
)
|
||||
generator.run(readme)
|
||||
|
||||
# 如果未跳过检查,提示用户检查功能暂未实现
|
||||
if not no_check:
|
||||
console.print("[yellow]注意:检查和修复功能暂未在此版本中实现,请手动运行检查工具(如pytest、pylint)。[/yellow]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"程序异常退出: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
||||
from rich.console import Console
|
||||
from openai import OpenAI
|
||||
|
||||
# 导入本地模块
|
||||
from .utils import is_dangerous_command, safe_read_file, safe_write_file, load_state, save_state, normalize_path, load_dangerous_commands
|
||||
from .models import GeneratorState, FileInfo, ProjectStructure, ConfigModel
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
"""
|
||||
核心代码生成器类,负责解析README、生成代码、执行命令并支持断点续写。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.deepseek.com",
|
||||
model: str = "deepseek-reasoner",
|
||||
output_dir: str = "./generated",
|
||||
log_file: Optional[str] = None,
|
||||
resume: bool = True,
|
||||
config_path: Optional[Path] = None,
|
||||
):
|
||||
"""
|
||||
初始化生成器。
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥,默认从环境变量DEEPSEEK_APIKEY读取。
|
||||
base_url: API基础URL。
|
||||
model: 使用的模型。
|
||||
output_dir: 输出根目录。
|
||||
log_file: 日志文件路径,默认自动生成。
|
||||
resume: 是否启用断点续写,默认为True。
|
||||
config_path: 配置文件路径,用于加载危险命令等配置。
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("必须提供API密钥,或设置环境变量DEEPSEEK_APIKEY")
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.resume = resume
|
||||
self.config_path = config_path
|
||||
|
||||
# 加载配置
|
||||
self.config = self._load_config()
|
||||
self.dangerous_commands = self.config.dangerous_commands
|
||||
|
||||
# 配置日志
|
||||
if log_file is None:
|
||||
log_file = self.output_dir / "generator.log"
|
||||
logger.remove() # 移除默认handler
|
||||
logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上
|
||||
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
|
||||
logger.info(f"日志已初始化,保存至: {log_file}")
|
||||
|
||||
self.readme_content = None
|
||||
self.state_path = self.output_dir / ".llm_generator_state.json"
|
||||
self.state: Optional[GeneratorState] = None
|
||||
self.progress: Optional[Progress] = None
|
||||
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
|
||||
self.console = Console()
|
||||
|
||||
def _load_config(self) -> ConfigModel:
|
||||
"""
|
||||
加载配置,如果配置文件不存在则使用默认值。
|
||||
"""
|
||||
try:
|
||||
# 简化实现:从环境或固定路径加载,实际应从pyproject.toml解析
|
||||
dangerous_cmds = load_dangerous_commands(self.config_path)
|
||||
return ConfigModel(
|
||||
check_tools=["pytest", "pylint", "mypy", "black"],
|
||||
max_retries=3,
|
||||
dangerous_commands=dangerous_cmds,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"加载配置失败,使用默认值: {e}")
|
||||
return ConfigModel()
|
||||
|
||||
def _call_llm(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = 0.2,
|
||||
expect_json: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用LLM并返回解析后的JSON。
|
||||
"""
|
||||
logger.debug(f"调用LLM,模型: {self.model}")
|
||||
logger.debug(f"System: {system_prompt[:200]}...")
|
||||
logger.debug(f"User: {user_prompt[:200]}...")
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=temperature,
|
||||
response_format={"type": "json_object"} if expect_json else None,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
content = message.content
|
||||
|
||||
# 记录思考过程(如果存在)
|
||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||||
logger.info(f"模型思考过程: {message.reasoning_content}")
|
||||
|
||||
logger.debug(f"LLM原始响应: {content[:500]}...")
|
||||
|
||||
if expect_json:
|
||||
result = json.loads(content)
|
||||
else:
|
||||
result = {"content": content}
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}")
|
||||
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM调用失败: {e}")
|
||||
raise
|
||||
|
||||
def parse_readme(self, readme_path: Path) -> str:
|
||||
"""
|
||||
读取README文件内容并计算哈希值用于断点续写检测。
|
||||
"""
|
||||
logger.info(f"读取README文件: {readme_path}")
|
||||
try:
|
||||
content = safe_read_file(readme_path)
|
||||
logger.debug(f"README内容长度: {len(content)} 字符")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"读取README失败: {e}")
|
||||
raise
|
||||
|
||||
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
||||
"""
|
||||
根据README内容,让LLM生成文件列表和依赖关系。
|
||||
|
||||
Returns:
|
||||
(files, dependencies)
|
||||
files: 按顺序需要生成的文件路径列表
|
||||
dependencies: 字典 {file: [依赖文件路径]}
|
||||
"""
|
||||
system_prompt = (
|
||||
"你是一个软件架构师。请根据README描述,分析需要生成哪些源代码文件,并确定它们的生成顺序,"
|
||||
"同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。"
|
||||
"返回严格的JSON对象,包含两个字段:\n"
|
||||
"- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n"
|
||||
"- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n"
|
||||
"注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。"
|
||||
)
|
||||
user_prompt = f"README内容如下:\n\n{self.readme_content}"
|
||||
|
||||
result = self._call_llm(system_prompt, user_prompt)
|
||||
|
||||
files = result.get("files", [])
|
||||
dependencies = result.get("dependencies", {})
|
||||
|
||||
if not files:
|
||||
raise ValueError("LLM未返回任何文件列表")
|
||||
|
||||
logger.info(f"解析到 {len(files)} 个待生成文件")
|
||||
logger.debug(f"文件列表: {files}")
|
||||
logger.debug(f"依赖关系: {dependencies}")
|
||||
|
||||
return files, dependencies
|
||||
|
||||
def generate_file(
|
||||
self,
|
||||
file_path: str,
|
||||
prompt_instruction: str,
|
||||
dependency_files: List[str],
|
||||
) -> Tuple[str, str, List[str]]:
|
||||
"""
|
||||
生成单个文件,返回 (代码, 描述, 命令列表)。
|
||||
"""
|
||||
# 读取依赖文件内容
|
||||
context_content = []
|
||||
|
||||
if self.readme_content:
|
||||
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
|
||||
|
||||
for dep in dependency_files:
|
||||
dep_path = Path(dep)
|
||||
if not dep_path.exists():
|
||||
# 尝试相对于输出目录查找
|
||||
alt_path = self.output_dir / dep
|
||||
if alt_path.exists():
|
||||
dep_path = alt_path
|
||||
else:
|
||||
raise FileNotFoundError(f"依赖文件不存在: {dep}")
|
||||
|
||||
content = safe_read_file(dep_path)
|
||||
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
|
||||
|
||||
full_context = "\n".join(context_content)
|
||||
|
||||
system_prompt = (
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||||
"返回严格的JSON对象,包含三个字段:\n"
|
||||
"- code: (string) 生成的完整代码\n"
|
||||
"- description: (string) 简短的中文功能描述\n"
|
||||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
||||
)
|
||||
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
||||
|
||||
result = self._call_llm(system_prompt, user_prompt)
|
||||
|
||||
code = result.get("code", "")
|
||||
description = result.get("description", "")
|
||||
commands = result.get("commands", [])
|
||||
|
||||
if not isinstance(commands, list):
|
||||
commands = []
|
||||
|
||||
return code, description, commands
|
||||
|
||||
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None:
|
||||
"""
|
||||
执行单个命令,检查风险。
|
||||
"""
|
||||
dangerous, reason = is_dangerous_command(cmd, self.dangerous_commands)
|
||||
if dangerous:
|
||||
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
|
||||
raise RuntimeError(f"危险命令: {cmd} ({reason})")
|
||||
|
||||
logger.info(f"执行命令: {cmd}")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
cwd=cwd or self.output_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
logger.debug(f"命令返回码: {result.returncode}")
|
||||
if result.stdout:
|
||||
logger.debug(f"stdout: {result.stdout[:500]}")
|
||||
if result.stderr:
|
||||
logger.warning(f"stderr: {result.stderr[:500]}")
|
||||
if result.returncode != 0:
|
||||
raise subprocess.CalledProcessError(result.returncode, cmd)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"命令执行超时: {cmd}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"命令执行失败: {e}")
|
||||
raise
|
||||
|
||||
def _update_state(self, generated_file: str, executed_commands: List[str]) -> None:
|
||||
"""
|
||||
更新断点续写状态。
|
||||
"""
|
||||
if self.state is None:
|
||||
self.state = GeneratorState()
|
||||
self.state.generated_files.append(generated_file)
|
||||
self.state.executed_commands.extend(executed_commands)
|
||||
self.state.updated_at = datetime.now()
|
||||
save_state(self.state_path, self.state.model_dump())
|
||||
|
||||
def run(self, readme_path: Path) -> None:
|
||||
"""
|
||||
主执行流程,支持断点续写。
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info("开始代码生成流程")
|
||||
logger.info(f"README: {readme_path}")
|
||||
logger.info(f"输出目录: {self.output_dir}")
|
||||
logger.info(f"断点续写: {self.resume}")
|
||||
|
||||
# 初始化阶段
|
||||
self.console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
|
||||
self.readme_content = self.parse_readme(readme_path)
|
||||
|
||||
# 加载或初始化状态
|
||||
if self.resume and self.state_path.exists():
|
||||
raw_state = load_state(self.state_path)
|
||||
self.state = GeneratorState(**raw_state) if raw_state else GeneratorState()
|
||||
logger.info(f"加载状态文件: {self.state_path}")
|
||||
# 检查README是否变更
|
||||
if self.state.readme_hash and self.state.readme_hash != hash(self.readme_content):
|
||||
logger.warning("README内容已变更,建议使用 --no-resume 重新开始")
|
||||
else:
|
||||
self.state = GeneratorState()
|
||||
|
||||
self.console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
|
||||
files, dependencies = self.get_project_structure()
|
||||
|
||||
# 过滤已生成的文件
|
||||
if self.resume and self.state:
|
||||
pending_files = [f for f in files if f not in self.state.generated_files]
|
||||
logger.info(f"跳过了 {len(files) - len(pending_files)} 个已生成文件,剩余 {len(pending_files)} 个")
|
||||
files = pending_files
|
||||
|
||||
if not files:
|
||||
logger.success("所有文件已生成,无需继续")
|
||||
return
|
||||
|
||||
self.console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
|
||||
|
||||
# 创建进度条
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
console=self.console,
|
||||
) as progress:
|
||||
self.progress = progress
|
||||
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
|
||||
|
||||
for idx, file in enumerate(files, 1):
|
||||
logger.info(f"处理文件 [{idx}/{len(files)}]: {file}")
|
||||
file_task = progress.add_task(f"生成 {file}", total=None)
|
||||
|
||||
try:
|
||||
deps = dependencies.get(file, [])
|
||||
instruction = f"请根据README描述和依赖文件,生成文件 '{file}' 的完整代码。"
|
||||
code, desc, commands = self.generate_file(file, instruction, deps)
|
||||
logger.info(f"生成完成: {file} - {desc}")
|
||||
|
||||
# 写入文件
|
||||
output_path = self.output_dir / file
|
||||
safe_write_file(output_path, code)
|
||||
logger.info(f"已写入: {output_path}")
|
||||
|
||||
# 执行命令,跳过已执行的
|
||||
executed_in_this_file = []
|
||||
for cmd in commands:
|
||||
if self.resume and self.state and cmd in self.state.executed_commands:
|
||||
logger.info(f"跳过已执行命令: {cmd}")
|
||||
continue
|
||||
logger.info(f"准备执行命令: {cmd}")
|
||||
self.execute_command(cmd, cwd=self.output_dir)
|
||||
executed_in_this_file.append(cmd)
|
||||
|
||||
# 更新状态
|
||||
self._update_state(file, executed_in_this_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理文件 {file} 失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
progress.remove_task(file_task)
|
||||
progress.update(total_task, advance=1)
|
||||
|
||||
logger.success("所有文件处理完成!")
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据模型定义模块,使用 Pydantic 进行数据验证和序列化。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GeneratorState(BaseModel):
|
||||
"""
|
||||
断点续写状态模型,用于保存和加载 .llm_generator_state.json 文件。
|
||||
"""
|
||||
generated_files: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="已成功生成的文件路径列表"
|
||||
)
|
||||
executed_commands: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="已执行的操作系统命令列表"
|
||||
)
|
||||
readme_hash: Optional[str] = Field(
|
||||
default=None,
|
||||
description="README 内容的哈希值,用于检测变更"
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.now,
|
||||
description="状态文件创建时间"
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.now,
|
||||
description="状态文件最后更新时间"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""
|
||||
生成文件的信息模型。
|
||||
"""
|
||||
path: str = Field(..., description="文件路径")
|
||||
code: Optional[str] = Field(default=None, description="生成的代码内容")
|
||||
description: Optional[str] = Field(default=None, description="文件功能描述")
|
||||
commands: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="生成后需要执行的命令列表"
|
||||
)
|
||||
|
||||
|
||||
class ProjectStructure(BaseModel):
|
||||
"""
|
||||
项目结构模型,包括文件列表和依赖关系。
|
||||
"""
|
||||
files: List[str] = Field(
|
||||
...,
|
||||
description="按生成顺序排列的文件路径列表"
|
||||
)
|
||||
dependencies: Dict[str, List[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="文件依赖关系,键为文件路径,值为依赖文件列表"
|
||||
)
|
||||
|
||||
|
||||
class CheckResult(BaseModel):
|
||||
"""
|
||||
检查工具的结果模型。
|
||||
"""
|
||||
tool: str = Field(..., description="检查工具名称,如 'pylint'")
|
||||
passed: bool = Field(..., description="检查是否通过")
|
||||
errors: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="错误信息列表"
|
||||
)
|
||||
warnings: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="警告信息列表"
|
||||
)
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
配置模型,对应 pyproject.toml 中的 [tool.llm-codegen] 部分。
|
||||
"""
|
||||
check_tools: List[str] = Field(
|
||||
default=["pytest", "pylint", "mypy", "black"],
|
||||
description="要运行的检查工具列表"
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="修复的最大重试次数"
|
||||
)
|
||||
dangerous_commands: List[str] = Field(
|
||||
default=["rm", "sudo", "chmod", "dd"],
|
||||
description="危险命令列表"
|
||||
)
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""
|
||||
utils.py - 工具函数模块
|
||||
包含危险命令判断、文件操作、状态管理等通用函数。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def is_dangerous_command(cmd: str, dangerous_commands: Optional[List[str]] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
判断命令是否危险。
|
||||
|
||||
Args:
|
||||
cmd: 要检查的命令字符串。
|
||||
dangerous_commands: 危险命令关键词列表,如果为None则使用默认列表。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否危险, 原因)
|
||||
"""
|
||||
if dangerous_commands is None:
|
||||
# 默认危险命令列表,可以从配置读取
|
||||
dangerous_commands = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||||
|
||||
cmd_lower = cmd.lower()
|
||||
for danger in dangerous_commands:
|
||||
if danger in cmd_lower:
|
||||
return True, f"包含危险关键词 '{danger}'"
|
||||
return False, ""
|
||||
|
||||
|
||||
def load_dangerous_commands(config_path: Optional[Path] = None) -> List[str]:
|
||||
"""
|
||||
从配置文件加载危险命令列表(简化实现,实际应从pyproject.toml读取)。
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为None,表示使用默认列表。
|
||||
|
||||
Returns:
|
||||
List[str]: 危险命令关键词列表。
|
||||
"""
|
||||
# 在实际实现中,应使用tomli库解析配置,这里返回默认列表
|
||||
return ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||||
|
||||
|
||||
def safe_read_file(file_path: Path, encoding: str = "utf-8") -> str:
|
||||
"""
|
||||
安全读取文件内容。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径。
|
||||
encoding: 文件编码,默认为utf-8。
|
||||
|
||||
Returns:
|
||||
str: 文件内容。
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 如果文件不存在。
|
||||
IOError: 如果读取失败。
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
except Exception as e:
|
||||
raise IOError(f"读取文件失败 {file_path}: {e}")
|
||||
|
||||
|
||||
def safe_write_file(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||
"""
|
||||
安全写入文件内容,确保目录存在。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径。
|
||||
content: 要写入的内容。
|
||||
encoding: 文件编码,默认为utf-8。
|
||||
"""
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(file_path, 'w', encoding=encoding) as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def load_state(state_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
加载断点续写状态文件。
|
||||
|
||||
Args:
|
||||
state_path: 状态文件路径(如.llm_generator_state.json)。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 状态数据,如果文件不存在或解析失败则返回空字典。
|
||||
"""
|
||||
if not state_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return {}
|
||||
|
||||
|
||||
def save_state(state_path: Path, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
保存断点续写状态文件。
|
||||
|
||||
Args:
|
||||
state_path: 状态文件路径。
|
||||
state: 状态数据。
|
||||
"""
|
||||
with open(state_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def normalize_path(path: str, base_dir: Optional[Path] = None) -> Path:
|
||||
"""
|
||||
规范化路径,相对于基础目录。
|
||||
|
||||
Args:
|
||||
path: 路径字符串。
|
||||
base_dir: 基础目录,默认为None(使用当前工作目录)。
|
||||
|
||||
Returns:
|
||||
Path: 规范化后的Path对象。
|
||||
"""
|
||||
if base_dir is None:
|
||||
base_dir = Path.cwd()
|
||||
return (base_dir / path).resolve()
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
"""
|
||||
Initialization file for the tests package of the LLM code generation tool.
|
||||
|
||||
This file marks the 'tests' directory as a Python package, enabling proper
|
||||
imports and test discovery with pytest.
|
||||
"""
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from src.llm_codegen.core import CodeGenerator
|
||||
from src.llm_codegen.models import GeneratorState, ConfigModel
|
||||
from src.llm_codegen.utils import is_dangerous_command
|
||||
|
||||
|
||||
class TestCodeGenerator:
|
||||
"""测试 CodeGenerator 核心类的单元测试。"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""模拟 OpenAI 客户端。"""
|
||||
with patch('src.llm_codegen.core.OpenAI') as mock:
|
||||
client = Mock()
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def generator(self, mock_openai_client, tmp_path):
|
||||
"""创建 CodeGenerator 实例,使用临时目录和模拟 API。"""
|
||||
output_dir = tmp_path / "output"
|
||||
return CodeGenerator(
|
||||
api_key="test-api-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
model="deepseek-reasoner",
|
||||
output_dir=str(output_dir),
|
||||
resume=False,
|
||||
)
|
||||
|
||||
def test_init(self, generator, tmp_path):
|
||||
"""测试初始化。"""
|
||||
assert generator.api_key == "test-api-key"
|
||||
assert generator.model == "deepseek-reasoner"
|
||||
assert generator.output_dir == tmp_path / "output"
|
||||
assert generator.resume is False
|
||||
assert isinstance(generator.config, ConfigModel)
|
||||
assert generator.dangerous_commands == ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||||
assert generator.state is None
|
||||
|
||||
def test_parse_readme(self, generator, tmp_path):
|
||||
"""测试读取 README 文件。"""
|
||||
readme_path = tmp_path / "README.md"
|
||||
readme_content = "# Test Project\nThis is a test README."
|
||||
readme_path.write_text(readme_content)
|
||||
|
||||
result = generator.parse_readme(readme_path)
|
||||
assert result == readme_content
|
||||
|
||||
def test_get_project_structure(self, generator, mock_openai_client):
|
||||
"""测试获取项目结构,模拟 LLM 响应。"""
|
||||
generator.readme_content = "# Test README"
|
||||
mock_response = {
|
||||
"files": ["src/__init__.py", "src/core.py"],
|
||||
"dependencies": {"src/core.py": ["src/__init__.py"]}
|
||||
}
|
||||
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
|
||||
|
||||
files, dependencies = generator.get_project_structure()
|
||||
assert files == ["src/__init__.py", "src/core.py"]
|
||||
assert dependencies == {"src/core.py": ["src/__init__.py"]}
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
def test_generate_file(self, generator, mock_openai_client, tmp_path):
|
||||
"""测试生成单个文件,模拟依赖文件和 LLM 响应。"""
|
||||
generator.readme_content = "# Test README"
|
||||
dep_file = tmp_path / "dep.txt"
|
||||
dep_file.write_text("Dependency content")
|
||||
|
||||
mock_response = {
|
||||
"code": "print('Hello, World!')",
|
||||
"description": "测试文件生成",
|
||||
"commands": ["echo 'test'"]
|
||||
}
|
||||
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
|
||||
|
||||
code, desc, commands = generator.generate_file(
|
||||
"test.py",
|
||||
"生成测试文件",
|
||||
[str(dep_file)]
|
||||
)
|
||||
assert code == "print('Hello, World!')"
|
||||
assert desc == "测试文件生成"
|
||||
assert commands == ["echo 'test'"]
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
def test_execute_command_safe(self, generator, tmp_path):
|
||||
"""测试执行安全命令。"""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = "output"
|
||||
mock_run.return_value.stderr = ""
|
||||
|
||||
generator.execute_command("echo 'test'", cwd=tmp_path)
|
||||
mock_run.assert_called_once_with(
|
||||
"echo 'test'",
|
||||
shell=True,
|
||||
cwd=tmp_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
def test_execute_command_dangerous(self, generator):
|
||||
"""测试阻止危险命令。"""
|
||||
with pytest.raises(RuntimeError, match="危险命令"):
|
||||
generator.execute_command("rm -rf /") # 假设在危险命令列表中
|
||||
|
||||
def test_run_without_resume(self, generator, mock_openai_client, tmp_path):
|
||||
"""测试完整运行流程,禁用断点续写。"""
|
||||
readme_path = tmp_path / "README.md"
|
||||
readme_path.write_text("# Test README")
|
||||
generator.readme_content = "# Test README"
|
||||
|
||||
# 模拟 get_project_structure 响应
|
||||
mock_structure = {
|
||||
"files": ["file1.py", "file2.py"],
|
||||
"dependencies": {}
|
||||
}
|
||||
mock_openai_client.chat.completions.create.side_effect = [
|
||||
Mock(choices=[Mock(message=Mock(content=json.dumps(mock_structure)))]),
|
||||
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code1", "description": "desc1", "commands": []})))]),
|
||||
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code2", "description": "desc2", "commands": []})))])
|
||||
]
|
||||
|
||||
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
|
||||
patch('src.llm_codegen.core.safe_read_file') as mock_read, \
|
||||
patch('src.llm_codegen.core.save_state') as mock_save:
|
||||
mock_read.return_value = "content"
|
||||
generator.run(readme_path)
|
||||
|
||||
# 验证文件生成和状态保存
|
||||
assert mock_write.call_count == 2
|
||||
assert mock_save.called
|
||||
|
||||
def test_run_with_resume(self, generator, mock_openai_client, tmp_path):
|
||||
"""测试断点续写功能。"""
|
||||
generator.resume = True
|
||||
generator.state = GeneratorState(generated_files=["file1.py"], executed_commands=[])
|
||||
readme_path = tmp_path / "README.md"
|
||||
readme_path.write_text("# Test README")
|
||||
generator.readme_content = "# Test README"
|
||||
|
||||
mock_structure = {
|
||||
"files": ["file1.py", "file2.py"],
|
||||
"dependencies": {}
|
||||
}
|
||||
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_structure)
|
||||
|
||||
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
|
||||
patch('src.llm_codegen.core.safe_read_file') as mock_read:
|
||||
mock_read.return_value = "content"
|
||||
generator.run(readme_path)
|
||||
|
||||
# 只应生成 file2.py,跳过 file1.py
|
||||
assert mock_write.call_count == 1
|
||||
|
||||
def test_load_config_default(self, generator):
|
||||
"""测试加载默认配置。"""
|
||||
config = generator._load_config()
|
||||
assert isinstance(config, ConfigModel)
|
||||
assert config.check_tools == ["pytest", "pylint", "mypy", "black"]
|
||||
assert config.max_retries == 3
|
||||
|
||||
def test_update_state(self, generator, tmp_path):
|
||||
"""测试更新状态文件。"""
|
||||
generator.state_path = tmp_path / "state.json"
|
||||
generator.state = GeneratorState()
|
||||
|
||||
with patch('src.llm_codegen.core.save_state') as mock_save:
|
||||
generator._update_state("new_file.py", ["cmd1"])
|
||||
assert generator.state.generated_files == ["new_file.py"]
|
||||
assert generator.state.executed_commands == ["cmd1"]
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Reference in New Issue