commit d431b24755ad71e6bc98a8e9a4ed71c4dfb34e15 Author: songsenand Date: Mon Mar 16 08:59:59 2026 +0800 feat: 添加项目初始化文件和README模板以支持代码生成工具 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..505a3b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# Virtual environments +.venv diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/README.md b/README.md new file mode 100644 index 0000000..7b42251 --- /dev/null +++ b/README.md @@ -0,0 +1,533 @@ +# LLM 代码生成工具(自举版) + +本项目是一个基于大语言模型的代码生成工具,能够根据项目 `README.md` 描述自动生成完整的 Python 包代码,并具备代码检查、测试和自动修复能力。它是前一个代码生成器的升级版本,采用 `uv` 进行包管理,包含完整的单元测试、并行检查模块,并可通过命令行直接调用。 + +## 特别说明 + +我已经实现了一个简易版本,请在此基础上拓展开发: + +```python +#!/home/songsenand/env/.venv/bin/python +""" +基于LLM的自动化代码生成工具 +根据README.md文件,自动生成项目文件结构并填充代码,执行必要命令。 +""" + +import json +import os +import subprocess +import sys +from typing import List, Dict, Optional, Any, Tuple +from pathlib import Path + +import typer +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID +from loguru import logger +from openai import OpenAI + +# ==================== 配置 ==================== +DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"] +ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单 + +app = typer.Typer(help="基于LLM的自动化代码生成工具") +console = Console() + +# ==================== 工具函数 ==================== +def is_dangerous_command(cmd: str) -> Tuple[bool, str]: + """ + 判断命令是否危险 + 返回 (是否危险, 原因) + """ + cmd_lower = cmd.lower() + for danger in DANGEROUS_COMMANDS: + if danger in cmd_lower: + return True, f"包含危险关键词 '{danger}'" + return False, "" + +# ==================== 核心类 ==================== +class CodeGenerator: + """代码生成器,封装所有逻辑""" + + def __init__( + self, + api_key: Optional[str] = None, + base_url: str = "https://api.deepseek.com", + model: str = "deepseek-reasoner", + output_dir: str = "./generated", + log_file: Optional[str] = None, + ): + """ + 初始化生成器 + + Args: + api_key: OpenAI API密钥,默认从环境变量DEEPSEEK_APIKEY读取 + base_url: API基础URL + model: 使用的模型 + output_dir: 输出根目录 + log_file: 日志文件路径,默认自动生成 + """ + self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY") + if not self.api_key: + raise ValueError("必须提供API密钥,或设置环境变量DEEPSEEK_APIKEY") + + self.client = OpenAI(api_key=self.api_key, base_url=base_url) + self.model = model + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # 配置日志 + if log_file is None: + log_file = self.output_dir / "generator.log" + logger.remove() # 移除默认handler + logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上 + logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG + logger.info(f"日志已初始化,保存至: {log_file}") + + self.readme_content = None + + self.progress: Optional[Progress] = None + self.tasks: Dict[str, TaskID] = {} # 任务ID映射 + + def _call_llm( + self, + system_prompt: str, + user_prompt: str, + temperature: float = 0.2, + expect_json: bool = True, + ) -> Dict[str, Any]: + """ + 调用LLM并返回解析后的JSON + """ + logger.debug(f"调用LLM,模型: {self.model}") + logger.debug(f"System: {system_prompt[:200]}...") + logger.debug(f"User: {user_prompt[:200]}...") + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=temperature, + response_format={"type": "json_object"} if expect_json else None, + ) + + message = response.choices[0].message + content = message.content + + # 记录思考过程(如果存在) + if hasattr(message, "reasoning_content") and message.reasoning_content: + logger.info(f"模型思考过程: {message.reasoning_content}") + + logger.debug(f"LLM原始响应: {content[:500]}...") + + if expect_json: + result = json.loads(content) + else: + result = {"content": content} + + return result + + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {e}") + raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}") + except Exception as e: + logger.error(f"LLM调用失败: {e}") + raise + + def parse_readme(self, readme_path: Path) -> str: + """ + 读取README文件内容 + """ + logger.info(f"读取README文件: {readme_path}") + try: + with open(readme_path, "r", encoding="utf-8") as f: + content = f.read() + logger.debug(f"README内容长度: {len(content)} 字符") + return content + except Exception as e: + logger.error(f"读取README失败: {e}") + raise + + def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]: + """ + 根据README内容,让LLM生成文件列表和依赖关系 + + Returns: + (files, dependencies) + files: 按顺序需要生成的文件路径列表 + dependencies: 字典 {file: [依赖文件路径]} + """ + system_prompt = ( + "你是一个软件架构师。请根据README描述,分析需要生成哪些源代码文件,并确定它们的生成顺序," + "同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。" + "返回严格的JSON对象,包含两个字段:\n" + "- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n" + "- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n" + "注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。" + ) + user_prompt = f"README内容如下:\n\n{self.readme_content}" + + result = self._call_llm(system_prompt, user_prompt) + + files = result.get("files", []) + dependencies = result.get("dependencies", {}) + + if not files: + raise ValueError("LLM未返回任何文件列表") + + logger.info(f"解析到 {len(files)} 个待生成文件") + logger.debug(f"文件列表: {files}") + logger.debug(f"依赖关系: {dependencies}") + + return files, dependencies + + def generate_file( + self, + file_path: str, + prompt_instruction: str, + dependency_files: List[str], + ) -> Tuple[str, str, List[str]]: + """ + 生成单个文件,返回 (代码, 描述, 命令列表) + """ + # 读取依赖文件内容 + context_content = [] + + if self.readme_content: + context_content.append(f"### 项目 README ###\n{self.readme_content}\n") + + for dep in dependency_files: + dep_path = Path(dep) + if not dep_path.exists(): + # 尝试相对于当前目录或输出目录查找 + alt_path = self.output_dir / dep + if alt_path.exists(): + dep_path = alt_path + else: + raise FileNotFoundError(f"依赖文件不存在: {dep}") + + with open(dep_path, "r", encoding="utf-8") as f: + content = f.read() + context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n") + + full_context = "\n".join(context_content) + + system_prompt = ( + "你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。" + "返回严格的JSON对象,包含三个字段:\n" + "- code: (string) 生成的完整代码\n" + "- description: (string) 简短的中文功能描述\n" + "- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组" + ) + user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}" + + result = self._call_llm(system_prompt, user_prompt) + + code = result.get("code", "") + description = result.get("description", "") + commands = result.get("commands", []) + + if not isinstance(commands, list): + commands = [] + + return code, description, commands + + def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None: + """ + 执行单个命令,检查风险 + """ + dangerous, reason = is_dangerous_command(cmd) + if dangerous: + logger.error(f"危险命令被阻止: {cmd},原因: {reason}") + raise RuntimeError(f"危险命令: {cmd} ({reason})") + + logger.info(f"执行命令: {cmd}") + try: + result = subprocess.run( + cmd, + shell=True, + cwd=cwd or self.output_dir, + capture_output=True, + text=True, + timeout=300, # 5分钟超时 + ) + logger.debug(f"命令返回码: {result.returncode}") + if result.stdout: + logger.debug(f"stdout: {result.stdout[:500]}") + if result.stderr: + logger.warning(f"stderr: {result.stderr[:500]}") + if result.returncode != 0: + raise subprocess.CalledProcessError(result.returncode, cmd) + except subprocess.TimeoutExpired: + logger.error(f"命令执行超时: {cmd}") + raise + except Exception as e: + logger.error(f"命令执行失败: {e}") + raise + + def run(self, readme_path: Path): + """ + 主执行流程 + """ + logger.info("=" * 50) + logger.info("开始代码生成流程") + logger.info(f"README: {readme_path}") + logger.info(f"输出目录: {self.output_dir}") + + # 初始化阶段:用rich输出状态(不会被日志级别过滤) + console.print("[bold yellow]🔍 正在解析README...[/bold yellow]") + self.readme_content = self.parse_readme(readme_path) + + console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]") + files, dependencies = self.get_project_structure() + + console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]") + + # 3. 创建进度条 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + console=console, + ) as progress: + self.progress = progress + # 创建总任务 + total_task = progress.add_task("[cyan]整体进度...", total=len(files)) + + # 依次生成每个文件 + for idx, file in enumerate(files, 1): + logger.info(f"处理文件 [{idx}/{len(files)}]: {file}") + + # 创建子任务(可选) + file_task = progress.add_task(f"生成 {file}", total=None) + + try: + # 获取依赖文件 + deps = dependencies.get(file, []) + + # 构造生成指令 + instruction = f"请根据README描述和依赖文件,生成文件 '{file}' 的完整代码。" + + # 调用LLM生成代码 + code, desc, commands = self.generate_file(file, instruction, deps) + + logger.info(f"生成完成: {file} - {desc}") + + # 写入文件 + output_path = self.output_dir / file + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + f.write(code) + logger.info(f"已写入: {output_path}") + + # 执行命令 + for cmd in commands: + logger.info(f"准备执行命令: {cmd}") + self.execute_command(cmd, cwd=self.output_dir) + + except Exception as e: + logger.error(f"处理文件 {file} 失败: {e}") + # 可选:继续或终止 + raise + finally: + progress.remove_task(file_task) + progress.update(total_task, advance=1) + + logger.success("所有文件处理完成!") + +# ==================== CLI入口 ==================== +@app.command() +def main( + readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"), + output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录,默认为readme所在目录"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥,也可通过环境变量DEEPSEEK_APIKEY设置"), + base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), + model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), + log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径(默认输出目录下generator.log)"), +): + """ + 根据README自动生成项目代码 + """ + if output_dir is None: + output_dir = readme.parent + + try: + generator = CodeGenerator( + api_key=api_key, + base_url=base_url, + model=model, + output_dir=output_dir, + log_file=log_file, + ) + generator.run(readme) + except Exception as e: + logger.error(f"程序异常退出: {e}") + raise typer.Exit(code=1) + +if __name__ == "__main__": + app() +``` + +## 功能特性 + +- 📦 **自动生成**:解析 `README.md`,分析需要生成的文件列表及依赖关系,按顺序生成每个文件的代码。 +- 🔧 **命令执行**:生成文件后可自动执行建议命令(如安装依赖、运行构建),内置危险命令拦截。 +- ✅ **单元测试**:使用 `pytest` 编写测试用例,支持测试覆盖率统计。 +- 🔍 **并行检查**:生成代码后并行运行多个检查工具(如 `pylint`、`mypy`、`black`),收集错误信息。 +- 🔄 **自修复**:将检查错误、`README` 和相关代码作为上下文提交给 LLM,自动生成修复补丁并应用。 +- ⏯️ **断点续写**:如果生成过程意外中断(如网络问题、API 限制),重新运行时会从上次中断处继续,已生成的文件和已执行的命令不会重复执行,状态自动保存在输出目录下的 `.llm_generator_state.json` 文件中。 +- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,参数兼容原脚本(`--output`、`--api-key`、`--model` 等)。 +- 📝 **详细日志**:所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。 +- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。 + +## 安装 + +### 依赖 + +- Python 3.9+ +- 使用 `uv` 管理包 + +```bash +# 使用 uv +uv add [dev] +``` + +### 配置 API 密钥 + +设置环境变量(推荐): +```bash +export DEEPSEEK_APIKEY="your-api-key" +``` + +或在命令行中通过 `--api-key` 传入。 + +## 使用方法 + +```bash +llm-codegen [OPTIONS] README +``` + +### 参数 + +| 参数 | 说明 | +|------|------| +| `README` | `README.md` 文件路径(必须) | +| `--output, -o` | 输出根目录(默认:README 所在目录) | +| `--api-key` | API 密钥(默认:环境变量 `DEEPSEEK_APIKEY`) | +| `--base-url` | API 基础 URL(默认:`https://api.deepseek.com`) | +| `--model, -m` | 使用的模型(默认:`deepseek-reasoner`) | +| `--log` | 日志文件路径(默认:输出目录下 `generator.log`) | +| `--resume/--no-resume` | 是否启用断点续写(默认:`--resume`,即自动从上次中断处继续) | +| `--no-check` | 跳过生成后的检查和修复 | +| `--help` | 显示帮助信息 | + +### 示例 + +```bash +llm-codegen my_project/README.md -o ./generated +``` + +如果中途中断,只需再次运行相同的命令,工具会自动检测状态文件并从上次中断处继续生成。 + +## 项目结构 + +生成的项目将包含以下文件和目录: + +``` +. +├── README.md # 项目说明(原始输入) +├── pyproject.toml # 项目元数据、依赖、脚本入口 +├── src/ +│ └── llm_codegen/ # 主代码包 +│ ├── __init__.py +│ ├── cli.py # 命令行入口(typer) +│ ├── core.py # 核心生成逻辑(CodeGenerator 类) +│ ├── checker.py # 并行检查与修复模块 +│ ├── utils.py # 工具函数(危险命令判断、文件操作) +│ └── models.py # 数据模型(Pydantic) +├── tests/ # 单元测试 +│ ├── __init__.py +│ ├── test_cli.py +│ ├── test_core.py +│ └── test_checker.py +└── logs/ # 运行日志(自动创建) +``` + +## 核心流程 + +1. **解析阶段**:读取 `README.md`,调用 LLM 获取 `files`(按生成顺序的文件路径列表)和 `dependencies`(每个文件依赖的已有文件列表)。 +2. **生成阶段**:按顺序生成每个文件,使用 `README` 和依赖文件作为上下文,同时获取 LLM 建议的命令。每成功生成一个文件并执行命令后,状态会自动保存到 `.llm_generator_state.json`。 +3. **命令执行**:对每个建议命令进行危险检查,低风险则执行。已执行的命令记录在状态文件中,避免重复执行。 +4. **检查阶段**(可选):生成完成后,并行运行配置的检查工具(如 `pytest`、`pylint`、`mypy`),收集错误。 +5. **修复阶段**(可选):若检查失败,将错误信息、`README` 和相关文件内容提交给 LLM,请求生成修复方案,并自动应用修改。重复直到检查通过或达到重试次数上限。 + +## 断点续写机制 + +- 状态文件保存在输出目录下的 `.llm_generator_state.json`,记录已成功生成的文件列表和已执行的命令。 +- 重新运行工具时(默认启用 `--resume`),会自动读取状态文件,跳过已完成的部分,从下一个文件开始继续。 +- 如果 `README` 发生重大变更导致文件列表不一致,工具会检测并提示用户重新开始(可通过 `--no-resume` 强制从头生成)。 +- 状态文件在全部流程成功完成后可手动删除,工具不会自动删除,以便后续查看或用于调试。 + +## 开发指南 + +### 环境设置 + +```bash +# 安装 uv(若未安装) +curl -LsSf https://astral.sh/uv/install.sh | sh + +# 创建虚拟环境并激活 +uv venv +source .venv/bin/activate # Linux/macOS +# 或 .venv\Scripts\activate # Windows + +# 安装项目(可编辑模式)和开发依赖 +uv pip install -e ".[dev]" +``` + +### 运行测试 + +```bash +pytest tests/ --cov=src/llm_codegen +``` + +### 代码检查 + +```bash +# 运行所有检查 +pre-commit run --all-files + +# 或手动运行 +pylint src/llm_codegen +mypy src/llm_codegen +black --check src/llm_codegen +``` + +### 添加新功能 + +1. 在 `src/llm_codegen/` 下添加或修改模块。 +2. 在 `tests/` 中添加对应的单元测试。 +3. 更新 `README.md` 和命令行帮助信息。 + +## 配置 + +通过 `pyproject.toml` 的 `[tool.llm-codegen]` 部分可以自定义检查工具和修复行为: + +```toml +[tool.llm-codegen] +check_tools = ["pytest", "pylint", "mypy", "black"] +max_retries = 3 +dangerous_commands = ["rm", "sudo", "chmod", "dd"] +``` + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ca1d38f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,49 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "llm-codegen" +version = "0.1.0" +description = "基于大语言模型的自动化代码生成工具,根据README.md描述自动生成完整的Python包代码,具备代码检查、测试和自动修复能力。" +authors = [ + {name = "Your Name", email = "your.email@example.com"} +] +readme = "README.md" +license = {text = "MIT"} +requires-python = ">=3.9" +dependencies = [ + "typer>=0.9.0", + "rich>=13.0.0", + "loguru>=0.7.0", + "openai>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pylint>=3.0.0", + "mypy>=1.0.0", + "black>=23.0.0", + "pre-commit>=3.0.0", +] + +[project.urls] +Homepage = "https://github.com/yourusername/llm-codegen" + +[project.scripts] +llm-codegen = "llm_codegen.cli:app" + +[tool.llm-codegen] +check_tools = ["pytest", "pylint", "mypy", "black"] +max_retries = 3 +dangerous_commands = ["rm", "sudo", "chmod", "dd"] + +[tool.black] +line-length = 88 +target-version = ['py39'] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "--cov=src/llm_codegen --cov-report=term-missing" \ No newline at end of file diff --git a/src/llm_codegen/__init__.py b/src/llm_codegen/__init__.py new file mode 100644 index 0000000..13ad370 --- /dev/null +++ b/src/llm_codegen/__init__.py @@ -0,0 +1,21 @@ +""" +LLM Code Generator package. + +This package provides an automated code generation tool based on large language models (LLMs). +It can generate complete Python package code from README descriptions, with features like code checking, testing, and auto-fixing. +""" + +__version__ = "0.1.0" +__author__ = "LLM CodeGen Team" +__description__ = "A self-bootstrapping LLM-based code generation tool" + +# Export main components for easy access from the package +from .core import CodeGenerator +from .cli import app +from .utils import is_dangerous_command + +__all__ = [ + "CodeGenerator", + "app", + "is_dangerous_command", +] \ No newline at end of file diff --git a/src/llm_codegen/checker.py b/src/llm_codegen/checker.py new file mode 100644 index 0000000..b85694d --- /dev/null +++ b/src/llm_codegen/checker.py @@ -0,0 +1,242 @@ +""" +checker.py - 并行检查与修复模块 +负责在代码生成后运行配置的检查工具(如pylint、mypy、black)并收集错误, +然后使用LLM自动生成和应用修复补丁。 +""" + +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import List, Dict, Optional, Any, Tuple +from concurrent.futures import ThreadPoolExecutor, as_completed + +from loguru import logger +from openai import OpenAI + +from .models import ConfigModel # 从models.py导入配置模型 +from .utils import safe_read_file, safe_write_file # 工具函数 + + +class Checker: + """ + 检查与修复器类,提供并行运行检查工具和自动修复功能。 + """ + + def __init__( + self, + output_dir: Path, + config: ConfigModel, + api_key: Optional[str] = None, + base_url: str = "https://api.deepseek.com", + model: str = "deepseek-reasoner", + ): + """ + 初始化检查器。 + + Args: + output_dir: 输出目录,包含生成的代码。 + config: 配置模型,包含check_tools、max_retries等。 + api_key: LLM API密钥,如果None则从环境变量DEEPSEEK_APIKEY获取。 + base_url: LLM API基础URL。 + model: LLM模型。 + """ + self.output_dir = output_dir + self.config = config + self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY") + if not self.api_key: + raise ValueError("API密钥未提供,请设置环境变量DEEPSEEK_APIKEY或传入api_key") + self.client = OpenAI(api_key=self.api_key, base_url=base_url) + self.model = model + self.max_retries = config.max_retries + + def run_check_tool(self, tool: str, file_path: Path) -> Tuple[bool, str]: + """ + 运行单个检查工具并返回结果。 + + Args: + tool: 工具名称,如"pylint"。 + file_path: 要检查的文件路径。 + + Returns: + (success, output): 成功为True,错误输出字符串。 + """ + commands = { + "pylint": f"pylint {file_path}", + "mypy": f"mypy {file_path}", + "black": f"black --check {file_path}", + "pytest": f"pytest {file_path}", # 假设检查测试文件 + } + if tool not in commands: + logger.warning(f"未知检查工具: {tool}") + return True, "" # 跳过未知工具 + + cmd = commands[tool] + try: + result = subprocess.run( + cmd, + shell=True, + cwd=self.output_dir, + capture_output=True, + text=True, + timeout=60, # 1分钟超时 + ) + if result.returncode == 0: + return True, "" + else: + output = result.stdout + result.stderr + return False, output + except subprocess.TimeoutExpired: + logger.error(f"检查工具 {tool} 超时") + return False, "超时" + except Exception as e: + logger.error(f"运行检查工具 {tool} 失败: {e}") + return False, str(e) + + def run_parallel_checks(self, files: List[Path]) -> Dict[str, List[Tuple[str, str]]]: + """ + 并行运行所有配置的检查工具。 + + Args: + files: 要检查的文件路径列表。 + + Returns: + 错误字典,键为文件路径,值为列表,每个元素为(工具名, 错误输出)。 + """ + errors = {} + check_tools = self.config.check_tools + + with ThreadPoolExecutor() as executor: + futures = [] + for file in files: + for tool in check_tools: + future = executor.submit(self.run_check_tool, tool, file) + futures.append((future, file, tool)) + + for future, file, tool in futures: + success, output = future.result() + if not success: + if file not in errors: + errors[file] = [] + errors[file].append((tool, output)) + + return errors + + def call_llm_for_fix(self, file_path: Path, errors: List[Tuple[str, str]], readme_content: str) -> Optional[str]: + """ + 调用LLM生成修复补丁。 + + Args: + file_path: 需要修复的文件。 + errors: 错误列表,每个元素为(工具名, 错误输出)。 + readme_content: README内容。 + + Returns: + 修复后的代码字符串,如果失败返回None。 + """ + error_summary = "\n".join([f"{tool}: {err}" for tool, err in errors]) + file_content = safe_read_file(file_path) + + system_prompt = ( + "你是一个专业的代码修复助手。给定代码、错误信息和项目README,请生成修复后的完整代码。" + "返回严格的JSON对象,包含字段:\n" + "- code: (string) 修复后的完整代码\n" + "- description: (string) 修复描述\n" + ) + user_prompt = ( + f"项目README:\n{readme_content}\n\n" + f"文件内容:\n{file_content}\n\n" + f"错误信息:\n{error_summary}\n\n" + "请生成修复后的代码,确保所有检查通过。" + ) + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.2, + response_format={"type": "json_object"}, + ) + result = json.loads(response.choices[0].message.content) + return result.get("code") + except Exception as e: + logger.error(f"调用LLM修复失败: {e}") + return None + + def apply_fix(self, file_path: Path, new_code: str) -> bool: + """ + 应用修复代码。 + + Args: + file_path: 文件路径。 + new_code: 新代码。 + + Returns: + 是否成功应用。 + """ + try: + safe_write_file(file_path, new_code) + logger.info(f"已应用修复到 {file_path}") + return True + except Exception as e: + logger.error(f"应用修复失败: {e}") + return False + + def run_checks_and_fixes(self, readme_content: str, files: Optional[List[Path]] = None) -> bool: + """ + 主方法:运行检查并自动修复。 + + Args: + readme_content: README内容。 + files: 要检查的文件列表,如果None则检查output_dir下所有Python文件。 + + Returns: + 是否所有检查最终通过。 + """ + if files is None: + # 递归查找所有Python文件 + files = list(self.output_dir.rglob("*.py")) + + for attempt in range(self.max_retries): + logger.info(f"检查尝试 {attempt + 1}/{self.max_retries}") + errors = self.run_parallel_checks(files) + + if not errors: + logger.success("所有检查通过!") + return True + + # 有错误,尝试修复 + logger.warning(f"发现 {len(errors)} 个文件有错误,尝试修复") + all_fixed = True + for file_path, error_list in errors.items(): + new_code = self.call_llm_for_fix(file_path, error_list, readme_content) + if new_code: + if self.apply_fix(file_path, new_code): + # 修复后重新检查这个文件 + success, _ = self.run_check_tool("pylint", file_path) # 简化检查,重新运行一个工具 + if not success: + all_fixed = False + else: + all_fixed = False + else: + all_fixed = False + + if all_fixed: + logger.info("修复成功,重新检查...") + continue # 重新检查所有文件 + else: + logger.error("修复失败或仍有错误") + break + + # 最终检查 + final_errors = self.run_parallel_checks(files) + if not final_errors: + logger.success("检查最终通过") + return True + else: + logger.error(f"检查失败,剩余错误: {final_errors}") + return False diff --git a/src/llm_codegen/cli.py b/src/llm_codegen/cli.py new file mode 100644 index 0000000..e9bc679 --- /dev/null +++ b/src/llm_codegen/cli.py @@ -0,0 +1,49 @@ +import typer +from pathlib import Path +from typing import Optional +from loguru import logger +from rich.console import Console +from .core import CodeGenerator + +app = typer.Typer(help="基于LLM的自动化代码生成工具") +console = Console() + +@app.command() +def main( + readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md文件路径"), + output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录,默认为readme所在目录"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥,也可通过环境变量DEEPSEEK_APIKEY设置"), + base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), + model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), + log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径(默认输出目录下generator.log)"), + resume: bool = typer.Option(True, "--resume/--no-resume", help="是否启用断点续写(默认启用)"), + no_check: bool = typer.Option(False, "--no-check", help="跳过生成后的检查和修复"), +): + """ + 根据README自动生成项目代码,支持断点续写和可选检查。 + """ + if output_dir is None: + output_dir = readme.parent + + try: + generator = CodeGenerator( + api_key=api_key, + base_url=base_url, + model=model, + output_dir=output_dir, + log_file=log_file, + resume=resume, + config_path=None, # 配置文件路径,可从pyproject.toml加载,但CLI中暂不提供参数 + ) + generator.run(readme) + + # 如果未跳过检查,提示用户检查功能暂未实现 + if not no_check: + console.print("[yellow]注意:检查和修复功能暂未在此版本中实现,请手动运行检查工具(如pytest、pylint)。[/yellow]") + + except Exception as e: + logger.error(f"程序异常退出: {e}") + raise typer.Exit(code=1) + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/src/llm_codegen/core.py b/src/llm_codegen/core.py new file mode 100644 index 0000000..f818954 --- /dev/null +++ b/src/llm_codegen/core.py @@ -0,0 +1,365 @@ +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import List, Dict, Optional, Any, Tuple + +from loguru import logger +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID +from rich.console import Console +from openai import OpenAI + +# 导入本地模块 +from .utils import is_dangerous_command, safe_read_file, safe_write_file, load_state, save_state, normalize_path, load_dangerous_commands +from .models import GeneratorState, FileInfo, ProjectStructure, ConfigModel + + +class CodeGenerator: + """ + 核心代码生成器类,负责解析README、生成代码、执行命令并支持断点续写。 + """ + + def __init__( + self, + api_key: Optional[str] = None, + base_url: str = "https://api.deepseek.com", + model: str = "deepseek-reasoner", + output_dir: str = "./generated", + log_file: Optional[str] = None, + resume: bool = True, + config_path: Optional[Path] = None, + ): + """ + 初始化生成器。 + + Args: + api_key: OpenAI API密钥,默认从环境变量DEEPSEEK_APIKEY读取。 + base_url: API基础URL。 + model: 使用的模型。 + output_dir: 输出根目录。 + log_file: 日志文件路径,默认自动生成。 + resume: 是否启用断点续写,默认为True。 + config_path: 配置文件路径,用于加载危险命令等配置。 + """ + self.api_key = api_key or os.getenv("DEEPSEEK_APIKEY") + if not self.api_key: + raise ValueError("必须提供API密钥,或设置环境变量DEEPSEEK_APIKEY") + + self.client = OpenAI(api_key=self.api_key, base_url=base_url) + self.model = model + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.resume = resume + self.config_path = config_path + + # 加载配置 + self.config = self._load_config() + self.dangerous_commands = self.config.dangerous_commands + + # 配置日志 + if log_file is None: + log_file = self.output_dir / "generator.log" + logger.remove() # 移除默认handler + logger.add(sys.stderr, level="WARNING") # 控制台输出INFO及以上 + logger.add(log_file, rotation="10 MB", level="DEBUG") # 文件记录DEBUG + logger.info(f"日志已初始化,保存至: {log_file}") + + self.readme_content = None + self.state_path = self.output_dir / ".llm_generator_state.json" + self.state: Optional[GeneratorState] = None + self.progress: Optional[Progress] = None + self.tasks: Dict[str, TaskID] = {} # 任务ID映射 + self.console = Console() + + def _load_config(self) -> ConfigModel: + """ + 加载配置,如果配置文件不存在则使用默认值。 + """ + try: + # 简化实现:从环境或固定路径加载,实际应从pyproject.toml解析 + dangerous_cmds = load_dangerous_commands(self.config_path) + return ConfigModel( + check_tools=["pytest", "pylint", "mypy", "black"], + max_retries=3, + dangerous_commands=dangerous_cmds, + ) + except Exception as e: + logger.warning(f"加载配置失败,使用默认值: {e}") + return ConfigModel() + + def _call_llm( + self, + system_prompt: str, + user_prompt: str, + temperature: float = 0.2, + expect_json: bool = True, + ) -> Dict[str, Any]: + """ + 调用LLM并返回解析后的JSON。 + """ + logger.debug(f"调用LLM,模型: {self.model}") + logger.debug(f"System: {system_prompt[:200]}...") + logger.debug(f"User: {user_prompt[:200]}...") + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=temperature, + response_format={"type": "json_object"} if expect_json else None, + ) + + message = response.choices[0].message + content = message.content + + # 记录思考过程(如果存在) + if hasattr(message, "reasoning_content") and message.reasoning_content: + logger.info(f"模型思考过程: {message.reasoning_content}") + + logger.debug(f"LLM原始响应: {content[:500]}...") + + if expect_json: + result = json.loads(content) + else: + result = {"content": content} + + return result + + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {e}") + raise ValueError(f"LLM返回的不是有效JSON: {content[:200]}") + except Exception as e: + logger.error(f"LLM调用失败: {e}") + raise + + def parse_readme(self, readme_path: Path) -> str: + """ + 读取README文件内容并计算哈希值用于断点续写检测。 + """ + logger.info(f"读取README文件: {readme_path}") + try: + content = safe_read_file(readme_path) + logger.debug(f"README内容长度: {len(content)} 字符") + return content + except Exception as e: + logger.error(f"读取README失败: {e}") + raise + + def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]: + """ + 根据README内容,让LLM生成文件列表和依赖关系。 + + Returns: + (files, dependencies) + files: 按顺序需要生成的文件路径列表 + dependencies: 字典 {file: [依赖文件路径]} + """ + system_prompt = ( + "你是一个软件架构师。请根据README描述,分析需要生成哪些源代码文件,并确定它们的生成顺序," + "同时给出每个文件生成时最少需要读取哪些已有文件作为上下文。" + "返回严格的JSON对象,包含两个字段:\n" + "- files: 数组,按生成顺序排列的文件路径(相对于项目根目录)\n" + "- dependencies: 对象,键为文件路径,值为该文件依赖的已有文件路径列表(可为空)\n" + "注意:依赖文件必须是已存在的参考文件,不要包含待生成的文件。" + ) + user_prompt = f"README内容如下:\n\n{self.readme_content}" + + result = self._call_llm(system_prompt, user_prompt) + + files = result.get("files", []) + dependencies = result.get("dependencies", {}) + + if not files: + raise ValueError("LLM未返回任何文件列表") + + logger.info(f"解析到 {len(files)} 个待生成文件") + logger.debug(f"文件列表: {files}") + logger.debug(f"依赖关系: {dependencies}") + + return files, dependencies + + def generate_file( + self, + file_path: str, + prompt_instruction: str, + dependency_files: List[str], + ) -> Tuple[str, str, List[str]]: + """ + 生成单个文件,返回 (代码, 描述, 命令列表)。 + """ + # 读取依赖文件内容 + context_content = [] + + if self.readme_content: + context_content.append(f"### 项目 README ###\n{self.readme_content}\n") + + for dep in dependency_files: + dep_path = Path(dep) + if not dep_path.exists(): + # 尝试相对于输出目录查找 + alt_path = self.output_dir / dep + if alt_path.exists(): + dep_path = alt_path + else: + raise FileNotFoundError(f"依赖文件不存在: {dep}") + + content = safe_read_file(dep_path) + context_content.append(f"### 文件: {dep_path.name} (路径: {dep}) ###\n{content}\n") + + full_context = "\n".join(context_content) + + system_prompt = ( + "你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。" + "返回严格的JSON对象,包含三个字段:\n" + "- code: (string) 生成的完整代码\n" + "- description: (string) 简短的中文功能描述\n" + "- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组" + ) + user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}" + + result = self._call_llm(system_prompt, user_prompt) + + code = result.get("code", "") + description = result.get("description", "") + commands = result.get("commands", []) + + if not isinstance(commands, list): + commands = [] + + return code, description, commands + + def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> None: + """ + 执行单个命令,检查风险。 + """ + dangerous, reason = is_dangerous_command(cmd, self.dangerous_commands) + if dangerous: + logger.error(f"危险命令被阻止: {cmd},原因: {reason}") + raise RuntimeError(f"危险命令: {cmd} ({reason})") + + logger.info(f"执行命令: {cmd}") + try: + result = subprocess.run( + cmd, + shell=True, + cwd=cwd or self.output_dir, + capture_output=True, + text=True, + timeout=300, # 5分钟超时 + ) + logger.debug(f"命令返回码: {result.returncode}") + if result.stdout: + logger.debug(f"stdout: {result.stdout[:500]}") + if result.stderr: + logger.warning(f"stderr: {result.stderr[:500]}") + if result.returncode != 0: + raise subprocess.CalledProcessError(result.returncode, cmd) + except subprocess.TimeoutExpired: + logger.error(f"命令执行超时: {cmd}") + raise + except Exception as e: + logger.error(f"命令执行失败: {e}") + raise + + def _update_state(self, generated_file: str, executed_commands: List[str]) -> None: + """ + 更新断点续写状态。 + """ + if self.state is None: + self.state = GeneratorState() + self.state.generated_files.append(generated_file) + self.state.executed_commands.extend(executed_commands) + self.state.updated_at = datetime.now() + save_state(self.state_path, self.state.model_dump()) + + def run(self, readme_path: Path) -> None: + """ + 主执行流程,支持断点续写。 + """ + logger.info("=" * 50) + logger.info("开始代码生成流程") + logger.info(f"README: {readme_path}") + logger.info(f"输出目录: {self.output_dir}") + logger.info(f"断点续写: {self.resume}") + + # 初始化阶段 + self.console.print("[bold yellow]🔍 正在解析README...[/bold yellow]") + self.readme_content = self.parse_readme(readme_path) + + # 加载或初始化状态 + if self.resume and self.state_path.exists(): + raw_state = load_state(self.state_path) + self.state = GeneratorState(**raw_state) if raw_state else GeneratorState() + logger.info(f"加载状态文件: {self.state_path}") + # 检查README是否变更 + if self.state.readme_hash and self.state.readme_hash != hash(self.readme_content): + logger.warning("README内容已变更,建议使用 --no-resume 重新开始") + else: + self.state = GeneratorState() + + self.console.print("[bold yellow]📋 正在分析项目结构...[/bold yellow]") + files, dependencies = self.get_project_structure() + + # 过滤已生成的文件 + if self.resume and self.state: + pending_files = [f for f in files if f not in self.state.generated_files] + logger.info(f"跳过了 {len(files) - len(pending_files)} 个已生成文件,剩余 {len(pending_files)} 个") + files = pending_files + + if not files: + logger.success("所有文件已生成,无需继续") + return + + self.console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]") + + # 创建进度条 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + console=self.console, + ) as progress: + self.progress = progress + total_task = progress.add_task("[cyan]整体进度...", total=len(files)) + + for idx, file in enumerate(files, 1): + logger.info(f"处理文件 [{idx}/{len(files)}]: {file}") + file_task = progress.add_task(f"生成 {file}", total=None) + + try: + deps = dependencies.get(file, []) + instruction = f"请根据README描述和依赖文件,生成文件 '{file}' 的完整代码。" + code, desc, commands = self.generate_file(file, instruction, deps) + logger.info(f"生成完成: {file} - {desc}") + + # 写入文件 + output_path = self.output_dir / file + safe_write_file(output_path, code) + logger.info(f"已写入: {output_path}") + + # 执行命令,跳过已执行的 + executed_in_this_file = [] + for cmd in commands: + if self.resume and self.state and cmd in self.state.executed_commands: + logger.info(f"跳过已执行命令: {cmd}") + continue + logger.info(f"准备执行命令: {cmd}") + self.execute_command(cmd, cwd=self.output_dir) + executed_in_this_file.append(cmd) + + # 更新状态 + self._update_state(file, executed_in_this_file) + + except Exception as e: + logger.error(f"处理文件 {file} 失败: {e}") + raise + finally: + progress.remove_task(file_task) + progress.update(total_task, advance=1) + + logger.success("所有文件处理完成!") diff --git a/src/llm_codegen/models.py b/src/llm_codegen/models.py new file mode 100644 index 0000000..6fa4885 --- /dev/null +++ b/src/llm_codegen/models.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +""" +数据模型定义模块,使用 Pydantic 进行数据验证和序列化。 +""" + +from typing import List, Dict, Optional +from datetime import datetime +from pydantic import BaseModel, Field + + +class GeneratorState(BaseModel): + """ + 断点续写状态模型,用于保存和加载 .llm_generator_state.json 文件。 + """ + generated_files: List[str] = Field( + default_factory=list, + description="已成功生成的文件路径列表" + ) + executed_commands: List[str] = Field( + default_factory=list, + description="已执行的操作系统命令列表" + ) + readme_hash: Optional[str] = Field( + default=None, + description="README 内容的哈希值,用于检测变更" + ) + created_at: datetime = Field( + default_factory=datetime.now, + description="状态文件创建时间" + ) + updated_at: datetime = Field( + default_factory=datetime.now, + description="状态文件最后更新时间" + ) + + class Config: + json_encoders = { + datetime: lambda v: v.isoformat() + } + + +class FileInfo(BaseModel): + """ + 生成文件的信息模型。 + """ + path: str = Field(..., description="文件路径") + code: Optional[str] = Field(default=None, description="生成的代码内容") + description: Optional[str] = Field(default=None, description="文件功能描述") + commands: List[str] = Field( + default_factory=list, + description="生成后需要执行的命令列表" + ) + + +class ProjectStructure(BaseModel): + """ + 项目结构模型,包括文件列表和依赖关系。 + """ + files: List[str] = Field( + ..., + description="按生成顺序排列的文件路径列表" + ) + dependencies: Dict[str, List[str]] = Field( + default_factory=dict, + description="文件依赖关系,键为文件路径,值为依赖文件列表" + ) + + +class CheckResult(BaseModel): + """ + 检查工具的结果模型。 + """ + tool: str = Field(..., description="检查工具名称,如 'pylint'") + passed: bool = Field(..., description="检查是否通过") + errors: List[str] = Field( + default_factory=list, + description="错误信息列表" + ) + warnings: List[str] = Field( + default_factory=list, + description="警告信息列表" + ) + + +class ConfigModel(BaseModel): + """ + 配置模型,对应 pyproject.toml 中的 [tool.llm-codegen] 部分。 + """ + check_tools: List[str] = Field( + default=["pytest", "pylint", "mypy", "black"], + description="要运行的检查工具列表" + ) + max_retries: int = Field( + default=3, + description="修复的最大重试次数" + ) + dangerous_commands: List[str] = Field( + default=["rm", "sudo", "chmod", "dd"], + description="危险命令列表" + ) diff --git a/src/llm_codegen/utils.py b/src/llm_codegen/utils.py new file mode 100644 index 0000000..43f1345 --- /dev/null +++ b/src/llm_codegen/utils.py @@ -0,0 +1,130 @@ +""" +utils.py - 工具函数模块 +包含危险命令判断、文件操作、状态管理等通用函数。 +""" + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def is_dangerous_command(cmd: str, dangerous_commands: Optional[List[str]] = None) -> Tuple[bool, str]: + """ + 判断命令是否危险。 + + Args: + cmd: 要检查的命令字符串。 + dangerous_commands: 危险命令关键词列表,如果为None则使用默认列表。 + + Returns: + Tuple[bool, str]: (是否危险, 原因) + """ + if dangerous_commands is None: + # 默认危险命令列表,可以从配置读取 + dangerous_commands = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"] + + cmd_lower = cmd.lower() + for danger in dangerous_commands: + if danger in cmd_lower: + return True, f"包含危险关键词 '{danger}'" + return False, "" + + +def load_dangerous_commands(config_path: Optional[Path] = None) -> List[str]: + """ + 从配置文件加载危险命令列表(简化实现,实际应从pyproject.toml读取)。 + + Args: + config_path: 配置文件路径,默认为None,表示使用默认列表。 + + Returns: + List[str]: 危险命令关键词列表。 + """ + # 在实际实现中,应使用tomli库解析配置,这里返回默认列表 + return ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"] + + +def safe_read_file(file_path: Path, encoding: str = "utf-8") -> str: + """ + 安全读取文件内容。 + + Args: + file_path: 文件路径。 + encoding: 文件编码,默认为utf-8。 + + Returns: + str: 文件内容。 + + Raises: + FileNotFoundError: 如果文件不存在。 + IOError: 如果读取失败。 + """ + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except FileNotFoundError: + raise FileNotFoundError(f"文件不存在: {file_path}") + except Exception as e: + raise IOError(f"读取文件失败 {file_path}: {e}") + + +def safe_write_file(file_path: Path, content: str, encoding: str = "utf-8") -> None: + """ + 安全写入文件内容,确保目录存在。 + + Args: + file_path: 文件路径。 + content: 要写入的内容。 + encoding: 文件编码,默认为utf-8。 + """ + file_path.parent.mkdir(parents=True, exist_ok=True) + with open(file_path, 'w', encoding=encoding) as f: + f.write(content) + + +def load_state(state_path: Path) -> Dict[str, Any]: + """ + 加载断点续写状态文件。 + + Args: + state_path: 状态文件路径(如.llm_generator_state.json)。 + + Returns: + Dict[str, Any]: 状态数据,如果文件不存在或解析失败则返回空字典。 + """ + if not state_path.exists(): + return {} + try: + with open(state_path, 'r', encoding='utf-8') as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + return {} + + +def save_state(state_path: Path, state: Dict[str, Any]) -> None: + """ + 保存断点续写状态文件。 + + Args: + state_path: 状态文件路径。 + state: 状态数据。 + """ + with open(state_path, 'w', encoding='utf-8') as f: + json.dump(state, f, indent=2, ensure_ascii=False) + + +def normalize_path(path: str, base_dir: Optional[Path] = None) -> Path: + """ + 规范化路径,相对于基础目录。 + + Args: + path: 路径字符串。 + base_dir: 基础目录,默认为None(使用当前工作目录)。 + + Returns: + Path: 规范化后的Path对象。 + """ + if base_dir is None: + base_dir = Path.cwd() + return (base_dir / path).resolve() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..fdb5620 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +""" +Initialization file for the tests package of the LLM code generation tool. + +This file marks the 'tests' directory as a Python package, enabling proper +imports and test discovery with pytest. +""" \ No newline at end of file diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..6ad9f7a --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,184 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from pathlib import Path +import json +import os +import sys +from datetime import datetime + +from src.llm_codegen.core import CodeGenerator +from src.llm_codegen.models import GeneratorState, ConfigModel +from src.llm_codegen.utils import is_dangerous_command + + +class TestCodeGenerator: + """测试 CodeGenerator 核心类的单元测试。""" + + @pytest.fixture + def mock_openai_client(self): + """模拟 OpenAI 客户端。""" + with patch('src.llm_codegen.core.OpenAI') as mock: + client = Mock() + mock.return_value = client + yield client + + @pytest.fixture + def generator(self, mock_openai_client, tmp_path): + """创建 CodeGenerator 实例,使用临时目录和模拟 API。""" + output_dir = tmp_path / "output" + return CodeGenerator( + api_key="test-api-key", + base_url="https://api.deepseek.com", + model="deepseek-reasoner", + output_dir=str(output_dir), + resume=False, + ) + + def test_init(self, generator, tmp_path): + """测试初始化。""" + assert generator.api_key == "test-api-key" + assert generator.model == "deepseek-reasoner" + assert generator.output_dir == tmp_path / "output" + assert generator.resume is False + assert isinstance(generator.config, ConfigModel) + assert generator.dangerous_commands == ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"] + assert generator.state is None + + def test_parse_readme(self, generator, tmp_path): + """测试读取 README 文件。""" + readme_path = tmp_path / "README.md" + readme_content = "# Test Project\nThis is a test README." + readme_path.write_text(readme_content) + + result = generator.parse_readme(readme_path) + assert result == readme_content + + def test_get_project_structure(self, generator, mock_openai_client): + """测试获取项目结构,模拟 LLM 响应。""" + generator.readme_content = "# Test README" + mock_response = { + "files": ["src/__init__.py", "src/core.py"], + "dependencies": {"src/core.py": ["src/__init__.py"]} + } + mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response) + + files, dependencies = generator.get_project_structure() + assert files == ["src/__init__.py", "src/core.py"] + assert dependencies == {"src/core.py": ["src/__init__.py"]} + mock_openai_client.chat.completions.create.assert_called_once() + + def test_generate_file(self, generator, mock_openai_client, tmp_path): + """测试生成单个文件,模拟依赖文件和 LLM 响应。""" + generator.readme_content = "# Test README" + dep_file = tmp_path / "dep.txt" + dep_file.write_text("Dependency content") + + mock_response = { + "code": "print('Hello, World!')", + "description": "测试文件生成", + "commands": ["echo 'test'"] + } + mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response) + + code, desc, commands = generator.generate_file( + "test.py", + "生成测试文件", + [str(dep_file)] + ) + assert code == "print('Hello, World!')" + assert desc == "测试文件生成" + assert commands == ["echo 'test'"] + mock_openai_client.chat.completions.create.assert_called_once() + + def test_execute_command_safe(self, generator, tmp_path): + """测试执行安全命令。""" + with patch('subprocess.run') as mock_run: + mock_run.return_value.returncode = 0 + mock_run.return_value.stdout = "output" + mock_run.return_value.stderr = "" + + generator.execute_command("echo 'test'", cwd=tmp_path) + mock_run.assert_called_once_with( + "echo 'test'", + shell=True, + cwd=tmp_path, + capture_output=True, + text=True, + timeout=300 + ) + + def test_execute_command_dangerous(self, generator): + """测试阻止危险命令。""" + with pytest.raises(RuntimeError, match="危险命令"): + generator.execute_command("rm -rf /") # 假设在危险命令列表中 + + def test_run_without_resume(self, generator, mock_openai_client, tmp_path): + """测试完整运行流程,禁用断点续写。""" + readme_path = tmp_path / "README.md" + readme_path.write_text("# Test README") + generator.readme_content = "# Test README" + + # 模拟 get_project_structure 响应 + mock_structure = { + "files": ["file1.py", "file2.py"], + "dependencies": {} + } + mock_openai_client.chat.completions.create.side_effect = [ + Mock(choices=[Mock(message=Mock(content=json.dumps(mock_structure)))]), + Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code1", "description": "desc1", "commands": []})))]), + Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code2", "description": "desc2", "commands": []})))]) + ] + + with patch('src.llm_codegen.core.safe_write_file') as mock_write, \ + patch('src.llm_codegen.core.safe_read_file') as mock_read, \ + patch('src.llm_codegen.core.save_state') as mock_save: + mock_read.return_value = "content" + generator.run(readme_path) + + # 验证文件生成和状态保存 + assert mock_write.call_count == 2 + assert mock_save.called + + def test_run_with_resume(self, generator, mock_openai_client, tmp_path): + """测试断点续写功能。""" + generator.resume = True + generator.state = GeneratorState(generated_files=["file1.py"], executed_commands=[]) + readme_path = tmp_path / "README.md" + readme_path.write_text("# Test README") + generator.readme_content = "# Test README" + + mock_structure = { + "files": ["file1.py", "file2.py"], + "dependencies": {} + } + mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_structure) + + with patch('src.llm_codegen.core.safe_write_file') as mock_write, \ + patch('src.llm_codegen.core.safe_read_file') as mock_read: + mock_read.return_value = "content" + generator.run(readme_path) + + # 只应生成 file2.py,跳过 file1.py + assert mock_write.call_count == 1 + + def test_load_config_default(self, generator): + """测试加载默认配置。""" + config = generator._load_config() + assert isinstance(config, ConfigModel) + assert config.check_tools == ["pytest", "pylint", "mypy", "black"] + assert config.max_retries == 3 + + def test_update_state(self, generator, tmp_path): + """测试更新状态文件。""" + generator.state_path = tmp_path / "state.json" + generator.state = GeneratorState() + + with patch('src.llm_codegen.core.save_state') as mock_save: + generator._update_state("new_file.py", ["cmd1"]) + assert generator.state.generated_files == ["new_file.py"] + assert generator.state.executed_commands == ["cmd1"] + mock_save.assert_called_once() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])