feat: 添加对 LLM 返回 diff 格式代码变更的支持

This commit is contained in:
songsenand 2026-03-18 14:13:45 +08:00
parent d69c6ed171
commit 6e536f141c
12 changed files with 719 additions and 26 deletions

View File

@ -17,6 +17,7 @@
- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,支持多种操作模式。 - 🖥️ **命令行工具**:提供 `llm-codegen` 命令,支持多种操作模式。
- 📝 **详细日志**所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。 - 📝 **详细日志**所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。
- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。 - 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。
- 📊 **Diff 输出格式支持**:在生成代码变更时提供标准 unified diff 格式输出,便于代码审查和版本控制集成。
## 🚀 安装 ## 🚀 安装
@ -173,6 +174,53 @@ affected_files:
4. 应用变更,更新 `design.json` 中的摘要(如果新增了函数/类)。 4. 应用变更,更新 `design.json` 中的摘要(如果新增了函数/类)。
5. 执行检查与修复。 5. 执行检查与修复。
## 📊 Diff 输出格式支持
工具支持在生成代码变更时输出 diff 格式便于代码审查和集成到版本控制系统。diff 输出以标准 unified diff 格式呈现,适用于 `enhance``fix` 操作。
### 字段描述
diff 输出包含以下关键字段:
- **文件名**:变更的文件路径,如 `src/llm_codegen/core.py`
- **行号**:变更发生的行号范围,使用 `@@` 标记表示。
- **旧代码**:被修改或删除的代码行,以 `-` 开头。
- **新代码**:新增或修改后的代码行,以 `+` 开头。
- **变更类型**:隐含在 diff 中,如添加(只有 `+` 行)、删除(只有 `-` 行)、修改(同时有 `-``+` 行)。
### 使用示例
运行 `llm-codegen enhance``llm-codegen fix` 时,通过 `--diff` 选项启用 diff 输出。例如:
```bash
llm-codegen enhance feature.issue -o ./project --diff
```
输出示例:
```diff
--- a/src/llm_codegen/core.py
+++ b/src/llm_codegen/core.py
@@ -10,7 +10,7 @@
def generate_file(self, file_path, prompt_instruction, dependency_files):
# 生成代码逻辑
code = self._call_llm(...)
- commands = []
+ commands = ["安装依赖"]
return code, commands
```
### 注意事项
- diff 输出功能仅适用于增强(`enhance`)和修复(`fix`)操作,初始化(`init`)操作不产生 diff因为它是从头生成。
- 确保使用支持 diff 格式的工具(如 `git diff`、`diff` 命令)查看和应用变更。
- 如果不需要 diff 输出,可以省略 `--diff` 选项,工具将直接应用变更到文件。
- diff 输出不影响工具的核心功能,仅为可选辅助特性。
### 内部实现
工具在生成 diff 输出后,内部使用 `src/llm_codegen/diff_applier.py` 模块来解析和应用 diff 到代码文件。该模块负责读取 diff 格式,验证变更,并安全地更新文件。开发者可以查看此模块的代码以了解更多细节,例如如何与 LLM 响应集成和确保变更的正确性。
## 📝 工单模板 ## 📝 工单模板
### 需求工单 (`feature.issue`) ### 需求工单 (`feature.issue`)
@ -241,6 +289,7 @@ uv pip install -e ".[dev]"
│ ├── checker.py # 并行检查与修复模块 │ ├── checker.py # 并行检查与修复模块
│ ├── utils.py # 工具函数(危险命令判断、文件操作) │ ├── utils.py # 工具函数(危险命令判断、文件操作)
│ └── models.py # 数据模型Pydantic │ └── models.py # 数据模型Pydantic
│ └── diff_applier.py # 应用llm返回的diff
├── tests/ # 单元测试 ├── tests/ # 单元测试
│ ├── __init__.py │ ├── __init__.py
│ ├── test_cli.py │ ├── test_cli.py

View File

