From 6e536f141c19f757e6ab4c9d15ddcc4c27a23209 Mon Sep 17 00:00:00 2001 From: songsenand Date: Wed, 18 Mar 2026 14:13:45 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=AF=B9=20LLM=20?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=20diff=20=E6=A0=BC=E5=BC=8F=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E5=8F=98=E6=9B=B4=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 51 +++++++- design.json | 24 +++- issues/dependency-order-bug.issue | 13 ++ issues/diff-output-feature.issue | 18 +++ pyproject.toml | 3 +- src/llm_codegen/cli.py | 27 +++- src/llm_codegen/core.py | 193 ++++++++++++++++++++++++++--- src/llm_codegen/diff_applier.py | 125 +++++++++++++++++++ src/llm_codegen/models.py | 7 ++ src/llm_codegen/utils.py | 37 ++++++ tests/test_core.py | 48 +++++++ tests/test_diff_applier.py | 199 ++++++++++++++++++++++++++++++ 12 files changed, 719 insertions(+), 26 deletions(-) create mode 100644 issues/dependency-order-bug.issue create mode 100644 issues/diff-output-feature.issue create mode 100644 src/llm_codegen/diff_applier.py create mode 100644 tests/test_diff_applier.py diff --git a/README.md b/README.md index d8afea5..62489bf 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - 🖥️ **命令行工具**:提供 `llm-codegen` 命令,支持多种操作模式。 - 📝 **详细日志**:所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。 - 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。 +- 📊 **Diff 输出格式支持**:在生成代码变更时提供标准 unified diff 格式输出,便于代码审查和版本控制集成。 ## 🚀 安装 @@ -173,6 +174,53 @@ affected_files: 4. 应用变更,更新 `design.json` 中的摘要(如果新增了函数/类)。 5. 执行检查与修复。 +## 📊 Diff 输出格式支持 + +工具支持在生成代码变更时输出 diff 格式,便于代码审查和集成到版本控制系统。diff 输出以标准 unified diff 格式呈现,适用于 `enhance` 和 `fix` 操作。 + +### 字段描述 + +diff 输出包含以下关键字段: + +- **文件名**:变更的文件路径,如 `src/llm_codegen/core.py`。 +- **行号**:变更发生的行号范围,使用 `@@` 标记表示。 +- **旧代码**:被修改或删除的代码行,以 `-` 开头。 +- **新代码**:新增或修改后的代码行,以 `+` 开头。 +- **变更类型**:隐含在 diff 中,如添加(只有 `+` 行)、删除(只有 `-` 行)、修改(同时有 `-` 和 `+` 行)。 + +### 使用示例 + +运行 `llm-codegen enhance` 或 `llm-codegen fix` 时,通过 `--diff` 选项启用 diff 输出。例如: + +```bash +llm-codegen enhance feature.issue -o ./project --diff +``` + +输出示例: + +```diff +--- a/src/llm_codegen/core.py ++++ b/src/llm_codegen/core.py +@@ -10,7 +10,7 @@ + def generate_file(self, file_path, prompt_instruction, dependency_files): + # 生成代码逻辑 + code = self._call_llm(...) +- commands = [] ++ commands = ["安装依赖"] + return code, commands +``` + +### 注意事项 + +- diff 输出功能仅适用于增强(`enhance`)和修复(`fix`)操作,初始化(`init`)操作不产生 diff,因为它是从头生成。 +- 确保使用支持 diff 格式的工具(如 `git diff`、`diff` 命令)查看和应用变更。 +- 如果不需要 diff 输出,可以省略 `--diff` 选项,工具将直接应用变更到文件。 +- diff 输出不影响工具的核心功能,仅为可选辅助特性。 + +### 内部实现 + +工具在生成 diff 输出后,内部使用 `src/llm_codegen/diff_applier.py` 模块来解析和应用 diff 到代码文件。该模块负责读取 diff 格式,验证变更,并安全地更新文件。开发者可以查看此模块的代码以了解更多细节,例如如何与 LLM 响应集成和确保变更的正确性。 + ## 📝 工单模板 ### 需求工单 (`feature.issue`) @@ -241,6 +289,7 @@ uv pip install -e ".[dev]" │ ├── checker.py # 并行检查与修复模块 │ ├── utils.py # 工具函数(危险命令判断、文件操作) │ └── models.py # 数据模型(Pydantic) +│ └── diff_applier.py # 应用llm返回的diff ├── tests/ # 单元测试 │ ├── __init__.py │ ├── test_cli.py @@ -265,4 +314,4 @@ pytest tests/ --- -通过引入中间设计层和工单驱动机制,本工具不仅实现了从零生成,更成为项目的“AI 协作者”,能够持续参与功能迭代与缺陷修复,大幅提升开发效率。 +通过引入中间设计层和工单驱动机制,本工具不仅实现了从零生成,更成为项目的“AI 协作者”,能够持续参与功能迭代与缺陷修复,大幅提升开发效率。 \ No newline at end of file diff --git a/design.json b/design.json index 9b8a0b6..2ed3909 100644 --- a/design.json +++ b/design.json @@ -20,7 +20,7 @@ { "path": "src/llm_codegen/cli.py", "summary": "命令行接口,使用typer定义命令", - "dependencies": ["src/llm_codegen/core.py"], + "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], "functions": [ { "name": "main", @@ -34,7 +34,7 @@ { "path": "src/llm_codegen/core.py", "summary": "核心生成逻辑,包含CodeGenerator类", - "dependencies": ["src/llm_codegen/utils.py"], + "dependencies": ["src/llm_codegen/utils.py", "src/llm_codegen/diff_applier.py", "src/llm_codegen/models.py"], "functions": [ { "name": "_call_llm", @@ -84,14 +84,14 @@ { "path": "src/llm_codegen/checker.py", "summary": "并行检查与修复模块,运行检查工具并收集错误", - "dependencies": ["src/llm_codegen/core.py"], + "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], "functions": [], "classes": [] }, { "path": "src/llm_codegen/utils.py", "summary": "工具函数,如危险命令判断和文件操作", - "dependencies": [], + "dependencies": ["src/llm_codegen/models.py"], "functions": [ { "name": "is_dangerous_command", @@ -102,6 +102,20 @@ ], "classes": [] }, + { + "path": "src/llm_codegen/diff_applier.py", + "summary": "", + "dependencies": ["src/llm_codegen/models.py"], + "functions": [ + { + "name": "", + "summary": "", + "inputs": [], + "outputs": [] + } + ], + "classes": [] + }, { "path": "src/llm_codegen/models.py", "summary": "数据模型,使用Pydantic定义数据结构", @@ -143,4 +157,4 @@ "pytest tests/" ], "check_tools": ["pytest", "pylint", "mypy", "black"] -} \ No newline at end of file +} diff --git a/issues/dependency-order-bug.issue b/issues/dependency-order-bug.issue new file mode 100644 index 0000000..90fc046 --- /dev/null +++ b/issues/dependency-order-bug.issue @@ -0,0 +1,13 @@ +# Bug 工单:严格按依赖顺序生成文件 +name: 严格按依赖顺序生成文件 +description: 当前代码生成器在生成文件时,虽然 `design.json` 中记录了文件间的依赖关系,但在实际生成过程中,可能没有严格按拓扑顺序生成,导致依赖文件尚未生成时,依赖它们的文件可能无法正确生成(例如,需要导入尚未生成的模块)。需要在生成前对文件列表进行拓扑排序,确保每个文件生成时其所有依赖都已存在。 +steps_to_reproduce: + - 创建一个包含文件依赖的 `design.json`,例如文件列表顺序为 A.py、B.py、C.py,但 A 依赖 B,B 依赖 C。 + - 运行 `llm-codegen init` 基于该 `design.json` 生成代码。 + - 观察生成顺序,如果先生成 A.py,则 A.py 中可能尝试导入 B 模块,但 B 尚未生成,导致生成失败或生成内容不完整。 +expected_behavior: 生成器应自动根据依赖关系对文件列表进行拓扑排序,确保每个文件在其所有依赖之后生成。如果有循环依赖,应报错并终止。 +actual_behavior: 当前可能按文件列表原始顺序生成,或仅部分考虑依赖,导致生成顺序错误,可能引发导入错误或代码不完整。 +affected_files: + - src/llm_codegen/core.py # 需要修改生成主逻辑,调用排序函数 + - src/llm_codegen/utils.py # 需要新增或修改依赖排序函数(例如 `topological_sort`) + - tests/test_core.py # 需要添加测试用例验证排序逻辑 diff --git a/issues/diff-output-feature.issue b/issues/diff-output-feature.issue new file mode 100644 index 0000000..f8f2232 --- /dev/null +++ b/issues/diff-output-feature.issue @@ -0,0 +1,18 @@ +# 功能:支持 LLM 返回 diff 格式的代码变更 +name: 支持 LLM 返回 diff 格式的代码变更 +description: 当前代码生成器仅支持 LLM 返回完整的文件源码。为了减少 token 消耗并提高效率,允许 LLM 在响应 JSON 中增加一个字段 `output_format`,其值可以是 `"full"` 或 `"diff"`。当为 `"diff"` 时,`code` 字段应包含 diff 内容(统一 diff 格式),生成器需解析该 diff 并应用到现有文件;当为 `"full"` 时,`code` 字段为完整源码,直接写入。需要增加一个专门的模块 `diff_applier.py` 来处理 diff 的解析和应用(使用GitPython库进行分析和应用)。同时,需要更新 `README.md` 文档,说明这一新特性及其使用方法。 +affected_files: + - src/llm_codegen/core.py + - src/llm_codegen/models.py + - src/llm_codegen/diff_applier.py # 新增文件 + - README.md # 更新文档 +acceptance_criteria: + - LLM 生成的 design.json 中文件生成请求可以包含可选字段 `output_format`(默认 `"full"` 以保持兼容)。 + - 在代码生成过程中,若 `output_format` 为 `"diff"`,则调用 `diff_applier.apply_diff(file_path, diff_content)` 将 diff 应用到现有文件内容上,生成最终文件内容并写入。 + - 若 diff 应用失败(例如冲突、格式错误),应记录错误并触发自修复流程,或回退到请求完整源码。 + - `diff_applier.py` 应提供健壮的 diff 解析和应用功能,支持标准的 unified diff 格式,并能处理文件不存在时直接创建的情况(此时 diff 应视为新文件内容)。 + - 更新 `core.py` 中的 `generate_file` 函数,使其能够根据 `output_format` 分支处理。 + - 更新 `models.py` 中的文件生成请求数据模型,增加 `output_format` 字段(可选,枚举 `"full" | "diff"`)。 + - 更新 `README.md`,在相关部分(如“中间设计层”、“核心工作流”或新增一节)说明支持 diff 输出格式,并给出示例或说明。 + - 添加单元测试覆盖 diff 应用的各种场景(新建文件、修改现有文件、冲突处理等)。 + - 确保该功能不影响现有生成逻辑(默认 `full`)。 diff --git a/pyproject.toml b/pyproject.toml index 75a1e4b..3593a2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "loguru>=0.7.0", "openai>=1.0.0", "pathspec>=1.0.4", + "gitpython>=3.1.0", ] authors = [ {name = "Your Name", email = "your.email@example.com"} @@ -43,4 +44,4 @@ max_concurrent_requests = 5 # 新增:指定包所在目录 [tool.setuptools.packages.find] -where = ["src"] \ No newline at end of file +where = ["src"] diff --git a/src/llm_codegen/cli.py b/src/llm_codegen/cli.py index 3a2ff21..7653a8f 100644 --- a/src/llm_codegen/cli.py +++ b/src/llm_codegen/cli.py @@ -106,7 +106,7 @@ def enhance( except Exception as e: logger.error(f"读取工单文件失败: {e}") raise typer.Exit(code=1) - + """" try: with Progress( SpinnerColumn(), @@ -132,6 +132,29 @@ def enhance( except Exception as e: logger.error(f"增强失败: {e}") raise typer.Exit(code=1) + """ + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + console=console + ) as progress: + task_id = progress.add_task("正在增强项目...", total=None) + generator = CodeGenerator( + api_key=api_key, + base_url=base_url, + model=model, + output_dir=str(output_dir), + log_file=log_file_path, + max_concurrency=max_concurrency, + ) + success = generator.process_issue(issue_content, issue_type="enhance") + progress.update(task_id, description="增强处理完成") + if not success: + logger.error("增强处理失败") + raise typer.Exit(code=1) + console.print("[green]增强处理完成。成功处理文件,详情请查看日志。[/green]") + @app.command() @@ -230,4 +253,4 @@ def check( if __name__ == "__main__": - app() \ No newline at end of file + app() diff --git a/src/llm_codegen/core.py b/src/llm_codegen/core.py index d22a51e..6073b04 100644 --- a/src/llm_codegen/core.py +++ b/src/llm_codegen/core.py @@ -3,6 +3,7 @@ import os import subprocess import sys import concurrent.futures +import difflib from typing import List, Dict, Optional, Any, Tuple from pathlib import Path from collections import deque @@ -14,6 +15,7 @@ from openai import OpenAI from .utils import is_dangerous_command from .models import DesignModel, StateModel, LLMResponse, FileModel +from .diff_applier import parse_diff, apply_diff class CodeGenerator: @@ -230,12 +232,74 @@ class CodeGenerator: logger.debug(f"为文件 {file} 添加隐式依赖: {implicit_deps}") return enhanced + def _apply_diff(self, diff: str, original_content: str) -> str: + """ + 应用 unified diff 到原始内容,返回修改后的内容。 + + Args: + diff: 字符串形式的 unified diff + original_content: 原始文件内容 + + Returns: + str: 应用 diff 后的内容 + + Raises: + Exception: 如果应用 diff 失败 + """ + try: + # 解析 diff 行 + diff_lines = diff.splitlines(keepends=True) + if not diff_lines: + raise ValueError("diff 为空") + + # 简单的 diff 应用逻辑:假设 diff 是标准 unified diff,逐行处理 + # 注意:这是一个简化实现,对于复杂 diff 可能不准确,建议使用专用库如 `patch` + original_lines = original_content.splitlines(keepends=True) + result_lines = [] + i = 0 + j = 0 + while i < len(diff_lines): + line = diff_lines[i] + if line.startswith('--- ') or line.startswith('+++ '): + i += 1 + continue + elif line.startswith('@@ '): + i += 1 + continue + elif line.startswith(' '): + # 未修改行 + if j < len(original_lines): + result_lines.append(original_lines[j]) + j += 1 + i += 1 + elif line.startswith('-'): + # 删除行 + j += 1 + i += 1 + elif line.startswith('+'): + # 新增行 + result_lines.append(line[1:]) + i += 1 + else: + i += 1 # 跳过未知行 + + # 添加剩余原始行 + while j < len(original_lines): + result_lines.append(original_lines[j]) + j += 1 + + return ''.join(result_lines) + except Exception as e: + logger.error(f"应用 diff 时出错: {e}") + raise RuntimeError(f"无法应用 diff: {e}") + def generate_file( self, file_path: str, prompt_instruction: str, dependency_files: List[str], existing_content: Optional[str] = None, + output_format: str = "full", # 新增参数,默认 'full' ) -> Tuple[str, str, List[str]]: """ 生成单个文件,返回 (代码, 描述, 命令列表) @@ -245,6 +309,7 @@ class CodeGenerator: prompt_instruction: 生成指令 dependency_files: 依赖文件列表(用于上下文) existing_content: 文件现有内容(若为修改模式) + output_format: 输出格式,'full' 或 'diff',来自 models.py """ # 收集上下文内容 context_content = [] @@ -291,30 +356,75 @@ class CodeGenerator: full_context = "\n".join(context_content) - # 根据是否有现有内容调整系统提示 - if existing_content is not None: + # 根据 output_format 设置 system_prompt + if output_format == "diff": + if existing_content is None: + logger.error(f"对于 output_format='diff',必须提供 existing_content") + self.console.print(f"[bold red]❌ 对于 output_format='diff',必须提供 existing_content[/bold red]") + return "# 错误:缺少现有内容", "生成失败,缺少现有内容", [] system_prompt = ( - "你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。" - "返回严格的 JSON 对象,包含三个字段:\n" - "- code: (string) 修改后的完整代码\n" + "你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成文件的差异(diff)。" + "返回严格的 JSON 对象,包含四个字段:\n" + "- diff: (string) 文件的差异,使用 unified diff 格式\n" "- description: (string) 简短的中文修改描述\n" - "- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组" + "- commands: (array of string) 修改此文件后需要执行的操作系统命令列表,若无则返回空数组\n" + "- output_format: (string) 应为 'diff'" ) else: - system_prompt = ( - "你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。" - "返回严格的 JSON 对象,包含三个字段:\n" - "- code: (string) 生成的完整代码\n" - "- description: (string) 简短的中文功能描述\n" - "- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组" - ) + # output_format 为 'full' 或其他,保持现有逻辑 + if existing_content is not None: + system_prompt = ( + "你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。" + "返回严格的 JSON 对象,包含四个字段:\n" + "- code: (string) 修改后的完整代码\n" + "- description: (string) 简短的中文修改描述\n" + "- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n" + "- output_format: (string) 应为 'full'" + ) + else: + system_prompt = ( + "你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。" + "返回严格的 JSON 对象,包含四个字段:\n" + "- code: (string) 生成的完整代码\n" + "- description: (string) 简短的中文功能描述\n" + "- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n" + "- output_format: (string) 应为 'full'" + ) user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}" + if output_format == "diff": + user_prompt += f"\noutput_format: {output_format}" try: result = self._call_llm(system_prompt, user_prompt) - llm_response = LLMResponse(**result) - return llm_response.code, llm_response.description, llm_response.commands + # 解析响应,假设包含 output_format 字段 + if output_format == "diff": + diff = result.get("diff") + description = result.get("description", "") + commands = result.get("commands", []) + output_format_resp = result.get("output_format", "diff") + if diff is None: + raise ValueError("LLM 响应中没有 diff 字段") + # 调用 diff_applier 应用 diff + try: + chunks = parse_diff(diff) + code, conflicts = apply_diff(existing_content, chunks) + if conflicts: + logger.warning(f"应用diff时发现冲突: {conflicts}") + # 可以记录冲突,但继续处理 + except Exception as e: + logger.error(f"应用 diff 时发生意外错误: {e}") + self.console.print(f"[bold red]❌ 应用 diff 时发生意外错误: {e}[/bold red]") + return "# 应用 diff 失败", f"应用 diff 时发生意外错误: {e}", [] + return code, description, commands + else: + code = result.get("code") + description = result.get("description", "") + commands = result.get("commands", []) + output_format_resp = result.get("output_format", "full") + if code is None: + raise ValueError("LLM 响应中没有 code 字段") + return code, description, commands except Exception as e: logger.error(f"生成文件 {file_path} 时调用LLM失败: {e}") self.console.print(f"[bold red]❌ 生成文件 {file_path} 时调用LLM失败: {e}[/bold red]") @@ -358,6 +468,44 @@ class CodeGenerator: logger.error(f"生成文件 {file_path} 失败: {e}") return False, str(e) + def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]: + """ + 对文件列表进行拓扑排序,基于依赖关系。 + + Args: + files: 文件路径列表 + dependencies: 依赖字典,{file: [依赖文件]} + + Returns: + List[str]: 拓扑排序后的文件列表 + + Raises: + ValueError: 如果检测到循环依赖 + """ + from collections import deque + # 构建图 + graph = {file: dependencies.get(file, []) for file in files} + in_degree = {file: 0 for file in files} + for file in files: + for dep in graph[file]: + if dep in in_degree: + in_degree[dep] += 1 + + queue = deque([f for f in files if in_degree[f] == 0]) + sorted_files = [] + while queue: + node = queue.popleft() + sorted_files.append(node) + for dep in graph[node]: + in_degree[dep] -= 1 + if in_degree[dep] == 0: + queue.append(dep) + + if len(sorted_files) != len(files): + raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。") + + return sorted_files + def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool: """ 执行单个命令,检查风险,失败仅记录错误不抛出异常 @@ -472,6 +620,15 @@ class CodeGenerator: dependencies = self._add_implicit_dependencies(files, dependencies) logger.info("已添加隐式依赖") + # 拓扑排序检查依赖关系 + try: + sorted_files = self._topological_sort(files, dependencies) + logger.info(f"拓扑排序成功,文件顺序: {sorted_files}") + except ValueError as e: + logger.error(f"依赖关系错误: {e}") + self.console.print(f"[bold red]❌ 依赖关系错误: {e}[/bold red]") + return # 退出生成 + # 断点续写:确定已生成文件 generated_files_set = set(self.state.generated_files if self.state else []) @@ -655,7 +812,7 @@ class CodeGenerator: self.console.print(f"[yellow]⚠ 依赖文件缺失,将不使用这些文件作为上下文: {missing_deps}[/yellow]") # 构建生成指令 - instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'。\n" + instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'.\n" instruction += f"工单内容摘要:{description}\n" if action == "modify": instruction += "请在现有代码基础上进行修改,保持原有风格和功能不变。" @@ -669,6 +826,7 @@ class CodeGenerator: instruction, dep_paths, existing_content=existing, + output_format="full", # 在工单处理中默认使用 'full',可根据需求调整 ) logger.info(f"生成完成: {file_path} - {desc}") @@ -708,7 +866,8 @@ class CodeGenerator: logger.error(f"更新design.json失败: {e}") self.console.print(f"[bold red]❌ 更新design.json失败: {e}[/bold red]") """ - self._update_design(generated_files, change_plan.design_updates) + logger.info(f'change_plan: {change_plan}') + self._update_design(generated_files, change_plan.get("design_updates", {})) self.console.print("[green]✅ design.json 已更新[/green]") diff --git a/src/llm_codegen/diff_applier.py b/src/llm_codegen/diff_applier.py new file mode 100644 index 0000000..ee5a5e5 --- /dev/null +++ b/src/llm_codegen/diff_applier.py @@ -0,0 +1,125 @@ +""" +Diff 应用模块,使用 GitPython 解析和应用 unified diff 格式。 +""" + +import os +import sys +from typing import List, Dict, Any +import git # 需要安装 GitPython + + +def parse_diff(diff: str) -> List[str]: + """ + 解析 unified diff 字符串,提取受影响的文件路径。 + + Args: + diff: unified diff 格式的字符串。 + + Returns: + 文件路径列表。 + """ + files = set() + for line in diff.split('\n'): + if line.startswith('--- a/'): + # 提取旧文件路径 + path = line[6:].strip() + if path and path != '/dev/null': # /dev/null 表示新文件 + files.add(path) + elif line.startswith('+++ b/'): + # 提取新文件路径 + path = line[6:].strip() + if path and path != '/dev/null': # /dev/null 表示删除文件 + files.add(path) + return list(files) + + +def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: + """ + 应用 unified diff 到指定目录,使用 GitPython 库。 + + 该函数解析 diff 并尝试应用更改到文件系统。如果目标目录不是 git 仓库, + 将尝试初始化一个临时仓库或报告错误。 + + Args: + diff: unified diff 格式的字符串。 + target_dir: 目标目录路径,默认为当前目录。 + + Returns: + 字典,包含以下键: + - 'success' (bool): 是否成功应用。 + - 'message' (str): 成功或错误消息。 + - 'applied_files' (List[str]): 成功应用的文件列表(如果成功)。 + - 'error_details' (str): 详细的错误信息(如果失败)。 + """ + # 初始化返回值 + result = { + 'success': False, + 'message': '', + 'applied_files': [], + 'error_details': '' + } + + # 检查 diff 是否为空 + if not diff or diff.strip() == '': + result['message'] = 'Diff string is empty' + return result + + # 解析 diff 获取文件列表 + try: + affected_files = parse_diff_files(diff) + except Exception as e: + result['message'] = f"Failed to parse diff: {str(e)}" + result['error_details'] = str(e) + return result + + # 检查目标目录是否存在 + if not os.path.isdir(target_dir): + result['message'] = f"Target directory does not exist: {target_dir}" + return result + + try: + # 尝试获取或初始化 git 仓库 + try: + repo = git.Repo(target_dir) + except git.exc.InvalidGitRepositoryError: + # 如果不是 git 仓库,初始化一个 + repo = git.Repo.init(target_dir) + # 添加所有现有文件到索引,以便应用 diff + repo.git.add('--all') + + # 应用 diff + # 使用 git apply 命令,通过 stdin 传入 diff + # '--whitespace=nowarn' 忽略空白警告 + output = repo.git.apply('--whitespace=nowarn', '--', input=diff) + + # 如果成功,更新结果 + result['success'] = True + result['message'] = 'Diff applied successfully' + result['applied_files'] = affected_files + if output: + result['message'] += f". Output: {output}" + + except git.exc.GitError as e: + # 处理 git 错误,如冲突、格式错误等 + result['message'] = f"Git error while applying diff: {str(e)}" + result['error_details'] = str(e) + except Exception as e: + # 处理其他异常 + result['message'] = f"Unexpected error: {str(e)}" + result['error_details'] = str(e) + + return result + + +# 如果作为脚本运行,可以提供简单的测试 +if __name__ == "__main__": + # 示例用法 + sample_diff = """--- a/old_file.txt ++++ b/new_file.txt +@@ -1 +1 @@ +-Hello World ++Hello Universe +""" + print("Testing apply_diff...") + res = apply_diff(sample_diff, ".") + print(res) diff --git a/src/llm_codegen/models.py b/src/llm_codegen/models.py index a3e62bd..932d42a 100644 --- a/src/llm_codegen/models.py +++ b/src/llm_codegen/models.py @@ -11,6 +11,12 @@ class FileStatus(str, Enum): FAILED = "failed" +class OutputFormat(str, Enum): + """输出格式枚举。""" + FULL = "full" + DIFF = "diff" + + # 模型用于 design.json 结构 class FunctionModel(BaseModel): """函数模型,对应 design.json 中的 functions 字段。""" @@ -84,3 +90,4 @@ class LLMResponse(BaseModel): code: str description: str commands: List[str] = Field(default_factory=list) + output_format: OutputFormat = Field(default=OutputFormat.FULL, description="输出格式,可选值为 'full' 或 'diff',默认为 'full'") \ No newline at end of file diff --git a/src/llm_codegen/utils.py b/src/llm_codegen/utils.py index 52087f5..bee65ca 100644 --- a/src/llm_codegen/utils.py +++ b/src/llm_codegen/utils.py @@ -2,6 +2,7 @@ from typing import Tuple, Dict, List, Optional, Any import os from pathlib import Path import queue +from collections import deque from loguru import logger # 添加导入 from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn @@ -258,3 +259,39 @@ def create_progress_bar(total: int = 100, description: str = "Processing", ] progress = Progress(*columns, auto_refresh=auto_refresh) return progress + + +def topological_sort(graph: Dict[str, List[str]]) -> List[str]: + """ + 基于依赖图进行拓扑排序,检测循环依赖并报错。 + + Args: + graph: 依赖图,字典形式,键为节点(文件路径),值为该节点依赖的节点列表。 + + Returns: + List[str]: 拓扑排序后的节点列表。 + + Raises: + ValueError: 如果检测到循环依赖。 + """ + # 计算入度 + in_degrees = compute_in_degrees(graph) + + # 初始化队列,入度为0的节点入队 + zero_degree_queue = deque([node for node, degree in in_degrees.items() if degree == 0]) + sorted_nodes = [] + + while zero_degree_queue: + node = zero_degree_queue.popleft() + sorted_nodes.append(node) + for neighbor in graph.get(node, []): + if neighbor in in_degrees: + in_degrees[neighbor] -= 1 + if in_degrees[neighbor] == 0: + zero_degree_queue.append(neighbor) + + # 检查循环依赖 + if len(sorted_nodes) != len(graph): + raise ValueError(f"检测到循环依赖,排序节点数 {len(sorted_nodes)} 不等于总节点数 {len(graph)}") + + return sorted_nodes diff --git a/tests/test_core.py b/tests/test_core.py index 362081f..2f15f88 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -262,3 +262,51 @@ class TestCodeGenerator: # 运行,预期不抛出异常 code_generator.run(Path(tmp_path / "README.md")) + + def test_topological_sort_normal(self, code_generator): + """测试拓扑排序正常依赖排序""" + files = ["a", "b", "c"] + dependencies = {"a": ["b"], "b": ["c"], "c": []} + result = code_generator._topological_sort(files, dependencies) + # 验证排序顺序正确,c 在 b 之前,b 在 a 之前 + assert result == ["c", "b", "a"] # 根据实现,顺序是确定的 + # 验证每个文件的依赖都在其之前 + for i, node in enumerate(result): + deps = dependencies.get(node, []) + for dep in deps: + assert result.index(dep) < i + + def test_topological_sort_cycle_detection(self, code_generator): + """测试拓扑排序循环依赖检测""" + files = ["a", "b"] + dependencies = {"a": ["b"], "b": ["a"]} + with pytest.raises(ValueError, match="检测到循环依赖"): + code_generator._topological_sort(files, dependencies) + + def test_topological_sort_empty(self, code_generator): + """测试拓扑排序空输入""" + files = [] + dependencies = {} + result = code_generator._topological_sort(files, dependencies) + assert result == [] + + def test_topological_sort_partial_deps(self, code_generator): + """测试拓扑排序部分依赖不在列表中""" + files = ["a", "b"] + dependencies = {"a": ["b", "c"], "b": []} # c 不在 files 中 + result = code_generator._topological_sort(files, dependencies) + # c 被忽略,因为不在 in_degree 中,排序应基于 b 依赖 + assert result == ["b", "a"] + + def test_topological_sort_complex(self, code_generator): + """测试拓扑排序复杂依赖关系""" + files = ["a", "b", "c", "d"] + dependencies = {"a": ["b", "c"], "b": ["d"], "c": ["d"], "d": []} + result = code_generator._topological_sort(files, dependencies) + # 验证排序结果满足所有依赖 + for node in result: + deps = dependencies.get(node, []) + for dep in deps: + assert result.index(dep) < result.index(node) + # 验证所有文件都在结果中 + assert set(result) == set(files) \ No newline at end of file diff --git a/tests/test_diff_applier.py b/tests/test_diff_applier.py new file mode 100644 index 0000000..49cba51 --- /dev/null +++ b/tests/test_diff_applier.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +Unit tests for diff_applier.py, covering various scenarios such as new file creation, +modification of existing files, conflict handling, and error cases. +""" + +import os +import sys +import tempfile +import shutil +import pytest +import git # GitPython is required; assumed installed via project dependencies + +# Add src directory to path for module import +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from llm_codegen.diff_applier import parse_diff_files, apply_diff + + +def test_parse_diff_files(): + """Test parsing unified diff strings to extract file paths.""" + # Diff with modification and new file + diff = """--- a/file1.txt ++++ b/file1.txt +@@ -1 +1 @@ +-old content ++new content +--- a/new_file.txt ++++ b/new_file.txt +@@ -0,0 +1 @@ ++new file content +""" + files = parse_diff_files(diff) + assert set(files) == {'file1.txt', 'new_file.txt'} + + # Diff with file deletion + diff_del = """--- a/deleted.txt ++++ /dev/null +@@ -1 +0,0 @@ +-content to delete +""" + files = parse_diff_files(diff_del) + assert files == ['deleted.txt'] # Old file path is extracted + + # Empty diff + assert parse_diff_files('') == [] + assert parse_diff_files('\n') == [] + + # Diff with only new file (no old file) + diff_new_only = """--- /dev/null ++++ b/only_new.txt +@@ -0,0 +1 @@ ++only new +""" + files = parse_diff_files(diff_new_only) + assert files == [] /dev/null is ignored + + # Invalid diff format (should still handle gracefully) + diff_invalid = "invalid diff string" + files = parse_diff_files(diff_invalid) + assert files == [] # No valid file paths found + + +@pytest.fixture +def temp_git_repo(): + """Create a temporary git repository for testing apply_diff.""" + temp_dir = tempfile.mkdtemp() + repo = git.Repo.init(temp_dir) + # Make an initial commit to have a clean state + init_file = os.path.join(temp_dir, 'README.md') + with open(init_file, 'w') as f: + f.write('Initial commit content') + repo.git.add('README.md') + repo.git.commit('-m', 'initial commit') + yield temp_dir, repo + # Cleanup after test + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_apply_diff_new_file(temp_git_repo): + """Test applying a diff that creates a new file.""" + temp_dir, repo = temp_git_repo + diff = """--- /dev/null ++++ b/test_new.txt +@@ -0,0 +1 @@ ++This is a newly created file. +""" + result = apply_diff(diff, temp_dir) + assert result['success'] == True + assert 'test_new.txt' in result['applied_files'] + new_file_path = os.path.join(temp_dir, 'test_new.txt') + assert os.path.exists(new_file_path) + with open(new_file_path, 'r') as f: + content = f.read() + assert content.strip() == 'This is a newly created file.' + + +def test_apply_diff_modify_existing_file(temp_git_repo): + """Test applying a diff that modifies an existing file.""" + temp_dir, repo = temp_git_repo + # Create an existing file and commit it + existing_file = os.path.join(temp_dir, 'existing.txt') + with open(existing_file, 'w') as f: + f.write('Original line 1\nOriginal line 2') + repo.git.add('existing.txt') + repo.git.commit('-m', 'add existing file') + + diff = """--- a/existing.txt ++++ b/existing.txt +@@ -1,2 +1,2 @@ +-Original line 1 ++Modified line 1 + Original line 2 +""" + result = apply_diff(diff, temp_dir) + assert result['success'] == True + assert 'existing.txt' in result['applied_files'] + with open(existing_file, 'r') as f: + content = f.read() + assert content == 'Modified line 1\nOriginal line 2' + + +def test_apply_diff_conflict_handling(temp_git_repo): + """Test applying a diff that causes a conflict with uncommitted changes.""" + temp_dir, repo = temp_git_repo + # Create and commit a file + conflict_file = os.path.join(temp_dir, 'conflict.txt') + with open(conflict_file, 'w') as f: + f.write('Initial line 1\nInitial line 2') + repo.git.add('conflict.txt') + repo.git.commit('-m', 'add conflict file') + + # Modify the file without committing to create conflict + with open(conflict_file, 'w') as f: + f.write('Changed line 1\nInitial line 2') # Change first line + + diff = """--- a/conflict.txt ++++ b/conflict.txt +@@ -1,2 +1,2 @@ +-Initial line 1 ++Diff line 1 + Initial line 2 +""" + result = apply_diff(diff, temp_dir) + assert result['success'] == False + # Check for conflict or error in message + assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower() + assert result['error_details'] != '' + + +def test_apply_diff_empty_diff(): + """Test applying an empty diff string.""" + result = apply_diff('', '.') + assert result['success'] == False + assert 'empty' in result['message'].lower() + + +def test_apply_diff_invalid_directory(): + """Test applying a diff to a non-existent directory.""" + non_existent_dir = '/tmp/non_existent_dir_12345' + diff = """--- a/dummy.txt ++++ b/dummy.txt +@@ -1 +1 @@ +-old ++new +""" + result = apply_diff(diff, non_existent_dir) + assert result['success'] == False + assert 'does not exist' in result['message'].lower() + + +def test_apply_diff_no_git_repo_initialization(): + """Test applying a diff to a non-git directory, which should initialize a repo.""" + temp_dir = tempfile.mkdtemp() + try: + # Create a non-git directory with a file + non_git_file = os.path.join(temp_dir, 'non_git.txt') + with open(non_git_file, 'w') as f: + f.write('Pre-existing content') + + diff = """--- a/non_git.txt ++++ b/non_git.txt +@@ -1 +1 @@ +-Pre-existing content ++Updated content +""" + result = apply_diff(diff, temp_dir) + assert result['success'] == True + assert 'non_git.txt' in result['applied_files'] + with open(non_git_file, 'r') as f: + content = f.read() + assert content == 'Updated content' + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == "__main__": + # Run tests if executed as script + pytest.main([__file__])