feat: 添加项目初始化文件和README模板以支持代码生成工具

This commit is contained in:
songsenand 2026-03-16 08:59:59 +08:00
commit d431b24755
12 changed files with 1690 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.13

533
README.md Normal file
View File

@ -0,0 +1,533 @@
# LLM 代码生成工具(自举版)
本项目是一个基于大语言模型的代码生成工具,能够根据项目 `README.md` 描述自动生成完整的 Python 包代码,并具备代码检查、测试和自动修复能力。它是前一个代码生成器的升级版本,采用 `uv` 进行包管理,包含完整的单元测试、并行检查模块,并可通过命令行直接调用。
## 特别说明
我已经实现了一个简易版本,请在此基础上拓展开发:
```python
#!/home/songsenand/env/.venv/bin/python
"""
基于LLM的自动化代码生成工具
根据README.md文件自动生成项目文件结构并填充代码执行必要命令。
"""
import json
import os
import subprocess
import sys
from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path
import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
from loguru import logger
from openai import OpenAI
# ==================== 配置 ====================
DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单
app = typer.Typer(help="基于LLM的自动化代码生成工具")
console = Console()
# ==================== 工具函数 ====================
def is_dangerous_command(cmd: str) -> Tuple[bool, str]:
"""
判断命令是否危险
返回 (是否危险, 原因)
"""
cmd_lower = cmd.lower()
for danger in DANGEROUS_COMMANDS:
if danger in cmd_lower:
return True, f"包含危险关键词 '{danger}'"
return False, ""
# ==================== 核心类 ====================
class CodeGenerator:
"""代码生成器,封装所有逻辑"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "https://api.deepseek.com",
model: str = "deepseek-reasoner",
output_dir: str = "./generated",
log_file: Optional[str] = None,
):
"""
初始化生成器
Args:
api_key: OpenAI API密钥默认从环境变量DEEPSEEK_APIKEY读取
base_url: API基础URL
model: 使用的模型
output_dir: 输出根目录
log_file: 日志文件路径,默认自动生成
"""
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
if not self.api_key:
raise ValueError("必须提供API密钥或设置环境变量DEEPSEEK_APIKEY")
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
self.model = model
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# 配置日志
if log_file is None:
log_file = self.output_dir / "generator.log"
logger.remove() # 移除默认handler
logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
logger.info(f"日志已初始化,保存至: {log_file}")
self.readme_content = None
self.progress: Optional[Progress] = None
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
def _call_llm(
self,
system_prompt: str,
user_prompt: str,
temperature: float = 0.2,
expect_json: bool = True,
) -> Dict[str, Any]:
"""
调用LLM并返回解析后的JSON
"""
logger.debug(f"调用LLM模型: {self.model}")
logger.debug(f"System: {system_prompt[:200]}...")
logger.debug(f"User: {user_prompt[:200]}...")
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
response_format={"type": "json_object"} if expect_json else None,
)
message = response.choices[0].message
content = message.content
# 记录思考过程(如果存在)
if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f"模型思考过程: {message.reasoning_content}")
logger.debug(f"LLM原始响应: {content[:500]}...")
if expect_json:
result = json.loads(content)
else:
result = {"content": content}
return result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
except Exception as e:
logger.error(f"LLM调用失败: {e}")
raise
def parse_readme(self, readme_path: Path) -> str:
"""
读取README文件内容
"""
logger.info(f"读取README文件: {readme_path}")
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
logger.debug(f"README内容长度: {len(content)} 字符")
return content
except Exception as e:
logger.error(f"读取README失败: {e}")
raise
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
"""
根据README内容让LLM生成文件列表和依赖关系
Returns:
(files, dependencies)
files: 按顺序需要生成的文件路径列表
dependencies: 字典 {file: [依赖文件路径]}
"""
system_prompt = (
"你是一个软件架构师。请根据README描述分析需要生成哪些源代码文件并确定它们的生成顺序"
"同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。"
"返回严格的JSON对象包含两个字段\n"
"- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n"
"- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n"
"注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。"
)
user_prompt = f"README内容如下\n\n{self.readme_content}"
result = self._call_llm(system_prompt, user_prompt)
files = result.get("files", [])
dependencies = result.get("dependencies", {})
if not files:
raise ValueError("LLM未返回任何文件列表")
logger.info(f"解析到 {len(files)} 个待生成文件")
logger.debug(f"文件列表: {files}")
logger.debug(f"依赖关系: {dependencies}")
return files, dependencies
def generate_file(
self,
file_path: str,
prompt_instruction: str,
dependency_files: List[str],
) -> Tuple[str, str, List[str]]:
"""
生成单个文件,返回 (代码, 描述, 命令列表)
"""
# 读取依赖文件内容
context_content = []
if self.readme_content:
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
for dep in dependency_files:
dep_path = Path(dep)
if not dep_path.exists():
# 尝试相对于当前目录或输出目录查找
alt_path = self.output_dir / dep
if alt_path.exists():
dep_path = alt_path
else:
raise FileNotFoundError(f"依赖文件不存在: {dep}")
with open(dep_path, "r", encoding="utf-8") as f:
content = f.read()
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
full_context = "\n".join(context_content)
system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
"返回严格的JSON对象包含三个字段\n"
"- code: (string) 生成的完整代码\n"
"- description: (string) 简短的中文功能描述\n"
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
)
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
result = self._call_llm(system_prompt, user_prompt)
code = result.get("code", "")
description = result.get("description", "")
commands = result.get("commands", [])
if not isinstance(commands, list):
commands = []
return code, description, commands
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None:
"""
执行单个命令,检查风险
"""
dangerous, reason = is_dangerous_command(cmd)
if dangerous:
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
raise RuntimeError(f"危险命令: {cmd} ({reason})")
logger.info(f"执行命令: {cmd}")
try:
result = subprocess.run(
cmd,
shell=True,
cwd=cwd or self.output_dir,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
logger.debug(f"命令返回码: {result.returncode}")
if result.stdout:
logger.debug(f"stdout: {result.stdout[:500]}")
if result.stderr:
logger.warning(f"stderr: {result.stderr[:500]}")
if result.returncode != 0:
raise subprocess.CalledProcessError(result.returncode, cmd)
except subprocess.TimeoutExpired:
logger.error(f"命令执行超时: {cmd}")
raise
except Exception as e:
logger.error(f"命令执行失败: {e}")
raise
def run(self, readme_path: Path):
"""
主执行流程
"""
logger.info("=" * 50)
logger.info("开始代码生成流程")
logger.info(f"README: {readme_path}")
logger.info(f"输出目录: {self.output_dir}")
# 初始化阶段用rich输出状态不会被日志级别过滤
console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
self.readme_content = self.parse_readme(readme_path)
console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
files, dependencies = self.get_project_structure()
console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
# 3. 创建进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console,
) as progress:
self.progress = progress
# 创建总任务
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
# 依次生成每个文件
for idx, file in enumerate(files, 1):
logger.info(f"处理文件 [{idx}/{len(files)}]: {file}")
# 创建子任务(可选)
file_task = progress.add_task(f"生成 {file}", total=None)
try:
# 获取依赖文件
deps = dependencies.get(file, [])
# 构造生成指令
instruction = f"请根据README描述和依赖文件生成文件 '{file}' 的完整代码。"
# 调用LLM生成代码
code, desc, commands = self.generate_file(file, instruction, deps)
logger.info(f"生成完成: {file} - {desc}")
# 写入文件
output_path = self.output_dir / file
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"已写入: {output_path}")
# 执行命令
for cmd in commands:
logger.info(f"准备执行命令: {cmd}")
self.execute_command(cmd, cwd=self.output_dir)
except Exception as e:
logger.error(f"处理文件 {file} 失败: {e}")
# 可选:继续或终止
raise
finally:
progress.remove_task(file_task)
progress.update(total_task, advance=1)
logger.success("所有文件处理完成!")
# ==================== CLI入口 ====================
@app.command()
def main(
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"),
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录默认为readme所在目录"),
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥也可通过环境变量DEEPSEEK_APIKEY设置"),
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径默认输出目录下generator.log"),
):
"""
根据README自动生成项目代码
"""
if output_dir is None:
output_dir = readme.parent
try:
generator = CodeGenerator(
api_key=api_key,
base_url=base_url,
model=model,
output_dir=output_dir,
log_file=log_file,
)
generator.run(readme)
except Exception as e:
logger.error(f"程序异常退出: {e}")
raise typer.Exit(code=1)
if __name__ == "__main__":
app()
```
## 功能特性
- 📦 **自动生成**:解析 `README.md`,分析需要生成的文件列表及依赖关系,按顺序生成每个文件的代码。
- 🔧 **命令执行**:生成文件后可自动执行建议命令(如安装依赖、运行构建),内置危险命令拦截。
- ✅ **单元测试**:使用 `pytest` 编写测试用例,支持测试覆盖率统计。
- 🔍 **并行检查**:生成代码后并行运行多个检查工具(如 `pylint`、`mypy`、`black`),收集错误信息。
- 🔄 **自修复**:将检查错误、`README` 和相关代码作为上下文提交给 LLM自动生成修复补丁并应用。
- ⏯️ **断点续写**如果生成过程意外中断如网络问题、API 限制),重新运行时会从上次中断处继续,已生成的文件和已执行的命令不会重复执行,状态自动保存在输出目录下的 `.llm_generator_state.json` 文件中。
- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,参数兼容原脚本(`--output`、`--api-key`、`--model` 等)。
- 📝 **详细日志**所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。
- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。
## 安装
### 依赖
- Python 3.9+
- 使用 `uv` 管理包
```bash
# 使用 uv
uv add [dev]
```
### 配置 API 密钥
设置环境变量(推荐):
```bash
export DEEPSEEK_APIKEY="your-api-key"
```
或在命令行中通过 `--api-key` 传入。
## 使用方法
```bash
llm-codegen [OPTIONS] README
```
### 参数
| 参数 | 说明 |
|------|------|
| `README` | `README.md` 文件路径(必须) |
| `--output, -o` | 输出根目录默认README 所在目录) |
| `--api-key` | API 密钥(默认:环境变量 `DEEPSEEK_APIKEY` |
| `--base-url` | API 基础 URL默认`https://api.deepseek.com` |
| `--model, -m` | 使用的模型(默认:`deepseek-reasoner` |
| `--log` | 日志文件路径(默认:输出目录下 `generator.log` |
| `--resume/--no-resume` | 是否启用断点续写(默认:`--resume`,即自动从上次中断处继续) |
| `--no-check` | 跳过生成后的检查和修复 |
| `--help` | 显示帮助信息 |
### 示例
```bash
llm-codegen my_project/README.md -o ./generated
```
如果中途中断,只需再次运行相同的命令,工具会自动检测状态文件并从上次中断处继续生成。
## 项目结构
生成的项目将包含以下文件和目录:
```
.
├── README.md # 项目说明(原始输入)
├── pyproject.toml # 项目元数据、依赖、脚本入口
├── src/
│ └── llm_codegen/ # 主代码包
│ ├── __init__.py
│ ├── cli.py # 命令行入口typer
│ ├── core.py # 核心生成逻辑CodeGenerator 类)
│ ├── checker.py # 并行检查与修复模块
│ ├── utils.py # 工具函数(危险命令判断、文件操作)
│ └── models.py # 数据模型Pydantic
├── tests/ # 单元测试
│ ├── __init__.py
│ ├── test_cli.py
│ ├── test_core.py
│ └── test_checker.py
└── logs/ # 运行日志(自动创建)
```
## 核心流程
1. **解析阶段**:读取 `README.md`,调用 LLM 获取 `files`(按生成顺序的文件路径列表)和 `dependencies`(每个文件依赖的已有文件列表)。
2. **生成阶段**:按顺序生成每个文件,使用 `README` 和依赖文件作为上下文,同时获取 LLM 建议的命令。每成功生成一个文件并执行命令后,状态会自动保存到 `.llm_generator_state.json`
3. **命令执行**:对每个建议命令进行危险检查,低风险则执行。已执行的命令记录在状态文件中,避免重复执行。
4. **检查阶段**(可选):生成完成后,并行运行配置的检查工具(如 `pytest`、`pylint`、`mypy`),收集错误。
5. **修复阶段**(可选):若检查失败,将错误信息、`README` 和相关文件内容提交给 LLM请求生成修复方案并自动应用修改。重复直到检查通过或达到重试次数上限。
## 断点续写机制
- 状态文件保存在输出目录下的 `.llm_generator_state.json`,记录已成功生成的文件列表和已执行的命令。
- 重新运行工具时(默认启用 `--resume`),会自动读取状态文件,跳过已完成的部分,从下一个文件开始继续。
- 如果 `README` 发生重大变更导致文件列表不一致,工具会检测并提示用户重新开始(可通过 `--no-resume` 强制从头生成)。
- 状态文件在全部流程成功完成后可手动删除,工具不会自动删除,以便后续查看或用于调试。
## 开发指南
### 环境设置
```bash
# 安装 uv若未安装
curl -LsSf https://astral.sh/uv/install.sh | sh
# 创建虚拟环境并激活
uv venv
source .venv/bin/activate # Linux/macOS
# 或 .venv\Scripts\activate # Windows
# 安装项目(可编辑模式)和开发依赖
uv pip install -e ".[dev]"
```
### 运行测试
```bash
pytest tests/ --cov=src/llm_codegen
```
### 代码检查
```bash
# 运行所有检查
pre-commit run --all-files
# 或手动运行
pylint src/llm_codegen
mypy src/llm_codegen
black --check src/llm_codegen
```
### 添加新功能
1. 在 `src/llm_codegen/` 下添加或修改模块。
2. 在 `tests/` 中添加对应的单元测试。
3. 更新 `README.md` 和命令行帮助信息。
## 配置
通过 `pyproject.toml``[tool.llm-codegen]` 部分可以自定义检查工具和修复行为:
```toml
[tool.llm-codegen]
check_tools = ["pytest", "pylint", "mypy", "black"]
max_retries = 3
dangerous_commands = ["rm", "sudo", "chmod", "dd"]
```

49
pyproject.toml Normal file
View File

@ -0,0 +1,49 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "llm-codegen"
version = "0.1.0"
description = "基于大语言模型的自动化代码生成工具根据README.md描述自动生成完整的Python包代码具备代码检查、测试和自动修复能力。"
authors = [
{name = "Your Name", email = "your.email@example.com"}
]
readme = "README.md"
license = {text = "MIT"}
requires-python = ">=3.9"
dependencies = [
"typer>=0.9.0",
"rich>=13.0.0",
"loguru>=0.7.0",
"openai>=1.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"pylint>=3.0.0",
"mypy>=1.0.0",
"black>=23.0.0",
"pre-commit>=3.0.0",
]
[project.urls]
Homepage = "https://github.com/yourusername/llm-codegen"
[project.scripts]
llm-codegen = "llm_codegen.cli:app"
[tool.llm-codegen]
check_tools = ["pytest", "pylint", "mypy", "black"]
max_retries = 3
dangerous_commands = ["rm", "sudo", "chmod", "dd"]
[tool.black]
line-length = 88
target-version = ['py39']
[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "--cov=src/llm_codegen --cov-report=term-missing"

View File

@ -0,0 +1,21 @@
"""
LLM Code Generator package.
This package provides an automated code generation tool based on large language models (LLMs).
It can generate complete Python package code from README descriptions, with features like code checking, testing, and auto-fixing.
"""
__version__ = "0.1.0"
__author__ = "LLM CodeGen Team"
__description__ = "A self-bootstrapping LLM-based code generation tool"
# Export main components for easy access from the package
from .core import CodeGenerator
from .cli import app
from .utils import is_dangerous_command
__all__ = [
"CodeGenerator",
"app",
"is_dangerous_command",
]

242
src/llm_codegen/checker.py Normal file
View File

@ -0,0 +1,242 @@
"""
checker.py - 并行检查与修复模块
负责在代码生成后运行配置的检查工具如pylintmypyblack并收集错误
然后使用LLM自动生成和应用修复补丁
"""
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Dict, Optional, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
from openai import OpenAI
from .models import ConfigModel # 从models.py导入配置模型
from .utils import safe_read_file, safe_write_file # 工具函数
class Checker:
"""
检查与修复器类提供并行运行检查工具和自动修复功能
"""
def __init__(
self,
output_dir: Path,
config: ConfigModel,
api_key: Optional[str] = None,
base_url: str = "https://api.deepseek.com",
model: str = "deepseek-reasoner",
):
"""
初始化检查器
Args:
output_dir: 输出目录包含生成的代码
config: 配置模型包含check_toolsmax_retries等
api_key: LLM API密钥如果None则从环境变量DEEPSEEK_APIKEY获取
base_url: LLM API基础URL
model: LLM模型
"""
self.output_dir = output_dir
self.config = config
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
if not self.api_key:
raise ValueError("API密钥未提供请设置环境变量DEEPSEEK_APIKEY或传入api_key")
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
self.model = model
self.max_retries = config.max_retries
def run_check_tool(self, tool: str, file_path: Path) -> Tuple[bool, str]:
"""
运行单个检查工具并返回结果
Args:
tool: 工具名称"pylint"
file_path: 要检查的文件路径
Returns:
(success, output): 成功为True错误输出字符串
"""
commands = {
"pylint": f"pylint {file_path}",
"mypy": f"mypy {file_path}",
"black": f"black --check {file_path}",
"pytest": f"pytest {file_path}", # 假设检查测试文件
}
if tool not in commands:
logger.warning(f"未知检查工具: {tool}")
return True, "" # 跳过未知工具
cmd = commands[tool]
try:
result = subprocess.run(
cmd,
shell=True,
cwd=self.output_dir,
capture_output=True,
text=True,
timeout=60, # 1分钟超时
)
if result.returncode == 0:
return True, ""
else:
output = result.stdout + result.stderr
return False, output
except subprocess.TimeoutExpired:
logger.error(f"检查工具 {tool} 超时")
return False, "超时"
except Exception as e:
logger.error(f"运行检查工具 {tool} 失败: {e}")
return False, str(e)
def run_parallel_checks(self, files: List[Path]) -> Dict[str, List[Tuple[str, str]]]:
"""
并行运行所有配置的检查工具
Args:
files: 要检查的文件路径列表
Returns:
错误字典键为文件路径值为列表每个元素为(工具名, 错误输出)
"""
errors = {}
check_tools = self.config.check_tools
with ThreadPoolExecutor() as executor:
futures = []
for file in files:
for tool in check_tools:
future = executor.submit(self.run_check_tool, tool, file)
futures.append((future, file, tool))
for future, file, tool in futures:
success, output = future.result()
if not success:
if file not in errors:
errors[file] = []
errors[file].append((tool, output))
return errors
def call_llm_for_fix(self, file_path: Path, errors: List[Tuple[str, str]], readme_content: str) -> Optional[str]:
"""
调用LLM生成修复补丁
Args:
file_path: 需要修复的文件
errors: 错误列表每个元素为(工具名, 错误输出)
readme_content: README内容
Returns:
修复后的代码字符串如果失败返回None
"""
error_summary = "\n".join([f"{tool}: {err}" for tool, err in errors])
file_content = safe_read_file(file_path)
system_prompt = (
"你是一个专业的代码修复助手。给定代码、错误信息和项目README请生成修复后的完整代码。"
"返回严格的JSON对象包含字段\n"
"- code: (string) 修复后的完整代码\n"
"- description: (string) 修复描述\n"
)
user_prompt = (
f"项目README:\n{readme_content}\n\n"
f"文件内容:\n{file_content}\n\n"
f"错误信息:\n{error_summary}\n\n"
"请生成修复后的代码,确保所有检查通过。"
)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.2,
response_format={"type": "json_object"},
)
result = json.loads(response.choices[0].message.content)
return result.get("code")
except Exception as e:
logger.error(f"调用LLM修复失败: {e}")
return None
def apply_fix(self, file_path: Path, new_code: str) -> bool:
"""
应用修复代码
Args:
file_path: 文件路径
new_code: 新代码
Returns:
是否成功应用
"""
try:
safe_write_file(file_path, new_code)
logger.info(f"已应用修复到 {file_path}")
return True
except Exception as e:
logger.error(f"应用修复失败: {e}")
return False
def run_checks_and_fixes(self, readme_content: str, files: Optional[List[Path]] = None) -> bool:
"""
主方法运行检查并自动修复
Args:
readme_content: README内容
files: 要检查的文件列表如果None则检查output_dir下所有Python文件
Returns:
是否所有检查最终通过
"""
if files is None:
# 递归查找所有Python文件
files = list(self.output_dir.rglob("*.py"))
for attempt in range(self.max_retries):
logger.info(f"检查尝试 {attempt + 1}/{self.max_retries}")
errors = self.run_parallel_checks(files)
if not errors:
logger.success("所有检查通过!")
return True
# 有错误,尝试修复
logger.warning(f"发现 {len(errors)} 个文件有错误,尝试修复")
all_fixed = True
for file_path, error_list in errors.items():
new_code = self.call_llm_for_fix(file_path, error_list, readme_content)
if new_code:
if self.apply_fix(file_path, new_code):
# 修复后重新检查这个文件
success, _ = self.run_check_tool("pylint", file_path) # 简化检查,重新运行一个工具
if not success:
all_fixed = False
else:
all_fixed = False
else:
all_fixed = False
if all_fixed:
logger.info("修复成功,重新检查...")
continue # 重新检查所有文件
else:
logger.error("修复失败或仍有错误")
break
# 最终检查
final_errors = self.run_parallel_checks(files)
if not final_errors:
logger.success("检查最终通过")
return True
else:
logger.error(f"检查失败,剩余错误: {final_errors}")
return False

49
src/llm_codegen/cli.py Normal file
View File

@ -0,0 +1,49 @@
import typer
from pathlib import Path
from typing import Optional
from loguru import logger
from rich.console import Console
from .core import CodeGenerator
app = typer.Typer(help="基于LLM的自动化代码生成工具")
console = Console()
@app.command()
def main(
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"),
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录默认为readme所在目录"),
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥也可通过环境变量DEEPSEEK_APIKEY设置"),
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径默认输出目录下generator.log"),
resume: bool = typer.Option(True, "--resume/--no-resume", help="是否启用断点续写(默认启用)"),
no_check: bool = typer.Option(False, "--no-check", help="跳过生成后的检查和修复"),
):
"""
根据README自动生成项目代码支持断点续写和可选检查
"""
if output_dir is None:
output_dir = readme.parent
try:
generator = CodeGenerator(
api_key=api_key,
base_url=base_url,
model=model,
output_dir=output_dir,
log_file=log_file,
resume=resume,
config_path=None, # 配置文件路径可从pyproject.toml加载但CLI中暂不提供参数
)
generator.run(readme)
# 如果未跳过检查,提示用户检查功能暂未实现
if not no_check:
console.print("[yellow]注意检查和修复功能暂未在此版本中实现请手动运行检查工具如pytest、pylint。[/yellow]")
except Exception as e:
logger.error(f"程序异常退出: {e}")
raise typer.Exit(code=1)
if __name__ == "__main__":
app()

365
src/llm_codegen/core.py Normal file
View File

@ -0,0 +1,365 @@
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Dict, Optional, Any, Tuple
from loguru import logger
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
from rich.console import Console
from openai import OpenAI
# 导入本地模块
from .utils import is_dangerous_command, safe_read_file, safe_write_file, load_state, save_state, normalize_path, load_dangerous_commands
from .models import GeneratorState, FileInfo, ProjectStructure, ConfigModel
class CodeGenerator:
"""
核心代码生成器类负责解析README生成代码执行命令并支持断点续写
"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "https://api.deepseek.com",
model: str = "deepseek-reasoner",
output_dir: str = "./generated",
log_file: Optional[str] = None,
resume: bool = True,
config_path: Optional[Path] = None,
):
"""
初始化生成器
Args:
api_key: OpenAI API密钥默认从环境变量DEEPSEEK_APIKEY读取
base_url: API基础URL
model: 使用的模型
output_dir: 输出根目录
log_file: 日志文件路径默认自动生成
resume: 是否启用断点续写默认为True
config_path: 配置文件路径用于加载危险命令等配置
"""
self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY")
if not self.api_key:
raise ValueError("必须提供API密钥或设置环境变量DEEPSEEK_APIKEY")
self.client = OpenAI(api_key=self.api_key, base_url=base_url)
self.model = model
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.resume = resume
self.config_path = config_path
# 加载配置
self.config = self._load_config()
self.dangerous_commands = self.config.dangerous_commands
# 配置日志
if log_file is None:
log_file = self.output_dir / "generator.log"
logger.remove() # 移除默认handler
logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上
logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG
logger.info(f"日志已初始化,保存至: {log_file}")
self.readme_content = None
self.state_path = self.output_dir / ".llm_generator_state.json"
self.state: Optional[GeneratorState] = None
self.progress: Optional[Progress] = None
self.tasks: Dict[str, TaskID] = {} # 任务ID映射
self.console = Console()
def _load_config(self) -> ConfigModel:
"""
加载配置如果配置文件不存在则使用默认值
"""
try:
# 简化实现从环境或固定路径加载实际应从pyproject.toml解析
dangerous_cmds = load_dangerous_commands(self.config_path)
return ConfigModel(
check_tools=["pytest", "pylint", "mypy", "black"],
max_retries=3,
dangerous_commands=dangerous_cmds,
)
except Exception as e:
logger.warning(f"加载配置失败,使用默认值: {e}")
return ConfigModel()
def _call_llm(
self,
system_prompt: str,
user_prompt: str,
temperature: float = 0.2,
expect_json: bool = True,
) -> Dict[str, Any]:
"""
调用LLM并返回解析后的JSON
"""
logger.debug(f"调用LLM模型: {self.model}")
logger.debug(f"System: {system_prompt[:200]}...")
logger.debug(f"User: {user_prompt[:200]}...")
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
response_format={"type": "json_object"} if expect_json else None,
)
message = response.choices[0].message
content = message.content
# 记录思考过程(如果存在)
if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f"模型思考过程: {message.reasoning_content}")
logger.debug(f"LLM原始响应: {content[:500]}...")
if expect_json:
result = json.loads(content)
else:
result = {"content": content}
return result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}")
except Exception as e:
logger.error(f"LLM调用失败: {e}")
raise
def parse_readme(self, readme_path: Path) -> str:
"""
读取README文件内容并计算哈希值用于断点续写检测
"""
logger.info(f"读取README文件: {readme_path}")
try:
content = safe_read_file(readme_path)
logger.debug(f"README内容长度: {len(content)} 字符")
return content
except Exception as e:
logger.error(f"读取README失败: {e}")
raise
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
"""
根据README内容让LLM生成文件列表和依赖关系
Returns:
(files, dependencies)
files: 按顺序需要生成的文件路径列表
dependencies: 字典 {file: [依赖文件路径]}
"""
system_prompt = (
"你是一个软件架构师。请根据README描述分析需要生成哪些源代码文件并确定它们的生成顺序"
"同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。"
"返回严格的JSON对象包含两个字段\n"
"- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n"
"- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n"
"注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。"
)
user_prompt = f"README内容如下\n\n{self.readme_content}"
result = self._call_llm(system_prompt, user_prompt)
files = result.get("files", [])
dependencies = result.get("dependencies", {})
if not files:
raise ValueError("LLM未返回任何文件列表")
logger.info(f"解析到 {len(files)} 个待生成文件")
logger.debug(f"文件列表: {files}")
logger.debug(f"依赖关系: {dependencies}")
return files, dependencies
def generate_file(
self,
file_path: str,
prompt_instruction: str,
dependency_files: List[str],
) -> Tuple[str, str, List[str]]:
"""
生成单个文件返回 (代码, 描述, 命令列表)
"""
# 读取依赖文件内容
context_content = []
if self.readme_content:
context_content.append(f"### 项目 README ###\n{self.readme_content}\n")
for dep in dependency_files:
dep_path = Path(dep)
if not dep_path.exists():
# 尝试相对于输出目录查找
alt_path = self.output_dir / dep
if alt_path.exists():
dep_path = alt_path
else:
raise FileNotFoundError(f"依赖文件不存在: {dep}")
content = safe_read_file(dep_path)
context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n")
full_context = "\n".join(context_content)
system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
"返回严格的JSON对象包含三个字段\n"
"- code: (string) 生成的完整代码\n"
"- description: (string) 简短的中文功能描述\n"
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
)
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
result = self._call_llm(system_prompt, user_prompt)
code = result.get("code", "")
description = result.get("description", "")
commands = result.get("commands", [])
if not isinstance(commands, list):
commands = []
return code, description, commands
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None:
"""
执行单个命令检查风险
"""
dangerous, reason = is_dangerous_command(cmd, self.dangerous_commands)
if dangerous:
logger.error(f"危险命令被阻止: {cmd},原因: {reason}")
raise RuntimeError(f"危险命令: {cmd} ({reason})")
logger.info(f"执行命令: {cmd}")
try:
result = subprocess.run(
cmd,
shell=True,
cwd=cwd or self.output_dir,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
logger.debug(f"命令返回码: {result.returncode}")
if result.stdout:
logger.debug(f"stdout: {result.stdout[:500]}")
if result.stderr:
logger.warning(f"stderr: {result.stderr[:500]}")
if result.returncode != 0:
raise subprocess.CalledProcessError(result.returncode, cmd)
except subprocess.TimeoutExpired:
logger.error(f"命令执行超时: {cmd}")
raise
except Exception as e:
logger.error(f"命令执行失败: {e}")
raise
def _update_state(self, generated_file: str, executed_commands: List[str]) -> None:
"""
更新断点续写状态
"""
if self.state is None:
self.state = GeneratorState()
self.state.generated_files.append(generated_file)
self.state.executed_commands.extend(executed_commands)
self.state.updated_at = datetime.now()
save_state(self.state_path, self.state.model_dump())
def run(self, readme_path: Path) -> None:
"""
主执行流程支持断点续写
"""
logger.info("=" * 50)
logger.info("开始代码生成流程")
logger.info(f"README: {readme_path}")
logger.info(f"输出目录: {self.output_dir}")
logger.info(f"断点续写: {self.resume}")
# 初始化阶段
self.console.print("[bold yellow]🔍 正在解析README...[/bold yellow]")
self.readme_content = self.parse_readme(readme_path)
# 加载或初始化状态
if self.resume and self.state_path.exists():
raw_state = load_state(self.state_path)
self.state = GeneratorState(**raw_state) if raw_state else GeneratorState()
logger.info(f"加载状态文件: {self.state_path}")
# 检查README是否变更
if self.state.readme_hash and self.state.readme_hash != hash(self.readme_content):
logger.warning("README内容已变更建议使用 --no-resume 重新开始")
else:
self.state = GeneratorState()
self.console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]")
files, dependencies = self.get_project_structure()
# 过滤已生成的文件
if self.resume and self.state:
pending_files = [f for f in files if f not in self.state.generated_files]
logger.info(f"跳过了 {len(files) - len(pending_files)} 个已生成文件,剩余 {len(pending_files)}")
files = pending_files
if not files:
logger.success("所有文件已生成,无需继续")
return
self.console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]")
# 创建进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=self.console,
) as progress:
self.progress = progress
total_task = progress.add_task("[cyan]整体进度...", total=len(files))
for idx, file in enumerate(files, 1):
logger.info(f"处理文件 [{idx}/{len(files)}]: {file}")
file_task = progress.add_task(f"生成 {file}", total=None)
try:
deps = dependencies.get(file, [])
instruction = f"请根据README描述和依赖文件生成文件 '{file}' 的完整代码。"
code, desc, commands = self.generate_file(file, instruction, deps)
logger.info(f"生成完成: {file} - {desc}")
# 写入文件
output_path = self.output_dir / file
safe_write_file(output_path, code)
logger.info(f"已写入: {output_path}")
# 执行命令,跳过已执行的
executed_in_this_file = []
for cmd in commands:
if self.resume and self.state and cmd in self.state.executed_commands:
logger.info(f"跳过已执行命令: {cmd}")
continue
logger.info(f"准备执行命令: {cmd}")
self.execute_command(cmd, cwd=self.output_dir)
executed_in_this_file.append(cmd)
# 更新状态
self._update_state(file, executed_in_this_file)
except Exception as e:
logger.error(f"处理文件 {file} 失败: {e}")
raise
finally:
progress.remove_task(file_task)
progress.update(total_task, advance=1)
logger.success("所有文件处理完成!")

100
src/llm_codegen/models.py Normal file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""
数据模型定义模块使用 Pydantic 进行数据验证和序列化
"""
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel, Field
class GeneratorState(BaseModel):
"""
断点续写状态模型用于保存和加载 .llm_generator_state.json 文件
"""
generated_files: List[str] = Field(
default_factory=list,
description="已成功生成的文件路径列表"
)
executed_commands: List[str] = Field(
default_factory=list,
description="已执行的操作系统命令列表"
)
readme_hash: Optional[str] = Field(
default=None,
description="README 内容的哈希值,用于检测变更"
)
created_at: datetime = Field(
default_factory=datetime.now,
description="状态文件创建时间"
)
updated_at: datetime = Field(
default_factory=datetime.now,
description="状态文件最后更新时间"
)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
class FileInfo(BaseModel):
"""
生成文件的信息模型
"""
path: str = Field(..., description="文件路径")
code: Optional[str] = Field(default=None, description="生成的代码内容")
description: Optional[str] = Field(default=None, description="文件功能描述")
commands: List[str] = Field(
default_factory=list,
description="生成后需要执行的命令列表"
)
class ProjectStructure(BaseModel):
"""
项目结构模型包括文件列表和依赖关系
"""
files: List[str] = Field(
...,
description="按生成顺序排列的文件路径列表"
)
dependencies: Dict[str, List[str]] = Field(
default_factory=dict,
description="文件依赖关系,键为文件路径,值为依赖文件列表"
)
class CheckResult(BaseModel):
"""
检查工具的结果模型
"""
tool: str = Field(..., description="检查工具名称,如 'pylint'")
passed: bool = Field(..., description="检查是否通过")
errors: List[str] = Field(
default_factory=list,
description="错误信息列表"
)
warnings: List[str] = Field(
default_factory=list,
description="警告信息列表"
)
class ConfigModel(BaseModel):
"""
配置模型对应 pyproject.toml 中的 [tool.llm-codegen] 部分
"""
check_tools: List[str] = Field(
default=["pytest", "pylint", "mypy", "black"],
description="要运行的检查工具列表"
)
max_retries: int = Field(
default=3,
description="修复的最大重试次数"
)
dangerous_commands: List[str] = Field(
default=["rm", "sudo", "chmod", "dd"],
description="危险命令列表"
)

130
src/llm_codegen/utils.py Normal file
View File

@ -0,0 +1,130 @@
"""
utils.py - 工具函数模块
包含危险命令判断文件操作状态管理等通用函数
"""
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
def is_dangerous_command(cmd: str, dangerous_commands: Optional[List[str]] = None) -> Tuple[bool, str]:
"""
判断命令是否危险
Args:
cmd: 要检查的命令字符串
dangerous_commands: 危险命令关键词列表如果为None则使用默认列表
Returns:
Tuple[bool, str]: (是否危险, 原因)
"""
if dangerous_commands is None:
# 默认危险命令列表,可以从配置读取
dangerous_commands = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
cmd_lower = cmd.lower()
for danger in dangerous_commands:
if danger in cmd_lower:
return True, f"包含危险关键词 '{danger}'"
return False, ""
def load_dangerous_commands(config_path: Optional[Path] = None) -> List[str]:
"""
从配置文件加载危险命令列表简化实现实际应从pyproject.toml读取
Args:
config_path: 配置文件路径默认为None表示使用默认列表
Returns:
List[str]: 危险命令关键词列表
"""
# 在实际实现中应使用tomli库解析配置这里返回默认列表
return ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
def safe_read_file(file_path: Path, encoding: str = "utf-8") -> str:
"""
安全读取文件内容
Args:
file_path: 文件路径
encoding: 文件编码默认为utf-8
Returns:
str: 文件内容
Raises:
FileNotFoundError: 如果文件不存在
IOError: 如果读取失败
"""
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(f"文件不存在: {file_path}")
except Exception as e:
raise IOError(f"读取文件失败 {file_path}: {e}")
def safe_write_file(file_path: Path, content: str, encoding: str = "utf-8") -> None:
"""
安全写入文件内容确保目录存在
Args:
file_path: 文件路径
content: 要写入的内容
encoding: 文件编码默认为utf-8
"""
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding=encoding) as f:
f.write(content)
def load_state(state_path: Path) -> Dict[str, Any]:
"""
加载断点续写状态文件
Args:
state_path: 状态文件路径.llm_generator_state.json
Returns:
Dict[str, Any]: 状态数据如果文件不存在或解析失败则返回空字典
"""
if not state_path.exists():
return {}
try:
with open(state_path, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
def save_state(state_path: Path, state: Dict[str, Any]) -> None:
"""
保存断点续写状态文件
Args:
state_path: 状态文件路径
state: 状态数据
"""
with open(state_path, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
def normalize_path(path: str, base_dir: Optional[Path] = None) -> Path:
"""
规范化路径相对于基础目录
Args:
path: 路径字符串
base_dir: 基础目录默认为None使用当前工作目录
Returns:
Path: 规范化后的Path对象
"""
if base_dir is None:
base_dir = Path.cwd()
return (base_dir / path).resolve()

6
tests/__init__.py Normal file
View File

@ -0,0 +1,6 @@
"""
Initialization file for the tests package of the LLM code generation tool.
This file marks the 'tests' directory as a Python package, enabling proper
imports and test discovery with pytest.
"""

184
tests/test_core.py Normal file
View File

@ -0,0 +1,184 @@
import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
import json
import os
import sys
from datetime import datetime
from src.llm_codegen.core import CodeGenerator
from src.llm_codegen.models import GeneratorState, ConfigModel
from src.llm_codegen.utils import is_dangerous_command
class TestCodeGenerator:
"""测试 CodeGenerator 核心类的单元测试。"""
@pytest.fixture
def mock_openai_client(self):
"""模拟 OpenAI 客户端。"""
with patch('src.llm_codegen.core.OpenAI') as mock:
client = Mock()
mock.return_value = client
yield client
@pytest.fixture
def generator(self, mock_openai_client, tmp_path):
"""创建 CodeGenerator 实例,使用临时目录和模拟 API。"""
output_dir = tmp_path / "output"
return CodeGenerator(
api_key="test-api-key",
base_url="https://api.deepseek.com",
model="deepseek-reasoner",
output_dir=str(output_dir),
resume=False,
)
def test_init(self, generator, tmp_path):
"""测试初始化。"""
assert generator.api_key == "test-api-key"
assert generator.model == "deepseek-reasoner"
assert generator.output_dir == tmp_path / "output"
assert generator.resume is False
assert isinstance(generator.config, ConfigModel)
assert generator.dangerous_commands == ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
assert generator.state is None
def test_parse_readme(self, generator, tmp_path):
"""测试读取 README 文件。"""
readme_path = tmp_path / "README.md"
readme_content = "# Test Project\nThis is a test README."
readme_path.write_text(readme_content)
result = generator.parse_readme(readme_path)
assert result == readme_content
def test_get_project_structure(self, generator, mock_openai_client):
"""测试获取项目结构,模拟 LLM 响应。"""
generator.readme_content = "# Test README"
mock_response = {
"files": ["src/__init__.py", "src/core.py"],
"dependencies": {"src/core.py": ["src/__init__.py"]}
}
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
files, dependencies = generator.get_project_structure()
assert files == ["src/__init__.py", "src/core.py"]
assert dependencies == {"src/core.py": ["src/__init__.py"]}
mock_openai_client.chat.completions.create.assert_called_once()
def test_generate_file(self, generator, mock_openai_client, tmp_path):
"""测试生成单个文件,模拟依赖文件和 LLM 响应。"""
generator.readme_content = "# Test README"
dep_file = tmp_path / "dep.txt"
dep_file.write_text("Dependency content")
mock_response = {
"code": "print('Hello, World!')",
"description": "测试文件生成",
"commands": ["echo 'test'"]
}
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
code, desc, commands = generator.generate_file(
"test.py",
"生成测试文件",
[str(dep_file)]
)
assert code == "print('Hello, World!')"
assert desc == "测试文件生成"
assert commands == ["echo 'test'"]
mock_openai_client.chat.completions.create.assert_called_once()
def test_execute_command_safe(self, generator, tmp_path):
"""测试执行安全命令。"""
with patch('subprocess.run') as mock_run:
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = "output"
mock_run.return_value.stderr = ""
generator.execute_command("echo 'test'", cwd=tmp_path)
mock_run.assert_called_once_with(
"echo 'test'",
shell=True,
cwd=tmp_path,
capture_output=True,
text=True,
timeout=300
)
def test_execute_command_dangerous(self, generator):
"""测试阻止危险命令。"""
with pytest.raises(RuntimeError, match="危险命令"):
generator.execute_command("rm -rf /") # 假设在危险命令列表中
def test_run_without_resume(self, generator, mock_openai_client, tmp_path):
"""测试完整运行流程,禁用断点续写。"""
readme_path = tmp_path / "README.md"
readme_path.write_text("# Test README")
generator.readme_content = "# Test README"
# 模拟 get_project_structure 响应
mock_structure = {
"files": ["file1.py", "file2.py"],
"dependencies": {}
}
mock_openai_client.chat.completions.create.side_effect = [
Mock(choices=[Mock(message=Mock(content=json.dumps(mock_structure)))]),
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code1", "description": "desc1", "commands": []})))]),
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code2", "description": "desc2", "commands": []})))])
]
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
patch('src.llm_codegen.core.safe_read_file') as mock_read, \
patch('src.llm_codegen.core.save_state') as mock_save:
mock_read.return_value = "content"
generator.run(readme_path)
# 验证文件生成和状态保存
assert mock_write.call_count == 2
assert mock_save.called
def test_run_with_resume(self, generator, mock_openai_client, tmp_path):
"""测试断点续写功能。"""
generator.resume = True
generator.state = GeneratorState(generated_files=["file1.py"], executed_commands=[])
readme_path = tmp_path / "README.md"
readme_path.write_text("# Test README")
generator.readme_content = "# Test README"
mock_structure = {
"files": ["file1.py", "file2.py"],
"dependencies": {}
}
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_structure)
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
patch('src.llm_codegen.core.safe_read_file') as mock_read:
mock_read.return_value = "content"
generator.run(readme_path)
# 只应生成 file2.py跳过 file1.py
assert mock_write.call_count == 1
def test_load_config_default(self, generator):
"""测试加载默认配置。"""
config = generator._load_config()
assert isinstance(config, ConfigModel)
assert config.check_tools == ["pytest", "pylint", "mypy", "black"]
assert config.max_retries == 3
def test_update_state(self, generator, tmp_path):
"""测试更新状态文件。"""
generator.state_path = tmp_path / "state.json"
generator.state = GeneratorState()
with patch('src.llm_codegen.core.save_state') as mock_save:
generator._update_state("new_file.py", ["cmd1"])
assert generator.state.generated_files == ["new_file.py"]
assert generator.state.executed_commands == ["cmd1"]
mock_save.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])