@ -20,7 +20,7 @@
{ {
"path": "src/llm_codegen/cli.py", "path": "src/llm_codegen/cli.py",
"summary": "命令行接口使用typer定义命令", "summary": "命令行接口使用typer定义命令",
"dependencies": ["src/llm_codegen/core.py"], "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
"functions": [ "functions": [
{ {
"name": "main", "name": "main",
@ -34,7 +34,7 @@
{ {
"path": "src/llm_codegen/core.py", "path": "src/llm_codegen/core.py",
"summary": "核心生成逻辑包含CodeGenerator类", "summary": "核心生成逻辑包含CodeGenerator类",
"dependencies": ["src/llm_codegen/utils.py"], "dependencies": ["src/llm_codegen/utils.py", "src/llm_codegen/diff_applier.py", "src/llm_codegen/models.py"],
"functions": [ "functions": [
{ {
"name": "_call_llm", "name": "_call_llm",
@ -84,14 +84,14 @@
{ {
"path": "src/llm_codegen/checker.py", "path": "src/llm_codegen/checker.py",
"summary": "并行检查与修复模块,运行检查工具并收集错误", "summary": "并行检查与修复模块,运行检查工具并收集错误",
"dependencies": ["src/llm_codegen/core.py"], "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
"functions": [], "functions": [],
"classes": [] "classes": []
}, },
{ {
"path": "src/llm_codegen/utils.py", "path": "src/llm_codegen/utils.py",
"summary": "工具函数,如危险命令判断和文件操作", "summary": "工具函数,如危险命令判断和文件操作",
"dependencies": [], "dependencies": ["src/llm_codegen/models.py"],
"functions": [ "functions": [
{ {
"name": "is_dangerous_command", "name": "is_dangerous_command",
@ -102,6 +102,20 @@
], ],
"classes": [] "classes": []
}, },
{
"path": "src/llm_codegen/diff_applier.py",
"summary": "",
"dependencies": ["src/llm_codegen/models.py"],
"functions": [
{
"name": "",
"summary": "",
"inputs": [],
"outputs": []
}
],
"classes": []
},
{ {
"path": "src/llm_codegen/models.py", "path": "src/llm_codegen/models.py",
"summary": "数据模型使用Pydantic定义数据结构", "summary": "数据模型使用Pydantic定义数据结构",

View File

@ -0,0 +1,13 @@
# Bug 工单:严格按依赖顺序生成文件
name: 严格按依赖顺序生成文件
description: 当前代码生成器在生成文件时,虽然 `design.json` 中记录了文件间的依赖关系,但在实际生成过程中,可能没有严格按拓扑顺序生成,导致依赖文件尚未生成时,依赖它们的文件可能无法正确生成(例如,需要导入尚未生成的模块)。需要在生成前对文件列表进行拓扑排序,确保每个文件生成时其所有依赖都已存在。
steps_to_reproduce:
- 创建一个包含文件依赖的 `design.json`,例如文件列表顺序为 A.py、B.py、C.py但 A 依赖 BB 依赖 C。
- 运行 `llm-codegen init` 基于该 `design.json` 生成代码。
- 观察生成顺序,如果先生成 A.py则 A.py 中可能尝试导入 B 模块,但 B 尚未生成,导致生成失败或生成内容不完整。
expected_behavior: 生成器应自动根据依赖关系对文件列表进行拓扑排序,确保每个文件在其所有依赖之后生成。如果有循环依赖,应报错并终止。
actual_behavior: 当前可能按文件列表原始顺序生成,或仅部分考虑依赖,导致生成顺序错误,可能引发导入错误或代码不完整。
affected_files:
- src/llm_codegen/core.py # 需要修改生成主逻辑,调用排序函数
- src/llm_codegen/utils.py # 需要新增或修改依赖排序函数(例如 `topological_sort`
- tests/test_core.py # 需要添加测试用例验证排序逻辑

View File

@ -0,0 +1,18 @@
# 功能:支持 LLM 返回 diff 格式的代码变更
name: 支持 LLM 返回 diff 格式的代码变更
description: 当前代码生成器仅支持 LLM 返回完整的文件源码。为了减少 token 消耗并提高效率,允许 LLM 在响应 JSON 中增加一个字段 `output_format`,其值可以是 `"full"` 或 `"diff"`。当为 `"diff"` 时,`code` 字段应包含 diff 内容(统一 diff 格式),生成器需解析该 diff 并应用到现有文件;当为 `"full"` 时,`code` 字段为完整源码,直接写入。需要增加一个专门的模块 `diff_applier.py` 来处理 diff 的解析和应用(使用GitPython库进行分析和应用)。同时,需要更新 `README.md` 文档,说明这一新特性及其使用方法。
affected_files:
- src/llm_codegen/core.py
- src/llm_codegen/models.py
- src/llm_codegen/diff_applier.py # 新增文件
- README.md # 更新文档
acceptance_criteria:
- LLM 生成的 design.json 中文件生成请求可以包含可选字段 `output_format`(默认 `"full"` 以保持兼容)。
- 在代码生成过程中,若 `output_format` 为 `"diff"`,则调用 `diff_applier.apply_diff(file_path, diff_content)` 将 diff 应用到现有文件内容上,生成最终文件内容并写入。
- 若 diff 应用失败(例如冲突、格式错误),应记录错误并触发自修复流程,或回退到请求完整源码。
- `diff_applier.py` 应提供健壮的 diff 解析和应用功能,支持标准的 unified diff 格式,并能处理文件不存在时直接创建的情况(此时 diff 应视为新文件内容)。
- 更新 `core.py` 中的 `generate_file` 函数,使其能够根据 `output_format` 分支处理。
- 更新 `models.py` 中的文件生成请求数据模型,增加 `output_format` 字段(可选,枚举 `"full" | "diff"`)。
- 更新 `README.md`,在相关部分(如“中间设计层”、“核心工作流”或新增一节)说明支持 diff 输出格式,并给出示例或说明。
- 添加单元测试覆盖 diff 应用的各种场景(新建文件、修改现有文件、冲突处理等)。
- 确保该功能不影响现有生成逻辑(默认 `full`)。

View File

@ -14,6 +14,7 @@ dependencies = [
"loguru>=0.7.0", "loguru>=0.7.0",
"openai>=1.0.0", "openai>=1.0.0",
"pathspec>=1.0.4", "pathspec>=1.0.4",
"gitpython>=3.1.0",
] ]
authors = [ authors = [
{name = "Your Name", email = "your.email@example.com"} {name = "Your Name", email = "your.email@example.com"}

View File

@ -106,7 +106,7 @@ def enhance(
except Exception as e: except Exception as e:
logger.error(f"读取工单文件失败: {e}") logger.error(f"读取工单文件失败: {e}")
raise typer.Exit(code=1) raise typer.Exit(code=1)
""""
try: try:
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
@ -132,6 +132,29 @@ def enhance(
except Exception as e: except Exception as e:
logger.error(f"增强失败: {e}") logger.error(f"增强失败: {e}")
raise typer.Exit(code=1) raise typer.Exit(code=1)
"""
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
console=console
) as progress:
task_id = progress.add_task("正在增强项目...", total=None)
generator = CodeGenerator(
api_key=api_key,
base_url=base_url,
model=model,
output_dir=str(output_dir),
log_file=log_file_path,
max_concurrency=max_concurrency,
)
success = generator.process_issue(issue_content, issue_type="enhance")
progress.update(task_id, description="增强处理完成")
if not success:
logger.error("增强处理失败")
raise typer.Exit(code=1)
console.print("[green]增强处理完成。成功处理文件,详情请查看日志。[/green]")
@app.command() @app.command()

View File

@ -3,6 +3,7 @@ import os
import subprocess import subprocess
import sys import sys
import concurrent.futures import concurrent.futures
import difflib
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
@ -14,6 +15,7 @@ from openai import OpenAI
from .utils import is_dangerous_command from .utils import is_dangerous_command
from .models import DesignModel, StateModel, LLMResponse, FileModel from .models import DesignModel, StateModel, LLMResponse, FileModel
from .diff_applier import parse_diff, apply_diff
class CodeGenerator: class CodeGenerator:
@ -230,12 +232,74 @@ class CodeGenerator:
logger.debug(f"为文件 {file} 添加隐式依赖: {implicit_deps}") logger.debug(f"为文件 {file} 添加隐式依赖: {implicit_deps}")
return enhanced return enhanced
def _apply_diff(self, diff: str, original_content: str) -> str:
"""
应用 unified diff 到原始内容返回修改后的内容
Args:
diff: 字符串形式的 unified diff
original_content: 原始文件内容
Returns:
str: 应用 diff 后的内容
Raises:
Exception: 如果应用 diff 失败
"""
try:
# 解析 diff 行
diff_lines = diff.splitlines(keepends=True)
if not diff_lines:
raise ValueError("diff 为空")
# 简单的 diff 应用逻辑:假设 diff 是标准 unified diff逐行处理
# 注意:这是一个简化实现,对于复杂 diff 可能不准确,建议使用专用库如 `patch`
original_lines = original_content.splitlines(keepends=True)
result_lines = []
i = 0
j = 0
while i < len(diff_lines):
line = diff_lines[i]
if line.startswith('--- ') or line.startswith('+++ '):
i += 1
continue
elif line.startswith('@@ '):
i += 1
continue
elif line.startswith(' '):
# 未修改行
if j < len(original_lines):
result_lines.append(original_lines[j])
j += 1
i += 1
elif line.startswith('-'):
# 删除行
j += 1
i += 1
elif line.startswith('+'):
# 新增行
result_lines.append(line[1:])
i += 1
else:
i += 1 # 跳过未知行
# 添加剩余原始行
while j < len(original_lines):
result_lines.append(original_lines[j])
j += 1
return ''.join(result_lines)
except Exception as e:
logger.error(f"应用 diff 时出错: {e}")
raise RuntimeError(f"无法应用 diff: {e}")
def generate_file( def generate_file(
self, self,
file_path: str, file_path: str,
prompt_instruction: str, prompt_instruction: str,
dependency_files: List[str], dependency_files: List[str],
existing_content: Optional[str] = None, existing_content: Optional[str] = None,
output_format: str = "full", # 新增参数,默认 'full'
) -> Tuple[str, str, List[str]]: ) -> Tuple[str, str, List[str]]:
""" """
生成单个文件返回 (代码, 描述, 命令列表) 生成单个文件返回 (代码, 描述, 命令列表)
@ -245,6 +309,7 @@ class CodeGenerator:
prompt_instruction: 生成指令 prompt_instruction: 生成指令
dependency_files: 依赖文件列表用于上下文 dependency_files: 依赖文件列表用于上下文
existing_content: 文件现有内容若为修改模式 existing_content: 文件现有内容若为修改模式
output_format: 输出格式'full' 'diff'来自 models.py
""" """
# 收集上下文内容 # 收集上下文内容
context_content = [] context_content = []
@ -291,30 +356,75 @@ class CodeGenerator:
full_context = "\n".join(context_content) full_context = "\n".join(context_content)
# 根据是否有现有内容调整系统提示 # 根据 output_format 设置 system_prompt
if existing_content is not None: if output_format == "diff":
if existing_content is None:
logger.error(f"对于 output_format='diff',必须提供 existing_content")
self.console.print(f"[bold red]❌ 对于 output_format='diff',必须提供 existing_content[/bold red]")
return "# 错误:缺少现有内容", "生成失败,缺少现有内容", []
system_prompt = ( system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。" "你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成文件的差异diff"
"返回严格的 JSON 对象,包含三个字段:\n" "返回严格的 JSON 对象,包含个字段:\n"
"- code: (string) 修改后的完整代码\n" "- diff: (string) 文件的差异,使用 unified diff 格式\n"
"- description: (string) 简短的中文修改描述\n" "- description: (string) 简短的中文修改描述\n"
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组" "- commands: (array of string) 修改此文件后需要执行的操作系统命令列表,若无则返回空数组\n"
"- output_format: (string) 应为 'diff'"
) )
else: else:
system_prompt = ( # output_format 为 'full' 或其他,保持现有逻辑
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。" if existing_content is not None:
"返回严格的 JSON 对象,包含三个字段:\n" system_prompt = (
"- code: (string) 生成的完整代码\n" "你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。"
"- description: (string) 简短的中文功能描述\n" "返回严格的 JSON 对象,包含四个字段:\n"
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组" "- code: (string) 修改后的完整代码\n"
) "- description: (string) 简短的中文修改描述\n"
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n"
"- output_format: (string) 应为 'full'"
)
else:
system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
"返回严格的 JSON 对象,包含四个字段:\n"
"- code: (string) 生成的完整代码\n"
"- description: (string) 简短的中文功能描述\n"
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n"
"- output_format: (string) 应为 'full'"
)
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}" user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
if output_format == "diff":
user_prompt += f"\noutput_format: {output_format}"
try: try:
result = self._call_llm(system_prompt, user_prompt) result = self._call_llm(system_prompt, user_prompt)
llm_response = LLMResponse(**result) # 解析响应,假设包含 output_format 字段
return llm_response.code, llm_response.description, llm_response.commands if output_format == "diff":
diff = result.get("diff")
description = result.get("description", "")
commands = result.get("commands", [])
output_format_resp = result.get("output_format", "diff")
if diff is None:
raise ValueError("LLM 响应中没有 diff 字段")
# 调用 diff_applier 应用 diff
try:
chunks = parse_diff(diff)
code, conflicts = apply_diff(existing_content, chunks)
if conflicts:
logger.warning(f"应用diff时发现冲突: {conflicts}")
# 可以记录冲突,但继续处理
except Exception as e:
logger.error(f"应用 diff 时发生意外错误: {e}")
self.console.print(f"[bold red]❌ 应用 diff 时发生意外错误: {e}[/bold red]")
return "# 应用 diff 失败", f"应用 diff 时发生意外错误: {e}", []
return code, description, commands
else:
code = result.get("code")
description = result.get("description", "")
commands = result.get("commands", [])
output_format_resp = result.get("output_format", "full")
if code is None:
raise ValueError("LLM 响应中没有 code 字段")
return code, description, commands
except Exception as e: except Exception as e:
logger.error(f"生成文件 {file_path} 时调用LLM失败: {e}") logger.error(f"生成文件 {file_path} 时调用LLM失败: {e}")
self.console.print(f"[bold red]❌ 生成文件 {file_path} 时调用LLM失败: {e}[/bold red]") self.console.print(f"[bold red]❌ 生成文件 {file_path} 时调用LLM失败: {e}[/bold red]")
@ -358,6 +468,44 @@ class CodeGenerator:
logger.error(f"生成文件 {file_path} 失败: {e}") logger.error(f"生成文件 {file_path} 失败: {e}")
return False, str(e) return False, str(e)
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
"""
对文件列表进行拓扑排序基于依赖关系
Args:
files: 文件路径列表
dependencies: 依赖字典{file: [依赖文件]}
Returns:
List[str]: 拓扑排序后的文件列表
Raises:
ValueError: 如果检测到循环依赖
"""
from collections import deque
# 构建图
graph = {file: dependencies.get(file, []) for file in files}
in_degree = {file: 0 for file in files}
for file in files:
for dep in graph[file]:
if dep in in_degree:
in_degree[dep] += 1
queue = deque([f for f in files if in_degree[f] == 0])
sorted_files = []
while queue:
node = queue.popleft()
sorted_files.append(node)
for dep in graph[node]:
in_degree[dep] -= 1
if in_degree[dep] == 0:
queue.append(dep)
if len(sorted_files) != len(files):
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
return sorted_files
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool: def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
""" """
执行单个命令检查风险失败仅记录错误不抛出异常 执行单个命令检查风险失败仅记录错误不抛出异常
@ -472,6 +620,15 @@ class CodeGenerator:
dependencies = self._add_implicit_dependencies(files, dependencies) dependencies = self._add_implicit_dependencies(files, dependencies)
logger.info("已添加隐式依赖") logger.info("已添加隐式依赖")
# 拓扑排序检查依赖关系
try:
sorted_files = self._topological_sort(files, dependencies)
logger.info(f"拓扑排序成功,文件顺序: {sorted_files}")
except ValueError as e:
logger.error(f"依赖关系错误: {e}")
self.console.print(f"[bold red]❌ 依赖关系错误: {e}[/bold red]")
return # 退出生成
# 断点续写:确定已生成文件 # 断点续写:确定已生成文件
generated_files_set = set(self.state.generated_files if self.state else []) generated_files_set = set(self.state.generated_files if self.state else [])
@ -655,7 +812,7 @@ class CodeGenerator:
self.console.print(f"[yellow]⚠ 依赖文件缺失,将不使用这些文件作为上下文: {missing_deps}[/yellow]") self.console.print(f"[yellow]⚠ 依赖文件缺失,将不使用这些文件作为上下文: {missing_deps}[/yellow]")
# 构建生成指令 # 构建生成指令
instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'\n" instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'.\n"
instruction += f"工单内容摘要:{description}\n" instruction += f"工单内容摘要:{description}\n"
if action == "modify": if action == "modify":
instruction += "请在现有代码基础上进行修改,保持原有风格和功能不变。" instruction += "请在现有代码基础上进行修改,保持原有风格和功能不变。"
@ -669,6 +826,7 @@ class CodeGenerator:
instruction, instruction,
dep_paths, dep_paths,
existing_content=existing, existing_content=existing,
output_format="full", # 在工单处理中默认使用 'full',可根据需求调整
) )
logger.info(f"生成完成: {file_path} - {desc}") logger.info(f"生成完成: {file_path} - {desc}")
@ -708,7 +866,8 @@ class CodeGenerator:
logger.error(f"更新design.json失败: {e}") logger.error(f"更新design.json失败: {e}")
self.console.print(f"[bold red]❌ 更新design.json失败: {e}[/bold red]") self.console.print(f"[bold red]❌ 更新design.json失败: {e}[/bold red]")
""" """
self._update_design(generated_files, change_plan.design_updates) logger.info(f'change_plan: {change_plan}')
self._update_design(generated_files, change_plan.get("design_updates", {}))
self.console.print("[green]✅ design.json 已更新[/green]") self.console.print("[green]✅ design.json 已更新[/green]")

View File

@ -0,0 +1,125 @@
"""
Diff 应用模块使用 GitPython 解析和应用 unified diff 格式
"""
import os
import sys
from typing import List, Dict, Any
import git # 需要安装 GitPython
def parse_diff(diff: str) -> List[str]:
"""
解析 unified diff 字符串提取受影响的文件路径
Args:
diff: unified diff 格式的字符串
Returns:
文件路径列表
"""
files = set()
for line in diff.split('\n'):
if line.startswith('--- a/'):
# 提取旧文件路径
path = line[6:].strip()
if path and path != '/dev/null': # /dev/null 表示新文件
files.add(path)
elif line.startswith('+++ b/'):
# 提取新文件路径
path = line[6:].strip()
if path and path != '/dev/null': # /dev/null 表示删除文件
files.add(path)
return list(files)
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
"""
应用 unified diff 到指定目录使用 GitPython
该函数解析 diff 并尝试应用更改到文件系统如果目标目录不是 git 仓库
将尝试初始化一个临时仓库或报告错误
Args:
diff: unified diff 格式的字符串
target_dir: 目标目录路径默认为当前目录
Returns:
字典包含以下键
- 'success' (bool): 是否成功应用
- 'message' (str): 成功或错误消息
- 'applied_files' (List[str]): 成功应用的文件列表如果成功
- 'error_details' (str): 详细的错误信息如果失败
"""
# 初始化返回值
result = {
'success': False,
'message': '',
'applied_files': [],
'error_details': ''
}
# 检查 diff 是否为空
if not diff or diff.strip() == '':
result['message'] = 'Diff string is empty'
return result
# 解析 diff 获取文件列表
try:
affected_files = parse_diff_files(diff)
except Exception as e:
result['message'] = f"Failed to parse diff: {str(e)}"
result['error_details'] = str(e)
return result
# 检查目标目录是否存在
if not os.path.isdir(target_dir):
result['message'] = f"Target directory does not exist: {target_dir}"
return result
try:
# 尝试获取或初始化 git 仓库
try:
repo = git.Repo(target_dir)
except git.exc.InvalidGitRepositoryError:
# 如果不是 git 仓库,初始化一个
repo = git.Repo.init(target_dir)
# 添加所有现有文件到索引,以便应用 diff
repo.git.add('--all')
# 应用 diff
# 使用 git apply 命令,通过 stdin 传入 diff
# '--whitespace=nowarn' 忽略空白警告
output = repo.git.apply('--whitespace=nowarn', '--', input=diff)
# 如果成功,更新结果
result['success'] = True
result['message'] = 'Diff applied successfully'
result['applied_files'] = affected_files
if output:
result['message'] += f". Output: {output}"
except git.exc.GitError as e:
# 处理 git 错误,如冲突、格式错误等
result['message'] = f"Git error while applying diff: {str(e)}"
result['error_details'] = str(e)
except Exception as e:
# 处理其他异常
result['message'] = f"Unexpected error: {str(e)}"
result['error_details'] = str(e)
return result
# 如果作为脚本运行,可以提供简单的测试
if __name__ == "__main__":
# 示例用法
sample_diff = """--- a/old_file.txt
+++ b/new_file.txt
@@ -1 +1 @@
-Hello World
+Hello Universe
"""
print("Testing apply_diff...")
res = apply_diff(sample_diff, ".")
print(res)

View File

@ -11,6 +11,12 @@ class FileStatus(str, Enum):
FAILED = "failed" FAILED = "failed"
class OutputFormat(str, Enum):
"""输出格式枚举。"""
FULL = "full"
DIFF = "diff"
# 模型用于 design.json 结构 # 模型用于 design.json 结构
class FunctionModel(BaseModel): class FunctionModel(BaseModel):
"""函数模型,对应 design.json 中的 functions 字段。""" """函数模型,对应 design.json 中的 functions 字段。"""
@ -84,3 +90,4 @@ class LLMResponse(BaseModel):
code: str code: str
description: str description: str
commands: List[str] = Field(default_factory=list) commands: List[str] = Field(default_factory=list)
output_format: OutputFormat = Field(default=OutputFormat.FULL, description="输出格式,可选值为 'full''diff',默认为 'full'")

View File

@ -2,6 +2,7 @@ from typing import Tuple, Dict, List, Optional, Any
import os import os
from pathlib import Path from pathlib import Path
import queue import queue
from collections import deque
from loguru import logger # 添加导入 from loguru import logger # 添加导入
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn
@ -258,3 +259,39 @@ def create_progress_bar(total: int = 100, description: str = "Processing",
] ]
progress = Progress(*columns, auto_refresh=auto_refresh) progress = Progress(*columns, auto_refresh=auto_refresh)
return progress return progress
def topological_sort(graph: Dict[str, List[str]]) -> List[str]:
"""
基于依赖图进行拓扑排序检测循环依赖并报错
Args:
graph: 依赖图字典形式键为节点文件路径值为该节点依赖的节点列表
Returns:
List[str]: 拓扑排序后的节点列表
Raises:
ValueError: 如果检测到循环依赖
"""
# 计算入度
in_degrees = compute_in_degrees(graph)
# 初始化队列入度为0的节点入队
zero_degree_queue = deque([node for node, degree in in_degrees.items() if degree == 0])
sorted_nodes = []
while zero_degree_queue:
node = zero_degree_queue.popleft()
sorted_nodes.append(node)
for neighbor in graph.get(node, []):
if neighbor in in_degrees:
in_degrees[neighbor] -= 1
if in_degrees[neighbor] == 0:
zero_degree_queue.append(neighbor)
# 检查循环依赖
if len(sorted_nodes) != len(graph):
raise ValueError(f"检测到循环依赖,排序节点数 {len(sorted_nodes)} 不等于总节点数 {len(graph)}")
return sorted_nodes

View File

@ -262,3 +262,51 @@ class TestCodeGenerator:
# 运行,预期不抛出异常 # 运行,预期不抛出异常
code_generator.run(Path(tmp_path / "README.md")) code_generator.run(Path(tmp_path / "README.md"))
def test_topological_sort_normal(self, code_generator):
"""测试拓扑排序正常依赖排序"""
files = ["a", "b", "c"]
dependencies = {"a": ["b"], "b": ["c"], "c": []}
result = code_generator._topological_sort(files, dependencies)
# 验证排序顺序正确c 在 b 之前b 在 a 之前
assert result == ["c", "b", "a"] # 根据实现,顺序是确定的
# 验证每个文件的依赖都在其之前
for i, node in enumerate(result):
deps = dependencies.get(node, [])
for dep in deps:
assert result.index(dep) < i
def test_topological_sort_cycle_detection(self, code_generator):
"""测试拓扑排序循环依赖检测"""
files = ["a", "b"]
dependencies = {"a": ["b"], "b": ["a"]}
with pytest.raises(ValueError, match="检测到循环依赖"):
code_generator._topological_sort(files, dependencies)
def test_topological_sort_empty(self, code_generator):
"""测试拓扑排序空输入"""
files = []
dependencies = {}
result = code_generator._topological_sort(files, dependencies)
assert result == []
def test_topological_sort_partial_deps(self, code_generator):
"""测试拓扑排序部分依赖不在列表中"""
files = ["a", "b"]
dependencies = {"a": ["b", "c"], "b": []} # c 不在 files 中
result = code_generator._topological_sort(files, dependencies)
# c 被忽略,因为不在 in_degree 中,排序应基于 b 依赖
assert result == ["b", "a"]
def test_topological_sort_complex(self, code_generator):
"""测试拓扑排序复杂依赖关系"""
files = ["a", "b", "c", "d"]
dependencies = {"a": ["b", "c"], "b": ["d"], "c": ["d"], "d": []}
result = code_generator._topological_sort(files, dependencies)
# 验证排序结果满足所有依赖
for node in result:
deps = dependencies.get(node, [])
for dep in deps:
assert result.index(dep) < result.index(node)
# 验证所有文件都在结果中
assert set(result) == set(files)

199
tests/test_diff_applier.py Normal file
View File

@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
Unit tests for diff_applier.py, covering various scenarios such as new file creation,
modification of existing files, conflict handling, and error cases.
"""
import os
import sys
import tempfile
import shutil
import pytest
import git # GitPython is required; assumed installed via project dependencies
# Add src directory to path for module import
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from llm_codegen.diff_applier import parse_diff_files, apply_diff
def test_parse_diff_files():
"""Test parsing unified diff strings to extract file paths."""
# Diff with modification and new file
diff = """--- a/file1.txt
+++ b/file1.txt
@@ -1 +1 @@
-old content
+new content
--- a/new_file.txt
+++ b/new_file.txt
@@ -0,0 +1 @@
+new file content
"""
files = parse_diff_files(diff)
assert set(files) == {'file1.txt', 'new_file.txt'}
# Diff with file deletion
diff_del = """--- a/deleted.txt
+++ /dev/null
@@ -1 +0,0 @@
-content to delete
"""
files = parse_diff_files(diff_del)
assert files == ['deleted.txt'] # Old file path is extracted
# Empty diff
assert parse_diff_files('') == []
assert parse_diff_files('\n') == []
# Diff with only new file (no old file)
diff_new_only = """--- /dev/null
+++ b/only_new.txt
@@ -0,0 +1 @@
+only new
"""
files = parse_diff_files(diff_new_only)
assert files == [] /dev/null is ignored
# Invalid diff format (should still handle gracefully)
diff_invalid = "invalid diff string"
files = parse_diff_files(diff_invalid)
assert files == [] # No valid file paths found
@pytest.fixture
def temp_git_repo():
"""Create a temporary git repository for testing apply_diff."""
temp_dir = tempfile.mkdtemp()
repo = git.Repo.init(temp_dir)
# Make an initial commit to have a clean state
init_file = os.path.join(temp_dir, 'README.md')
with open(init_file, 'w') as f:
f.write('Initial commit content')
repo.git.add('README.md')
repo.git.commit('-m', 'initial commit')
yield temp_dir, repo
# Cleanup after test
shutil.rmtree(temp_dir, ignore_errors=True)
def test_apply_diff_new_file(temp_git_repo):
"""Test applying a diff that creates a new file."""
temp_dir, repo = temp_git_repo
diff = """--- /dev/null
+++ b/test_new.txt
@@ -0,0 +1 @@
+This is a newly created file.
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'test_new.txt' in result['applied_files']
new_file_path = os.path.join(temp_dir, 'test_new.txt')
assert os.path.exists(new_file_path)
with open(new_file_path, 'r') as f:
content = f.read()
assert content.strip() == 'This is a newly created file.'
def test_apply_diff_modify_existing_file(temp_git_repo):
"""Test applying a diff that modifies an existing file."""
temp_dir, repo = temp_git_repo
# Create an existing file and commit it
existing_file = os.path.join(temp_dir, 'existing.txt')
with open(existing_file, 'w') as f:
f.write('Original line 1\nOriginal line 2')
repo.git.add('existing.txt')
repo.git.commit('-m', 'add existing file')
diff = """--- a/existing.txt
+++ b/existing.txt
@@ -1,2 +1,2 @@
-Original line 1
+Modified line 1
Original line 2
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'existing.txt' in result['applied_files']
with open(existing_file, 'r') as f:
content = f.read()
assert content == 'Modified line 1\nOriginal line 2'
def test_apply_diff_conflict_handling(temp_git_repo):
"""Test applying a diff that causes a conflict with uncommitted changes."""
temp_dir, repo = temp_git_repo
# Create and commit a file
conflict_file = os.path.join(temp_dir, 'conflict.txt')
with open(conflict_file, 'w') as f:
f.write('Initial line 1\nInitial line 2')
repo.git.add('conflict.txt')
repo.git.commit('-m', 'add conflict file')
# Modify the file without committing to create conflict
with open(conflict_file, 'w') as f:
f.write('Changed line 1\nInitial line 2') # Change first line
diff = """--- a/conflict.txt
+++ b/conflict.txt
@@ -1,2 +1,2 @@
-Initial line 1
+Diff line 1
Initial line 2
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == False
# Check for conflict or error in message
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
assert result['error_details'] != ''
def test_apply_diff_empty_diff():
"""Test applying an empty diff string."""
result = apply_diff('', '.')
assert result['success'] == False
assert 'empty' in result['message'].lower()
def test_apply_diff_invalid_directory():
"""Test applying a diff to a non-existent directory."""
non_existent_dir = '/tmp/non_existent_dir_12345'
diff = """--- a/dummy.txt
+++ b/dummy.txt
@@ -1 +1 @@
-old
+new
"""
result = apply_diff(diff, non_existent_dir)
assert result['success'] == False
assert 'does not exist' in result['message'].lower()
def test_apply_diff_no_git_repo_initialization():
"""Test applying a diff to a non-git directory, which should initialize a repo."""
temp_dir = tempfile.mkdtemp()
try:
# Create a non-git directory with a file
non_git_file = os.path.join(temp_dir, 'non_git.txt')
with open(non_git_file, 'w') as f:
f.write('Pre-existing content')
diff = """--- a/non_git.txt
+++ b/non_git.txt
@@ -1 +1 @@
-Pre-existing content
+Updated content
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'non_git.txt' in result['applied_files']
with open(non_git_file, 'r') as f:
content = f.read()
assert content == 'Updated content'
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
# Run tests if executed as script
pytest.main([__file__])