feat: 添加对 LLM 返回 diff 格式代码变更的支持
This commit is contained in:
parent
d69c6ed171
commit
6e536f141c
51
README.md
51
README.md
|
|
@ -17,6 +17,7 @@
|
||||||
- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,支持多种操作模式。
|
- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,支持多种操作模式。
|
||||||
- 📝 **详细日志**:所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。
|
- 📝 **详细日志**:所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。
|
||||||
- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。
|
- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。
|
||||||
|
- 📊 **Diff 输出格式支持**:在生成代码变更时提供标准 unified diff 格式输出,便于代码审查和版本控制集成。
|
||||||
|
|
||||||
## 🚀 安装
|
## 🚀 安装
|
||||||
|
|
||||||
|
|
@ -173,6 +174,53 @@ affected_files:
|
||||||
4. 应用变更,更新 `design.json` 中的摘要(如果新增了函数/类)。
|
4. 应用变更,更新 `design.json` 中的摘要(如果新增了函数/类)。
|
||||||
5. 执行检查与修复。
|
5. 执行检查与修复。
|
||||||
|
|
||||||
|
## 📊 Diff 输出格式支持
|
||||||
|
|
||||||
|
工具支持在生成代码变更时输出 diff 格式,便于代码审查和集成到版本控制系统。diff 输出以标准 unified diff 格式呈现,适用于 `enhance` 和 `fix` 操作。
|
||||||
|
|
||||||
|
### 字段描述
|
||||||
|
|
||||||
|
diff 输出包含以下关键字段:
|
||||||
|
|
||||||
|
- **文件名**:变更的文件路径,如 `src/llm_codegen/core.py`。
|
||||||
|
- **行号**:变更发生的行号范围,使用 `@@` 标记表示。
|
||||||
|
- **旧代码**:被修改或删除的代码行,以 `-` 开头。
|
||||||
|
- **新代码**:新增或修改后的代码行,以 `+` 开头。
|
||||||
|
- **变更类型**:隐含在 diff 中,如添加(只有 `+` 行)、删除(只有 `-` 行)、修改(同时有 `-` 和 `+` 行)。
|
||||||
|
|
||||||
|
### 使用示例
|
||||||
|
|
||||||
|
运行 `llm-codegen enhance` 或 `llm-codegen fix` 时,通过 `--diff` 选项启用 diff 输出。例如:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llm-codegen enhance feature.issue -o ./project --diff
|
||||||
|
```
|
||||||
|
|
||||||
|
输出示例:
|
||||||
|
|
||||||
|
```diff
|
||||||
|
--- a/src/llm_codegen/core.py
|
||||||
|
+++ b/src/llm_codegen/core.py
|
||||||
|
@@ -10,7 +10,7 @@
|
||||||
|
def generate_file(self, file_path, prompt_instruction, dependency_files):
|
||||||
|
# 生成代码逻辑
|
||||||
|
code = self._call_llm(...)
|
||||||
|
- commands = []
|
||||||
|
+ commands = ["安装依赖"]
|
||||||
|
return code, commands
|
||||||
|
```
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
|
||||||
|
- diff 输出功能仅适用于增强(`enhance`)和修复(`fix`)操作,初始化(`init`)操作不产生 diff,因为它是从头生成。
|
||||||
|
- 确保使用支持 diff 格式的工具(如 `git diff`、`diff` 命令)查看和应用变更。
|
||||||
|
- 如果不需要 diff 输出,可以省略 `--diff` 选项,工具将直接应用变更到文件。
|
||||||
|
- diff 输出不影响工具的核心功能,仅为可选辅助特性。
|
||||||
|
|
||||||
|
### 内部实现
|
||||||
|
|
||||||
|
工具在生成 diff 输出后,内部使用 `src/llm_codegen/diff_applier.py` 模块来解析和应用 diff 到代码文件。该模块负责读取 diff 格式,验证变更,并安全地更新文件。开发者可以查看此模块的代码以了解更多细节,例如如何与 LLM 响应集成和确保变更的正确性。
|
||||||
|
|
||||||
## 📝 工单模板
|
## 📝 工单模板
|
||||||
|
|
||||||
### 需求工单 (`feature.issue`)
|
### 需求工单 (`feature.issue`)
|
||||||
|
|
@ -241,6 +289,7 @@ uv pip install -e ".[dev]"
|
||||||
│ ├── checker.py # 并行检查与修复模块
|
│ ├── checker.py # 并行检查与修复模块
|
||||||
│ ├── utils.py # 工具函数(危险命令判断、文件操作)
|
│ ├── utils.py # 工具函数(危险命令判断、文件操作)
|
||||||
│ └── models.py # 数据模型(Pydantic)
|
│ └── models.py # 数据模型(Pydantic)
|
||||||
|
│ └── diff_applier.py # 应用llm返回的diff
|
||||||
├── tests/ # 单元测试
|
├── tests/ # 单元测试
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── test_cli.py
|
│ ├── test_cli.py
|
||||||
|
|
@ -265,4 +314,4 @@ pytest tests/
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
通过引入中间设计层和工单驱动机制,本工具不仅实现了从零生成,更成为项目的“AI 协作者”,能够持续参与功能迭代与缺陷修复,大幅提升开发效率。
|
通过引入中间设计层和工单驱动机制,本工具不仅实现了从零生成,更成为项目的“AI 协作者”,能够持续参与功能迭代与缺陷修复,大幅提升开发效率。
|
||||||
24
design.json
24
design.json
|
|
@ -20,7 +20,7 @@
|
||||||
{
|
{
|
||||||
"path": "src/llm_codegen/cli.py",
|
"path": "src/llm_codegen/cli.py",
|
||||||
"summary": "命令行接口,使用typer定义命令",
|
"summary": "命令行接口,使用typer定义命令",
|
||||||
"dependencies": ["src/llm_codegen/core.py"],
|
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
||||||
"functions": [
|
"functions": [
|
||||||
{
|
{
|
||||||
"name": "main",
|
"name": "main",
|
||||||
|
|
@ -34,7 +34,7 @@
|
||||||
{
|
{
|
||||||
"path": "src/llm_codegen/core.py",
|
"path": "src/llm_codegen/core.py",
|
||||||
"summary": "核心生成逻辑,包含CodeGenerator类",
|
"summary": "核心生成逻辑,包含CodeGenerator类",
|
||||||
"dependencies": ["src/llm_codegen/utils.py"],
|
"dependencies": ["src/llm_codegen/utils.py", "src/llm_codegen/diff_applier.py", "src/llm_codegen/models.py"],
|
||||||
"functions": [
|
"functions": [
|
||||||
{
|
{
|
||||||
"name": "_call_llm",
|
"name": "_call_llm",
|
||||||
|
|
@ -84,14 +84,14 @@
|
||||||
{
|
{
|
||||||
"path": "src/llm_codegen/checker.py",
|
"path": "src/llm_codegen/checker.py",
|
||||||
"summary": "并行检查与修复模块,运行检查工具并收集错误",
|
"summary": "并行检查与修复模块,运行检查工具并收集错误",
|
||||||
"dependencies": ["src/llm_codegen/core.py"],
|
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
||||||
"functions": [],
|
"functions": [],
|
||||||
"classes": []
|
"classes": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "src/llm_codegen/utils.py",
|
"path": "src/llm_codegen/utils.py",
|
||||||
"summary": "工具函数,如危险命令判断和文件操作",
|
"summary": "工具函数,如危险命令判断和文件操作",
|
||||||
"dependencies": [],
|
"dependencies": ["src/llm_codegen/models.py"],
|
||||||
"functions": [
|
"functions": [
|
||||||
{
|
{
|
||||||
"name": "is_dangerous_command",
|
"name": "is_dangerous_command",
|
||||||
|
|
@ -102,6 +102,20 @@
|
||||||
],
|
],
|
||||||
"classes": []
|
"classes": []
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"path": "src/llm_codegen/diff_applier.py",
|
||||||
|
"summary": "",
|
||||||
|
"dependencies": ["src/llm_codegen/models.py"],
|
||||||
|
"functions": [
|
||||||
|
{
|
||||||
|
"name": "",
|
||||||
|
"summary": "",
|
||||||
|
"inputs": [],
|
||||||
|
"outputs": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"classes": []
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"path": "src/llm_codegen/models.py",
|
"path": "src/llm_codegen/models.py",
|
||||||
"summary": "数据模型,使用Pydantic定义数据结构",
|
"summary": "数据模型,使用Pydantic定义数据结构",
|
||||||
|
|
@ -143,4 +157,4 @@
|
||||||
"pytest tests/"
|
"pytest tests/"
|
||||||
],
|
],
|
||||||
"check_tools": ["pytest", "pylint", "mypy", "black"]
|
"check_tools": ["pytest", "pylint", "mypy", "black"]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Bug 工单:严格按依赖顺序生成文件
|
||||||
|
name: 严格按依赖顺序生成文件
|
||||||
|
description: 当前代码生成器在生成文件时,虽然 `design.json` 中记录了文件间的依赖关系,但在实际生成过程中,可能没有严格按拓扑顺序生成,导致依赖文件尚未生成时,依赖它们的文件可能无法正确生成(例如,需要导入尚未生成的模块)。需要在生成前对文件列表进行拓扑排序,确保每个文件生成时其所有依赖都已存在。
|
||||||
|
steps_to_reproduce:
|
||||||
|
- 创建一个包含文件依赖的 `design.json`,例如文件列表顺序为 A.py、B.py、C.py,但 A 依赖 B,B 依赖 C。
|
||||||
|
- 运行 `llm-codegen init` 基于该 `design.json` 生成代码。
|
||||||
|
- 观察生成顺序,如果先生成 A.py,则 A.py 中可能尝试导入 B 模块,但 B 尚未生成,导致生成失败或生成内容不完整。
|
||||||
|
expected_behavior: 生成器应自动根据依赖关系对文件列表进行拓扑排序,确保每个文件在其所有依赖之后生成。如果有循环依赖,应报错并终止。
|
||||||
|
actual_behavior: 当前可能按文件列表原始顺序生成,或仅部分考虑依赖,导致生成顺序错误,可能引发导入错误或代码不完整。
|
||||||
|
affected_files:
|
||||||
|
- src/llm_codegen/core.py # 需要修改生成主逻辑,调用排序函数
|
||||||
|
- src/llm_codegen/utils.py # 需要新增或修改依赖排序函数(例如 `topological_sort`)
|
||||||
|
- tests/test_core.py # 需要添加测试用例验证排序逻辑
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
# 功能:支持 LLM 返回 diff 格式的代码变更
|
||||||
|
name: 支持 LLM 返回 diff 格式的代码变更
|
||||||
|
description: 当前代码生成器仅支持 LLM 返回完整的文件源码。为了减少 token 消耗并提高效率,允许 LLM 在响应 JSON 中增加一个字段 `output_format`,其值可以是 `"full"` 或 `"diff"`。当为 `"diff"` 时,`code` 字段应包含 diff 内容(统一 diff 格式),生成器需解析该 diff 并应用到现有文件;当为 `"full"` 时,`code` 字段为完整源码,直接写入。需要增加一个专门的模块 `diff_applier.py` 来处理 diff 的解析和应用(使用GitPython库进行分析和应用)。同时,需要更新 `README.md` 文档,说明这一新特性及其使用方法。
|
||||||
|
affected_files:
|
||||||
|
- src/llm_codegen/core.py
|
||||||
|
- src/llm_codegen/models.py
|
||||||
|
- src/llm_codegen/diff_applier.py # 新增文件
|
||||||
|
- README.md # 更新文档
|
||||||
|
acceptance_criteria:
|
||||||
|
- LLM 生成的 design.json 中文件生成请求可以包含可选字段 `output_format`(默认 `"full"` 以保持兼容)。
|
||||||
|
- 在代码生成过程中,若 `output_format` 为 `"diff"`,则调用 `diff_applier.apply_diff(file_path, diff_content)` 将 diff 应用到现有文件内容上,生成最终文件内容并写入。
|
||||||
|
- 若 diff 应用失败(例如冲突、格式错误),应记录错误并触发自修复流程,或回退到请求完整源码。
|
||||||
|
- `diff_applier.py` 应提供健壮的 diff 解析和应用功能,支持标准的 unified diff 格式,并能处理文件不存在时直接创建的情况(此时 diff 应视为新文件内容)。
|
||||||
|
- 更新 `core.py` 中的 `generate_file` 函数,使其能够根据 `output_format` 分支处理。
|
||||||
|
- 更新 `models.py` 中的文件生成请求数据模型,增加 `output_format` 字段(可选,枚举 `"full" | "diff"`)。
|
||||||
|
- 更新 `README.md`,在相关部分(如“中间设计层”、“核心工作流”或新增一节)说明支持 diff 输出格式,并给出示例或说明。
|
||||||
|
- 添加单元测试覆盖 diff 应用的各种场景(新建文件、修改现有文件、冲突处理等)。
|
||||||
|
- 确保该功能不影响现有生成逻辑(默认 `full`)。
|
||||||
|
|
@ -14,6 +14,7 @@ dependencies = [
|
||||||
"loguru>=0.7.0",
|
"loguru>=0.7.0",
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
"pathspec>=1.0.4",
|
"pathspec>=1.0.4",
|
||||||
|
"gitpython>=3.1.0",
|
||||||
]
|
]
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Your Name", email = "your.email@example.com"}
|
{name = "Your Name", email = "your.email@example.com"}
|
||||||
|
|
@ -43,4 +44,4 @@ max_concurrent_requests = 5
|
||||||
|
|
||||||
# 新增:指定包所在目录
|
# 新增:指定包所在目录
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ def enhance(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取工单文件失败: {e}")
|
logger.error(f"读取工单文件失败: {e}")
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
""""
|
||||||
try:
|
try:
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
|
|
@ -132,6 +132,29 @@ def enhance(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"增强失败: {e}")
|
logger.error(f"增强失败: {e}")
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
"""
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
console=console
|
||||||
|
) as progress:
|
||||||
|
task_id = progress.add_task("正在增强项目...", total=None)
|
||||||
|
generator = CodeGenerator(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
model=model,
|
||||||
|
output_dir=str(output_dir),
|
||||||
|
log_file=log_file_path,
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
)
|
||||||
|
success = generator.process_issue(issue_content, issue_type="enhance")
|
||||||
|
progress.update(task_id, description="增强处理完成")
|
||||||
|
if not success:
|
||||||
|
logger.error("增强处理失败")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
console.print("[green]增强处理完成。成功处理文件,详情请查看日志。[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|
@ -230,4 +253,4 @@ def check(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app()
|
app()
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import difflib
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
@ -14,6 +15,7 @@ from openai import OpenAI
|
||||||
|
|
||||||
from .utils import is_dangerous_command
|
from .utils import is_dangerous_command
|
||||||
from .models import DesignModel, StateModel, LLMResponse, FileModel
|
from .models import DesignModel, StateModel, LLMResponse, FileModel
|
||||||
|
from .diff_applier import parse_diff, apply_diff
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator:
|
class CodeGenerator:
|
||||||
|
|
@ -230,12 +232,74 @@ class CodeGenerator:
|
||||||
logger.debug(f"为文件 {file} 添加隐式依赖: {implicit_deps}")
|
logger.debug(f"为文件 {file} 添加隐式依赖: {implicit_deps}")
|
||||||
return enhanced
|
return enhanced
|
||||||
|
|
||||||
|
def _apply_diff(self, diff: str, original_content: str) -> str:
|
||||||
|
"""
|
||||||
|
应用 unified diff 到原始内容,返回修改后的内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff: 字符串形式的 unified diff
|
||||||
|
original_content: 原始文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 应用 diff 后的内容
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果应用 diff 失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析 diff 行
|
||||||
|
diff_lines = diff.splitlines(keepends=True)
|
||||||
|
if not diff_lines:
|
||||||
|
raise ValueError("diff 为空")
|
||||||
|
|
||||||
|
# 简单的 diff 应用逻辑:假设 diff 是标准 unified diff,逐行处理
|
||||||
|
# 注意:这是一个简化实现,对于复杂 diff 可能不准确,建议使用专用库如 `patch`
|
||||||
|
original_lines = original_content.splitlines(keepends=True)
|
||||||
|
result_lines = []
|
||||||
|
i = 0
|
||||||
|
j = 0
|
||||||
|
while i < len(diff_lines):
|
||||||
|
line = diff_lines[i]
|
||||||
|
if line.startswith('--- ') or line.startswith('+++ '):
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
elif line.startswith('@@ '):
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
elif line.startswith(' '):
|
||||||
|
# 未修改行
|
||||||
|
if j < len(original_lines):
|
||||||
|
result_lines.append(original_lines[j])
|
||||||
|
j += 1
|
||||||
|
i += 1
|
||||||
|
elif line.startswith('-'):
|
||||||
|
# 删除行
|
||||||
|
j += 1
|
||||||
|
i += 1
|
||||||
|
elif line.startswith('+'):
|
||||||
|
# 新增行
|
||||||
|
result_lines.append(line[1:])
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
i += 1 # 跳过未知行
|
||||||
|
|
||||||
|
# 添加剩余原始行
|
||||||
|
while j < len(original_lines):
|
||||||
|
result_lines.append(original_lines[j])
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
return ''.join(result_lines)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"应用 diff 时出错: {e}")
|
||||||
|
raise RuntimeError(f"无法应用 diff: {e}")
|
||||||
|
|
||||||
def generate_file(
|
def generate_file(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
prompt_instruction: str,
|
prompt_instruction: str,
|
||||||
dependency_files: List[str],
|
dependency_files: List[str],
|
||||||
existing_content: Optional[str] = None,
|
existing_content: Optional[str] = None,
|
||||||
|
output_format: str = "full", # 新增参数,默认 'full'
|
||||||
) -> Tuple[str, str, List[str]]:
|
) -> Tuple[str, str, List[str]]:
|
||||||
"""
|
"""
|
||||||
生成单个文件,返回 (代码, 描述, 命令列表)
|
生成单个文件,返回 (代码, 描述, 命令列表)
|
||||||
|
|
@ -245,6 +309,7 @@ class CodeGenerator:
|
||||||
prompt_instruction: 生成指令
|
prompt_instruction: 生成指令
|
||||||
dependency_files: 依赖文件列表(用于上下文)
|
dependency_files: 依赖文件列表(用于上下文)
|
||||||
existing_content: 文件现有内容(若为修改模式)
|
existing_content: 文件现有内容(若为修改模式)
|
||||||
|
output_format: 输出格式,'full' 或 'diff',来自 models.py
|
||||||
"""
|
"""
|
||||||
# 收集上下文内容
|
# 收集上下文内容
|
||||||
context_content = []
|
context_content = []
|
||||||
|
|
@ -291,30 +356,75 @@ class CodeGenerator:
|
||||||
|
|
||||||
full_context = "\n".join(context_content)
|
full_context = "\n".join(context_content)
|
||||||
|
|
||||||
# 根据是否有现有内容调整系统提示
|
# 根据 output_format 设置 system_prompt
|
||||||
if existing_content is not None:
|
if output_format == "diff":
|
||||||
|
if existing_content is None:
|
||||||
|
logger.error(f"对于 output_format='diff',必须提供 existing_content")
|
||||||
|
self.console.print(f"[bold red]❌ 对于 output_format='diff',必须提供 existing_content[/bold red]")
|
||||||
|
return "# 错误:缺少现有内容", "生成失败,缺少现有内容", []
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。"
|
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成文件的差异(diff)。"
|
||||||
"返回严格的 JSON 对象,包含三个字段:\n"
|
"返回严格的 JSON 对象,包含四个字段:\n"
|
||||||
"- code: (string) 修改后的完整代码\n"
|
"- diff: (string) 文件的差异,使用 unified diff 格式\n"
|
||||||
"- description: (string) 简短的中文修改描述\n"
|
"- description: (string) 简短的中文修改描述\n"
|
||||||
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表,若无则返回空数组\n"
|
||||||
|
"- output_format: (string) 应为 'diff'"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = (
|
# output_format 为 'full' 或其他,保持现有逻辑
|
||||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
if existing_content is not None:
|
||||||
"返回严格的 JSON 对象,包含三个字段:\n"
|
system_prompt = (
|
||||||
"- code: (string) 生成的完整代码\n"
|
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。"
|
||||||
"- description: (string) 简短的中文功能描述\n"
|
"返回严格的 JSON 对象,包含四个字段:\n"
|
||||||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
"- code: (string) 修改后的完整代码\n"
|
||||||
)
|
"- description: (string) 简短的中文修改描述\n"
|
||||||
|
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n"
|
||||||
|
"- output_format: (string) 应为 'full'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_prompt = (
|
||||||
|
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||||||
|
"返回严格的 JSON 对象,包含四个字段:\n"
|
||||||
|
"- code: (string) 生成的完整代码\n"
|
||||||
|
"- description: (string) 简短的中文功能描述\n"
|
||||||
|
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n"
|
||||||
|
"- output_format: (string) 应为 'full'"
|
||||||
|
)
|
||||||
|
|
||||||
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
||||||
|
if output_format == "diff":
|
||||||
|
user_prompt += f"\noutput_format: {output_format}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self._call_llm(system_prompt, user_prompt)
|
result = self._call_llm(system_prompt, user_prompt)
|
||||||
llm_response = LLMResponse(**result)
|
# 解析响应,假设包含 output_format 字段
|
||||||
return llm_response.code, llm_response.description, llm_response.commands
|
if output_format == "diff":
|
||||||
|
diff = result.get("diff")
|
||||||
|
description = result.get("description", "")
|
||||||
|
commands = result.get("commands", [])
|
||||||
|
output_format_resp = result.get("output_format", "diff")
|
||||||
|
if diff is None:
|
||||||
|
raise ValueError("LLM 响应中没有 diff 字段")
|
||||||
|
# 调用 diff_applier 应用 diff
|
||||||
|
try:
|
||||||
|
chunks = parse_diff(diff)
|
||||||
|
code, conflicts = apply_diff(existing_content, chunks)
|
||||||
|
if conflicts:
|
||||||
|
logger.warning(f"应用diff时发现冲突: {conflicts}")
|
||||||
|
# 可以记录冲突,但继续处理
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"应用 diff 时发生意外错误: {e}")
|
||||||
|
self.console.print(f"[bold red]❌ 应用 diff 时发生意外错误: {e}[/bold red]")
|
||||||
|
return "# 应用 diff 失败", f"应用 diff 时发生意外错误: {e}", []
|
||||||
|
return code, description, commands
|
||||||
|
else:
|
||||||
|
code = result.get("code")
|
||||||
|
description = result.get("description", "")
|
||||||
|
commands = result.get("commands", [])
|
||||||
|
output_format_resp = result.get("output_format", "full")
|
||||||
|
if code is None:
|
||||||
|
raise ValueError("LLM 响应中没有 code 字段")
|
||||||
|
return code, description, commands
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成文件 {file_path} 时调用LLM失败: {e}")
|
logger.error(f"生成文件 {file_path} 时调用LLM失败: {e}")
|
||||||
self.console.print(f"[bold red]❌ 生成文件 {file_path} 时调用LLM失败: {e}[/bold red]")
|
self.console.print(f"[bold red]❌ 生成文件 {file_path} 时调用LLM失败: {e}[/bold red]")
|
||||||
|
|
@ -358,6 +468,44 @@ class CodeGenerator:
|
||||||
logger.error(f"生成文件 {file_path} 失败: {e}")
|
logger.error(f"生成文件 {file_path} 失败: {e}")
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
|
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
|
||||||
|
"""
|
||||||
|
对文件列表进行拓扑排序,基于依赖关系。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 文件路径列表
|
||||||
|
dependencies: 依赖字典,{file: [依赖文件]}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 拓扑排序后的文件列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果检测到循环依赖
|
||||||
|
"""
|
||||||
|
from collections import deque
|
||||||
|
# 构建图
|
||||||
|
graph = {file: dependencies.get(file, []) for file in files}
|
||||||
|
in_degree = {file: 0 for file in files}
|
||||||
|
for file in files:
|
||||||
|
for dep in graph[file]:
|
||||||
|
if dep in in_degree:
|
||||||
|
in_degree[dep] += 1
|
||||||
|
|
||||||
|
queue = deque([f for f in files if in_degree[f] == 0])
|
||||||
|
sorted_files = []
|
||||||
|
while queue:
|
||||||
|
node = queue.popleft()
|
||||||
|
sorted_files.append(node)
|
||||||
|
for dep in graph[node]:
|
||||||
|
in_degree[dep] -= 1
|
||||||
|
if in_degree[dep] == 0:
|
||||||
|
queue.append(dep)
|
||||||
|
|
||||||
|
if len(sorted_files) != len(files):
|
||||||
|
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
|
||||||
|
|
||||||
|
return sorted_files
|
||||||
|
|
||||||
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
|
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
执行单个命令,检查风险,失败仅记录错误不抛出异常
|
执行单个命令,检查风险,失败仅记录错误不抛出异常
|
||||||
|
|
@ -472,6 +620,15 @@ class CodeGenerator:
|
||||||
dependencies = self._add_implicit_dependencies(files, dependencies)
|
dependencies = self._add_implicit_dependencies(files, dependencies)
|
||||||
logger.info("已添加隐式依赖")
|
logger.info("已添加隐式依赖")
|
||||||
|
|
||||||
|
# 拓扑排序检查依赖关系
|
||||||
|
try:
|
||||||
|
sorted_files = self._topological_sort(files, dependencies)
|
||||||
|
logger.info(f"拓扑排序成功,文件顺序: {sorted_files}")
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"依赖关系错误: {e}")
|
||||||
|
self.console.print(f"[bold red]❌ 依赖关系错误: {e}[/bold red]")
|
||||||
|
return # 退出生成
|
||||||
|
|
||||||
# 断点续写:确定已生成文件
|
# 断点续写:确定已生成文件
|
||||||
generated_files_set = set(self.state.generated_files if self.state else [])
|
generated_files_set = set(self.state.generated_files if self.state else [])
|
||||||
|
|
||||||
|
|
@ -655,7 +812,7 @@ class CodeGenerator:
|
||||||
self.console.print(f"[yellow]⚠ 依赖文件缺失,将不使用这些文件作为上下文: {missing_deps}[/yellow]")
|
self.console.print(f"[yellow]⚠ 依赖文件缺失,将不使用这些文件作为上下文: {missing_deps}[/yellow]")
|
||||||
|
|
||||||
# 构建生成指令
|
# 构建生成指令
|
||||||
instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'。\n"
|
instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'.\n"
|
||||||
instruction += f"工单内容摘要:{description}\n"
|
instruction += f"工单内容摘要:{description}\n"
|
||||||
if action == "modify":
|
if action == "modify":
|
||||||
instruction += "请在现有代码基础上进行修改,保持原有风格和功能不变。"
|
instruction += "请在现有代码基础上进行修改,保持原有风格和功能不变。"
|
||||||
|
|
@ -669,6 +826,7 @@ class CodeGenerator:
|
||||||
instruction,
|
instruction,
|
||||||
dep_paths,
|
dep_paths,
|
||||||
existing_content=existing,
|
existing_content=existing,
|
||||||
|
output_format="full", # 在工单处理中默认使用 'full',可根据需求调整
|
||||||
)
|
)
|
||||||
logger.info(f"生成完成: {file_path} - {desc}")
|
logger.info(f"生成完成: {file_path} - {desc}")
|
||||||
|
|
||||||
|
|
@ -708,7 +866,8 @@ class CodeGenerator:
|
||||||
logger.error(f"更新design.json失败: {e}")
|
logger.error(f"更新design.json失败: {e}")
|
||||||
self.console.print(f"[bold red]❌ 更新design.json失败: {e}[/bold red]")
|
self.console.print(f"[bold red]❌ 更新design.json失败: {e}[/bold red]")
|
||||||
"""
|
"""
|
||||||
self._update_design(generated_files, change_plan.design_updates)
|
logger.info(f'change_plan: {change_plan}')
|
||||||
|
self._update_design(generated_files, change_plan.get("design_updates", {}))
|
||||||
self.console.print("[green]✅ design.json 已更新[/green]")
|
self.console.print("[green]✅ design.json 已更新[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""
|
||||||
|
Diff 应用模块,使用 GitPython 解析和应用 unified diff 格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import git # 需要安装 GitPython
|
||||||
|
|
||||||
|
|
||||||
|
def parse_diff(diff: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
解析 unified diff 字符串,提取受影响的文件路径。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff: unified diff 格式的字符串。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件路径列表。
|
||||||
|
"""
|
||||||
|
files = set()
|
||||||
|
for line in diff.split('\n'):
|
||||||
|
if line.startswith('--- a/'):
|
||||||
|
# 提取旧文件路径
|
||||||
|
path = line[6:].strip()
|
||||||
|
if path and path != '/dev/null': # /dev/null 表示新文件
|
||||||
|
files.add(path)
|
||||||
|
elif line.startswith('+++ b/'):
|
||||||
|
# 提取新文件路径
|
||||||
|
path = line[6:].strip()
|
||||||
|
if path and path != '/dev/null': # /dev/null 表示删除文件
|
||||||
|
files.add(path)
|
||||||
|
return list(files)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
应用 unified diff 到指定目录,使用 GitPython 库。
|
||||||
|
|
||||||
|
该函数解析 diff 并尝试应用更改到文件系统。如果目标目录不是 git 仓库,
|
||||||
|
将尝试初始化一个临时仓库或报告错误。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff: unified diff 格式的字符串。
|
||||||
|
target_dir: 目标目录路径,默认为当前目录。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,包含以下键:
|
||||||
|
- 'success' (bool): 是否成功应用。
|
||||||
|
- 'message' (str): 成功或错误消息。
|
||||||
|
- 'applied_files' (List[str]): 成功应用的文件列表(如果成功)。
|
||||||
|
- 'error_details' (str): 详细的错误信息(如果失败)。
|
||||||
|
"""
|
||||||
|
# 初始化返回值
|
||||||
|
result = {
|
||||||
|
'success': False,
|
||||||
|
'message': '',
|
||||||
|
'applied_files': [],
|
||||||
|
'error_details': ''
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查 diff 是否为空
|
||||||
|
if not diff or diff.strip() == '':
|
||||||
|
result['message'] = 'Diff string is empty'
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 解析 diff 获取文件列表
|
||||||
|
try:
|
||||||
|
affected_files = parse_diff_files(diff)
|
||||||
|
except Exception as e:
|
||||||
|
result['message'] = f"Failed to parse diff: {str(e)}"
|
||||||
|
result['error_details'] = str(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 检查目标目录是否存在
|
||||||
|
if not os.path.isdir(target_dir):
|
||||||
|
result['message'] = f"Target directory does not exist: {target_dir}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试获取或初始化 git 仓库
|
||||||
|
try:
|
||||||
|
repo = git.Repo(target_dir)
|
||||||
|
except git.exc.InvalidGitRepositoryError:
|
||||||
|
# 如果不是 git 仓库,初始化一个
|
||||||
|
repo = git.Repo.init(target_dir)
|
||||||
|
# 添加所有现有文件到索引,以便应用 diff
|
||||||
|
repo.git.add('--all')
|
||||||
|
|
||||||
|
# 应用 diff
|
||||||
|
# 使用 git apply 命令,通过 stdin 传入 diff
|
||||||
|
# '--whitespace=nowarn' 忽略空白警告
|
||||||
|
output = repo.git.apply('--whitespace=nowarn', '--', input=diff)
|
||||||
|
|
||||||
|
# 如果成功,更新结果
|
||||||
|
result['success'] = True
|
||||||
|
result['message'] = 'Diff applied successfully'
|
||||||
|
result['applied_files'] = affected_files
|
||||||
|
if output:
|
||||||
|
result['message'] += f". Output: {output}"
|
||||||
|
|
||||||
|
except git.exc.GitError as e:
|
||||||
|
# 处理 git 错误,如冲突、格式错误等
|
||||||
|
result['message'] = f"Git error while applying diff: {str(e)}"
|
||||||
|
result['error_details'] = str(e)
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他异常
|
||||||
|
result['message'] = f"Unexpected error: {str(e)}"
|
||||||
|
result['error_details'] = str(e)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 如果作为脚本运行,可以提供简单的测试
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 示例用法
|
||||||
|
sample_diff = """--- a/old_file.txt
|
||||||
|
+++ b/new_file.txt
|
||||||
|
@@ -1 +1 @@
|
||||||
|
-Hello World
|
||||||
|
+Hello Universe
|
||||||
|
"""
|
||||||
|
print("Testing apply_diff...")
|
||||||
|
res = apply_diff(sample_diff, ".")
|
||||||
|
print(res)
|
||||||
|
|
@ -11,6 +11,12 @@ class FileStatus(str, Enum):
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFormat(str, Enum):
|
||||||
|
"""输出格式枚举。"""
|
||||||
|
FULL = "full"
|
||||||
|
DIFF = "diff"
|
||||||
|
|
||||||
|
|
||||||
# 模型用于 design.json 结构
|
# 模型用于 design.json 结构
|
||||||
class FunctionModel(BaseModel):
|
class FunctionModel(BaseModel):
|
||||||
"""函数模型,对应 design.json 中的 functions 字段。"""
|
"""函数模型,对应 design.json 中的 functions 字段。"""
|
||||||
|
|
@ -84,3 +90,4 @@ class LLMResponse(BaseModel):
|
||||||
code: str
|
code: str
|
||||||
description: str
|
description: str
|
||||||
commands: List[str] = Field(default_factory=list)
|
commands: List[str] = Field(default_factory=list)
|
||||||
|
output_format: OutputFormat = Field(default=OutputFormat.FULL, description="输出格式,可选值为 'full' 或 'diff',默认为 'full'")
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Tuple, Dict, List, Optional, Any
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import queue
|
import queue
|
||||||
|
from collections import deque
|
||||||
from loguru import logger # 添加导入
|
from loguru import logger # 添加导入
|
||||||
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn
|
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn
|
||||||
|
|
||||||
|
|
@ -258,3 +259,39 @@ def create_progress_bar(total: int = 100, description: str = "Processing",
|
||||||
]
|
]
|
||||||
progress = Progress(*columns, auto_refresh=auto_refresh)
|
progress = Progress(*columns, auto_refresh=auto_refresh)
|
||||||
return progress
|
return progress
|
||||||
|
|
||||||
|
|
||||||
|
def topological_sort(graph: Dict[str, List[str]]) -> List[str]:
|
||||||
|
"""
|
||||||
|
基于依赖图进行拓扑排序,检测循环依赖并报错。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: 依赖图,字典形式,键为节点(文件路径),值为该节点依赖的节点列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 拓扑排序后的节点列表。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果检测到循环依赖。
|
||||||
|
"""
|
||||||
|
# 计算入度
|
||||||
|
in_degrees = compute_in_degrees(graph)
|
||||||
|
|
||||||
|
# 初始化队列,入度为0的节点入队
|
||||||
|
zero_degree_queue = deque([node for node, degree in in_degrees.items() if degree == 0])
|
||||||
|
sorted_nodes = []
|
||||||
|
|
||||||
|
while zero_degree_queue:
|
||||||
|
node = zero_degree_queue.popleft()
|
||||||
|
sorted_nodes.append(node)
|
||||||
|
for neighbor in graph.get(node, []):
|
||||||
|
if neighbor in in_degrees:
|
||||||
|
in_degrees[neighbor] -= 1
|
||||||
|
if in_degrees[neighbor] == 0:
|
||||||
|
zero_degree_queue.append(neighbor)
|
||||||
|
|
||||||
|
# 检查循环依赖
|
||||||
|
if len(sorted_nodes) != len(graph):
|
||||||
|
raise ValueError(f"检测到循环依赖,排序节点数 {len(sorted_nodes)} 不等于总节点数 {len(graph)}")
|
||||||
|
|
||||||
|
return sorted_nodes
|
||||||
|
|
|
||||||
|
|
@ -262,3 +262,51 @@ class TestCodeGenerator:
|
||||||
|
|
||||||
# 运行,预期不抛出异常
|
# 运行,预期不抛出异常
|
||||||
code_generator.run(Path(tmp_path / "README.md"))
|
code_generator.run(Path(tmp_path / "README.md"))
|
||||||
|
|
||||||
|
def test_topological_sort_normal(self, code_generator):
|
||||||
|
"""测试拓扑排序正常依赖排序"""
|
||||||
|
files = ["a", "b", "c"]
|
||||||
|
dependencies = {"a": ["b"], "b": ["c"], "c": []}
|
||||||
|
result = code_generator._topological_sort(files, dependencies)
|
||||||
|
# 验证排序顺序正确,c 在 b 之前,b 在 a 之前
|
||||||
|
assert result == ["c", "b", "a"] # 根据实现,顺序是确定的
|
||||||
|
# 验证每个文件的依赖都在其之前
|
||||||
|
for i, node in enumerate(result):
|
||||||
|
deps = dependencies.get(node, [])
|
||||||
|
for dep in deps:
|
||||||
|
assert result.index(dep) < i
|
||||||
|
|
||||||
|
def test_topological_sort_cycle_detection(self, code_generator):
|
||||||
|
"""测试拓扑排序循环依赖检测"""
|
||||||
|
files = ["a", "b"]
|
||||||
|
dependencies = {"a": ["b"], "b": ["a"]}
|
||||||
|
with pytest.raises(ValueError, match="检测到循环依赖"):
|
||||||
|
code_generator._topological_sort(files, dependencies)
|
||||||
|
|
||||||
|
def test_topological_sort_empty(self, code_generator):
|
||||||
|
"""测试拓扑排序空输入"""
|
||||||
|
files = []
|
||||||
|
dependencies = {}
|
||||||
|
result = code_generator._topological_sort(files, dependencies)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_topological_sort_partial_deps(self, code_generator):
|
||||||
|
"""测试拓扑排序部分依赖不在列表中"""
|
||||||
|
files = ["a", "b"]
|
||||||
|
dependencies = {"a": ["b", "c"], "b": []} # c 不在 files 中
|
||||||
|
result = code_generator._topological_sort(files, dependencies)
|
||||||
|
# c 被忽略,因为不在 in_degree 中,排序应基于 b 依赖
|
||||||
|
assert result == ["b", "a"]
|
||||||
|
|
||||||
|
def test_topological_sort_complex(self, code_generator):
|
||||||
|
"""测试拓扑排序复杂依赖关系"""
|
||||||
|
files = ["a", "b", "c", "d"]
|
||||||
|
dependencies = {"a": ["b", "c"], "b": ["d"], "c": ["d"], "d": []}
|
||||||
|
result = code_generator._topological_sort(files, dependencies)
|
||||||
|
# 验证排序结果满足所有依赖
|
||||||
|
for node in result:
|
||||||
|
deps = dependencies.get(node, [])
|
||||||
|
for dep in deps:
|
||||||
|
assert result.index(dep) < result.index(node)
|
||||||
|
# 验证所有文件都在结果中
|
||||||
|
assert set(result) == set(files)
|
||||||
|
|
@ -0,0 +1,199 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Unit tests for diff_applier.py, covering various scenarios such as new file creation,
|
||||||
|
modification of existing files, conflict handling, and error cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
import git # GitPython is required; assumed installed via project dependencies
|
||||||
|
|
||||||
|
# Add src directory to path for module import
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
|
from llm_codegen.diff_applier import parse_diff_files, apply_diff
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_diff_files():
|
||||||
|
"""Test parsing unified diff strings to extract file paths."""
|
||||||
|
# Diff with modification and new file
|
||||||
|
diff = """--- a/file1.txt
|
||||||
|
+++ b/file1.txt
|
||||||
|
@@ -1 +1 @@
|
||||||
|
-old content
|
||||||
|
+new content
|
||||||
|
--- a/new_file.txt
|
||||||
|
+++ b/new_file.txt
|
||||||
|
@@ -0,0 +1 @@
|
||||||
|
+new file content
|
||||||
|
"""
|
||||||
|
files = parse_diff_files(diff)
|
||||||
|
assert set(files) == {'file1.txt', 'new_file.txt'}
|
||||||
|
|
||||||
|
# Diff with file deletion
|
||||||
|
diff_del = """--- a/deleted.txt
|
||||||
|
+++ /dev/null
|
||||||
|
@@ -1 +0,0 @@
|
||||||
|
-content to delete
|
||||||
|
"""
|
||||||
|
files = parse_diff_files(diff_del)
|
||||||
|
assert files == ['deleted.txt'] # Old file path is extracted
|
||||||
|
|
||||||
|
# Empty diff
|
||||||
|
assert parse_diff_files('') == []
|
||||||
|
assert parse_diff_files('\n') == []
|
||||||
|
|
||||||
|
# Diff with only new file (no old file)
|
||||||
|
diff_new_only = """--- /dev/null
|
||||||
|
+++ b/only_new.txt
|
||||||
|
@@ -0,0 +1 @@
|
||||||
|
+only new
|
||||||
|
"""
|
||||||
|
files = parse_diff_files(diff_new_only)
|
||||||
|
assert files == [] /dev/null is ignored
|
||||||
|
|
||||||
|
# Invalid diff format (should still handle gracefully)
|
||||||
|
diff_invalid = "invalid diff string"
|
||||||
|
files = parse_diff_files(diff_invalid)
|
||||||
|
assert files == [] # No valid file paths found
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_git_repo():
|
||||||
|
"""Create a temporary git repository for testing apply_diff."""
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
repo = git.Repo.init(temp_dir)
|
||||||
|
# Make an initial commit to have a clean state
|
||||||
|
init_file = os.path.join(temp_dir, 'README.md')
|
||||||
|
with open(init_file, 'w') as f:
|
||||||
|
f.write('Initial commit content')
|
||||||
|
repo.git.add('README.md')
|
||||||
|
repo.git.commit('-m', 'initial commit')
|
||||||
|
yield temp_dir, repo
|
||||||
|
# Cleanup after test
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_diff_new_file(temp_git_repo):
|
||||||
|
"""Test applying a diff that creates a new file."""
|
||||||
|
temp_dir, repo = temp_git_repo
|
||||||
|
diff = """--- /dev/null
|
||||||
|
+++ b/test_new.txt
|
||||||
|
@@ -0,0 +1 @@
|
||||||
|
+This is a newly created file.
|
||||||
|
"""
|
||||||
|
result = apply_diff(diff, temp_dir)
|
||||||
|
assert result['success'] == True
|
||||||
|
assert 'test_new.txt' in result['applied_files']
|
||||||
|
new_file_path = os.path.join(temp_dir, 'test_new.txt')
|
||||||
|
assert os.path.exists(new_file_path)
|
||||||
|
with open(new_file_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
assert content.strip() == 'This is a newly created file.'
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_diff_modify_existing_file(temp_git_repo):
|
||||||
|
"""Test applying a diff that modifies an existing file."""
|
||||||
|
temp_dir, repo = temp_git_repo
|
||||||
|
# Create an existing file and commit it
|
||||||
|
existing_file = os.path.join(temp_dir, 'existing.txt')
|
||||||
|
with open(existing_file, 'w') as f:
|
||||||
|
f.write('Original line 1\nOriginal line 2')
|
||||||
|
repo.git.add('existing.txt')
|
||||||
|
repo.git.commit('-m', 'add existing file')
|
||||||
|
|
||||||
|
diff = """--- a/existing.txt
|
||||||
|
+++ b/existing.txt
|
||||||
|
@@ -1,2 +1,2 @@
|
||||||
|
-Original line 1
|
||||||
|
+Modified line 1
|
||||||
|
Original line 2
|
||||||
|
"""
|
||||||
|
result = apply_diff(diff, temp_dir)
|
||||||
|
assert result['success'] == True
|
||||||
|
assert 'existing.txt' in result['applied_files']
|
||||||
|
with open(existing_file, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
assert content == 'Modified line 1\nOriginal line 2'
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_diff_conflict_handling(temp_git_repo):
|
||||||
|
"""Test applying a diff that causes a conflict with uncommitted changes."""
|
||||||
|
temp_dir, repo = temp_git_repo
|
||||||
|
# Create and commit a file
|
||||||
|
conflict_file = os.path.join(temp_dir, 'conflict.txt')
|
||||||
|
with open(conflict_file, 'w') as f:
|
||||||
|
f.write('Initial line 1\nInitial line 2')
|
||||||
|
repo.git.add('conflict.txt')
|
||||||
|
repo.git.commit('-m', 'add conflict file')
|
||||||
|
|
||||||
|
# Modify the file without committing to create conflict
|
||||||
|
with open(conflict_file, 'w') as f:
|
||||||
|
f.write('Changed line 1\nInitial line 2') # Change first line
|
||||||
|
|
||||||
|
diff = """--- a/conflict.txt
|
||||||
|
+++ b/conflict.txt
|
||||||
|
@@ -1,2 +1,2 @@
|
||||||
|
-Initial line 1
|
||||||
|
+Diff line 1
|
||||||
|
Initial line 2
|
||||||
|
"""
|
||||||
|
result = apply_diff(diff, temp_dir)
|
||||||
|
assert result['success'] == False
|
||||||
|
# Check for conflict or error in message
|
||||||
|
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
|
||||||
|
assert result['error_details'] != ''
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_diff_empty_diff():
|
||||||
|
"""Test applying an empty diff string."""
|
||||||
|
result = apply_diff('', '.')
|
||||||
|
assert result['success'] == False
|
||||||
|
assert 'empty' in result['message'].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_diff_invalid_directory():
|
||||||
|
"""Test applying a diff to a non-existent directory."""
|
||||||
|
non_existent_dir = '/tmp/non_existent_dir_12345'
|
||||||
|
diff = """--- a/dummy.txt
|
||||||
|
+++ b/dummy.txt
|
||||||
|
@@ -1 +1 @@
|
||||||
|
-old
|
||||||
|
+new
|
||||||
|
"""
|
||||||
|
result = apply_diff(diff, non_existent_dir)
|
||||||
|
assert result['success'] == False
|
||||||
|
assert 'does not exist' in result['message'].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_diff_no_git_repo_initialization():
|
||||||
|
"""Test applying a diff to a non-git directory, which should initialize a repo."""
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
# Create a non-git directory with a file
|
||||||
|
non_git_file = os.path.join(temp_dir, 'non_git.txt')
|
||||||
|
with open(non_git_file, 'w') as f:
|
||||||
|
f.write('Pre-existing content')
|
||||||
|
|
||||||
|
diff = """--- a/non_git.txt
|
||||||
|
+++ b/non_git.txt
|
||||||
|
@@ -1 +1 @@
|
||||||
|
-Pre-existing content
|
||||||
|
+Updated content
|
||||||
|
"""
|
||||||
|
result = apply_diff(diff, temp_dir)
|
||||||
|
assert result['success'] == True
|
||||||
|
assert 'non_git.txt' in result['applied_files']
|
||||||
|
with open(non_git_file, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
assert content == 'Updated content'
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests if executed as script
|
||||||
|
pytest.main([__file__])
|
||||||
Loading…
Reference in New Issue