feat: 添加对 LLM 返回 diff 格式代码变更的支持
This commit is contained in:
parent
d69c6ed171
commit
6e536f141c
49
README.md
49
README.md
|
|
@ -17,6 +17,7 @@
|
|||
- 🖥️ **命令行工具**:提供 `llm-codegen` 命令,支持多种操作模式。
|
||||
- 📝 **详细日志**:所有操作、LLM 响应、错误均通过 `loguru` 记录到文件。
|
||||
- 🎨 **美观输出**:使用 `rich` 显示进度条和彩色状态。
|
||||
- 📊 **Diff 输出格式支持**:在生成代码变更时提供标准 unified diff 格式输出,便于代码审查和版本控制集成。
|
||||
|
||||
## 🚀 安装
|
||||
|
||||
|
|
@ -173,6 +174,53 @@ affected_files:
|
|||
4. 应用变更,更新 `design.json` 中的摘要(如果新增了函数/类)。
|
||||
5. 执行检查与修复。
|
||||
|
||||
## 📊 Diff 输出格式支持
|
||||
|
||||
工具支持在生成代码变更时输出 diff 格式,便于代码审查和集成到版本控制系统。diff 输出以标准 unified diff 格式呈现,适用于 `enhance` 和 `fix` 操作。
|
||||
|
||||
### 字段描述
|
||||
|
||||
diff 输出包含以下关键字段:
|
||||
|
||||
- **文件名**:变更的文件路径,如 `src/llm_codegen/core.py`。
|
||||
- **行号**:变更发生的行号范围,使用 `@@` 标记表示。
|
||||
- **旧代码**:被修改或删除的代码行,以 `-` 开头。
|
||||
- **新代码**:新增或修改后的代码行,以 `+` 开头。
|
||||
- **变更类型**:隐含在 diff 中,如添加(只有 `+` 行)、删除(只有 `-` 行)、修改(同时有 `-` 和 `+` 行)。
|
||||
|
||||
### 使用示例
|
||||
|
||||
运行 `llm-codegen enhance` 或 `llm-codegen fix` 时,通过 `--diff` 选项启用 diff 输出。例如:
|
||||
|
||||
```bash
|
||||
llm-codegen enhance feature.issue -o ./project --diff
|
||||
```
|
||||
|
||||
输出示例:
|
||||
|
||||
```diff
|
||||
--- a/src/llm_codegen/core.py
|
||||
+++ b/src/llm_codegen/core.py
|
||||
@@ -10,7 +10,7 @@
|
||||
def generate_file(self, file_path, prompt_instruction, dependency_files):
|
||||
# 生成代码逻辑
|
||||
code = self._call_llm(...)
|
||||
- commands = []
|
||||
+ commands = ["安装依赖"]
|
||||
return code, commands
|
||||
```
|
||||
|
||||
### 注意事项
|
||||
|
||||
- diff 输出功能仅适用于增强(`enhance`)和修复(`fix`)操作,初始化(`init`)操作不产生 diff,因为它是从头生成。
|
||||
- 确保使用支持 diff 格式的工具(如 `git diff`、`diff` 命令)查看和应用变更。
|
||||
- 如果不需要 diff 输出,可以省略 `--diff` 选项,工具将直接应用变更到文件。
|
||||
- diff 输出不影响工具的核心功能,仅为可选辅助特性。
|
||||
|
||||
### 内部实现
|
||||
|
||||
工具在生成 diff 输出后,内部使用 `src/llm_codegen/diff_applier.py` 模块来解析和应用 diff 到代码文件。该模块负责读取 diff 格式,验证变更,并安全地更新文件。开发者可以查看此模块的代码以了解更多细节,例如如何与 LLM 响应集成和确保变更的正确性。
|
||||
|
||||
## 📝 工单模板
|
||||
|
||||
### 需求工单 (`feature.issue`)
|
||||
|
|
@ -241,6 +289,7 @@ uv pip install -e ".[dev]"
|
|||
│ ├── checker.py # 并行检查与修复模块
|
||||
│ ├── utils.py # 工具函数(危险命令判断、文件操作)
|
||||
│ └── models.py # 数据模型(Pydantic)
|
||||
│ └── diff_applier.py # 应用llm返回的diff
|
||||
├── tests/ # 单元测试
|
||||
│ ├── __init__.py
|
||||
│ ├── test_cli.py
|
||||
|
|
|
|||
22
design.json
22
design.json
|
|
@ -20,7 +20,7 @@
|
|||
{
|
||||
"path": "src/llm_codegen/cli.py",
|
||||
"summary": "命令行接口,使用typer定义命令",
|
||||
"dependencies": ["src/llm_codegen/core.py"],
|
||||
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
||||
"functions": [
|
||||
{
|
||||
"name": "main",
|
||||
|
|
@ -34,7 +34,7 @@
|
|||
{
|
||||
"path": "src/llm_codegen/core.py",
|
||||
"summary": "核心生成逻辑,包含CodeGenerator类",
|
||||
"dependencies": ["src/llm_codegen/utils.py"],
|
||||
"dependencies": ["src/llm_codegen/utils.py", "src/llm_codegen/diff_applier.py", "src/llm_codegen/models.py"],
|
||||
"functions": [
|
||||
{
|
||||
"name": "_call_llm",
|
||||
|
|
@ -84,14 +84,14 @@
|
|||
{
|
||||
"path": "src/llm_codegen/checker.py",
|
||||
"summary": "并行检查与修复模块,运行检查工具并收集错误",
|
||||
"dependencies": ["src/llm_codegen/core.py"],
|
||||
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
||||
"functions": [],
|
||||
"classes": []
|
||||
},
|
||||
{
|
||||
"path": "src/llm_codegen/utils.py",
|
||||
"summary": "工具函数,如危险命令判断和文件操作",
|
||||
"dependencies": [],
|
||||
"dependencies": ["src/llm_codegen/models.py"],
|
||||
"functions": [
|
||||
{
|
||||
"name": "is_dangerous_command",
|
||||
|
|
@ -102,6 +102,20 @@
|
|||
],
|
||||
"classes": []
|
||||
},
|
||||
{
|
||||
"path": "src/llm_codegen/diff_applier.py",
|
||||
"summary": "",
|
||||
"dependencies": ["src/llm_codegen/models.py"],
|
||||
"functions": [
|
||||
{
|
||||
"name": "",
|
||||
"summary": "",
|
||||
"inputs": [],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
"classes": []
|
||||
},
|
||||
{
|
||||
"path": "src/llm_codegen/models.py",
|
||||
"summary": "数据模型,使用Pydantic定义数据结构",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
# Bug 工单:严格按依赖顺序生成文件
|
||||
name: 严格按依赖顺序生成文件
|
||||
description: 当前代码生成器在生成文件时,虽然 `design.json` 中记录了文件间的依赖关系,但在实际生成过程中,可能没有严格按拓扑顺序生成,导致依赖文件尚未生成时,依赖它们的文件可能无法正确生成(例如,需要导入尚未生成的模块)。需要在生成前对文件列表进行拓扑排序,确保每个文件生成时其所有依赖都已存在。
|
||||
steps_to_reproduce:
|
||||
- 创建一个包含文件依赖的 `design.json`,例如文件列表顺序为 A.py、B.py、C.py,但 A 依赖 B,B 依赖 C。
|
||||
- 运行 `llm-codegen init` 基于该 `design.json` 生成代码。
|
||||
- 观察生成顺序,如果先生成 A.py,则 A.py 中可能尝试导入 B 模块,但 B 尚未生成,导致生成失败或生成内容不完整。
|
||||
expected_behavior: 生成器应自动根据依赖关系对文件列表进行拓扑排序,确保每个文件在其所有依赖之后生成。如果有循环依赖,应报错并终止。
|
||||
actual_behavior: 当前可能按文件列表原始顺序生成,或仅部分考虑依赖,导致生成顺序错误,可能引发导入错误或代码不完整。
|
||||
affected_files:
|
||||
- src/llm_codegen/core.py # 需要修改生成主逻辑,调用排序函数
|
||||
- src/llm_codegen/utils.py # 需要新增或修改依赖排序函数(例如 `topological_sort`)
|
||||
- tests/test_core.py # 需要添加测试用例验证排序逻辑
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# 功能:支持 LLM 返回 diff 格式的代码变更
|
||||
name: 支持 LLM 返回 diff 格式的代码变更
|
||||
description: 当前代码生成器仅支持 LLM 返回完整的文件源码。为了减少 token 消耗并提高效率,允许 LLM 在响应 JSON 中增加一个字段 `output_format`,其值可以是 `"full"` 或 `"diff"`。当为 `"diff"` 时,`code` 字段应包含 diff 内容(统一 diff 格式),生成器需解析该 diff 并应用到现有文件;当为 `"full"` 时,`code` 字段为完整源码,直接写入。需要增加一个专门的模块 `diff_applier.py` 来处理 diff 的解析和应用(使用GitPython库进行分析和应用)。同时,需要更新 `README.md` 文档,说明这一新特性及其使用方法。
|
||||
affected_files:
|
||||
- src/llm_codegen/core.py
|
||||
- src/llm_codegen/models.py
|
||||
- src/llm_codegen/diff_applier.py # 新增文件
|
||||
- README.md # 更新文档
|
||||
acceptance_criteria:
|
||||
- LLM 生成的 design.json 中文件生成请求可以包含可选字段 `output_format`(默认 `"full"` 以保持兼容)。
|
||||
- 在代码生成过程中,若 `output_format` 为 `"diff"`,则调用 `diff_applier.apply_diff(file_path, diff_content)` 将 diff 应用到现有文件内容上,生成最终文件内容并写入。
|
||||
- 若 diff 应用失败(例如冲突、格式错误),应记录错误并触发自修复流程,或回退到请求完整源码。
|
||||
- `diff_applier.py` 应提供健壮的 diff 解析和应用功能,支持标准的 unified diff 格式,并能处理文件不存在时直接创建的情况(此时 diff 应视为新文件内容)。
|
||||
- 更新 `core.py` 中的 `generate_file` 函数,使其能够根据 `output_format` 分支处理。
|
||||
- 更新 `models.py` 中的文件生成请求数据模型,增加 `output_format` 字段(可选,枚举 `"full" | "diff"`)。
|
||||
- 更新 `README.md`,在相关部分(如“中间设计层”、“核心工作流”或新增一节)说明支持 diff 输出格式,并给出示例或说明。
|
||||
- 添加单元测试覆盖 diff 应用的各种场景(新建文件、修改现有文件、冲突处理等)。
|
||||
- 确保该功能不影响现有生成逻辑(默认 `full`)。
|
||||
|
|
@ -14,6 +14,7 @@ dependencies = [
|
|||
"loguru>=0.7.0",
|
||||
"openai>=1.0.0",
|
||||
"pathspec>=1.0.4",
|
||||
"gitpython>=3.1.0",
|
||||
]
|
||||
authors = [
|
||||
{name = "Your Name", email = "your.email@example.com"}
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ def enhance(
|
|||
except Exception as e:
|
||||
logger.error(f"读取工单文件失败: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
""""
|
||||
try:
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
|
|
@ -132,6 +132,29 @@ def enhance(
|
|||
except Exception as e:
|
||||
logger.error(f"增强失败: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
"""
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
task_id = progress.add_task("正在增强项目...", total=None)
|
||||
generator = CodeGenerator(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
output_dir=str(output_dir),
|
||||
log_file=log_file_path,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
success = generator.process_issue(issue_content, issue_type="enhance")
|
||||
progress.update(task_id, description="增强处理完成")
|
||||
if not success:
|
||||
logger.error("增强处理失败")
|
||||
raise typer.Exit(code=1)
|
||||
console.print("[green]增强处理完成。成功处理文件,详情请查看日志。[/green]")
|
||||
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import concurrent.futures
|
||||
import difflib
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
|
|
@ -14,6 +15,7 @@ from openai import OpenAI
|
|||
|
||||
from .utils import is_dangerous_command
|
||||
from .models import DesignModel, StateModel, LLMResponse, FileModel
|
||||
from .diff_applier import parse_diff, apply_diff
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
|
|
@ -230,12 +232,74 @@ class CodeGenerator:
|
|||
logger.debug(f"为文件 {file} 添加隐式依赖: {implicit_deps}")
|
||||
return enhanced
|
||||
|
||||
def _apply_diff(self, diff: str, original_content: str) -> str:
|
||||
"""
|
||||
应用 unified diff 到原始内容,返回修改后的内容。
|
||||
|
||||
Args:
|
||||
diff: 字符串形式的 unified diff
|
||||
original_content: 原始文件内容
|
||||
|
||||
Returns:
|
||||
str: 应用 diff 后的内容
|
||||
|
||||
Raises:
|
||||
Exception: 如果应用 diff 失败
|
||||
"""
|
||||
try:
|
||||
# 解析 diff 行
|
||||
diff_lines = diff.splitlines(keepends=True)
|
||||
if not diff_lines:
|
||||
raise ValueError("diff 为空")
|
||||
|
||||
# 简单的 diff 应用逻辑:假设 diff 是标准 unified diff,逐行处理
|
||||
# 注意:这是一个简化实现,对于复杂 diff 可能不准确,建议使用专用库如 `patch`
|
||||
original_lines = original_content.splitlines(keepends=True)
|
||||
result_lines = []
|
||||
i = 0
|
||||
j = 0
|
||||
while i < len(diff_lines):
|
||||
line = diff_lines[i]
|
||||
if line.startswith('--- ') or line.startswith('+++ '):
|
||||
i += 1
|
||||
continue
|
||||
elif line.startswith('@@ '):
|
||||
i += 1
|
||||
continue
|
||||
elif line.startswith(' '):
|
||||
# 未修改行
|
||||
if j < len(original_lines):
|
||||
result_lines.append(original_lines[j])
|
||||
j += 1
|
||||
i += 1
|
||||
elif line.startswith('-'):
|
||||
# 删除行
|
||||
j += 1
|
||||
i += 1
|
||||
elif line.startswith('+'):
|
||||
# 新增行
|
||||
result_lines.append(line[1:])
|
||||
i += 1
|
||||
else:
|
||||
i += 1 # 跳过未知行
|
||||
|
||||
# 添加剩余原始行
|
||||
while j < len(original_lines):
|
||||
result_lines.append(original_lines[j])
|
||||
j += 1
|
||||
|
||||
return ''.join(result_lines)
|
||||
except Exception as e:
|
||||
logger.error(f"应用 diff 时出错: {e}")
|
||||
raise RuntimeError(f"无法应用 diff: {e}")
|
||||
|
||||
def generate_file(
|
||||
self,
|
||||
file_path: str,
|
||||
prompt_instruction: str,
|
||||
dependency_files: List[str],
|
||||
existing_content: Optional[str] = None,
|
||||
output_format: str = "full", # 新增参数,默认 'full'
|
||||
) -> Tuple[str, str, List[str]]:
|
||||
"""
|
||||
生成单个文件,返回 (代码, 描述, 命令列表)
|
||||
|
|
@ -245,6 +309,7 @@ class CodeGenerator:
|
|||
prompt_instruction: 生成指令
|
||||
dependency_files: 依赖文件列表(用于上下文)
|
||||
existing_content: 文件现有内容(若为修改模式)
|
||||
output_format: 输出格式,'full' 或 'diff',来自 models.py
|
||||
"""
|
||||
# 收集上下文内容
|
||||
context_content = []
|
||||
|
|
@ -291,30 +356,75 @@ class CodeGenerator:
|
|||
|
||||
full_context = "\n".join(context_content)
|
||||
|
||||
# 根据是否有现有内容调整系统提示
|
||||
if existing_content is not None:
|
||||
# 根据 output_format 设置 system_prompt
|
||||
if output_format == "diff":
|
||||
if existing_content is None:
|
||||
logger.error(f"对于 output_format='diff',必须提供 existing_content")
|
||||
self.console.print(f"[bold red]❌ 对于 output_format='diff',必须提供 existing_content[/bold red]")
|
||||
return "# 错误:缺少现有内容", "生成失败,缺少现有内容", []
|
||||
system_prompt = (
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。"
|
||||
"返回严格的 JSON 对象,包含三个字段:\n"
|
||||
"- code: (string) 修改后的完整代码\n"
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成文件的差异(diff)。"
|
||||
"返回严格的 JSON 对象,包含四个字段:\n"
|
||||
"- diff: (string) 文件的差异,使用 unified diff 格式\n"
|
||||
"- description: (string) 简短的中文修改描述\n"
|
||||
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
||||
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表,若无则返回空数组\n"
|
||||
"- output_format: (string) 应为 'diff'"
|
||||
)
|
||||
else:
|
||||
system_prompt = (
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||||
"返回严格的 JSON 对象,包含三个字段:\n"
|
||||
"- code: (string) 生成的完整代码\n"
|
||||
"- description: (string) 简短的中文功能描述\n"
|
||||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组"
|
||||
)
|
||||
# output_format 为 'full' 或其他,保持现有逻辑
|
||||
if existing_content is not None:
|
||||
system_prompt = (
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,**修改**现有的代码文件。"
|
||||
"返回严格的 JSON 对象,包含四个字段:\n"
|
||||
"- code: (string) 修改后的完整代码\n"
|
||||
"- description: (string) 简短的中文修改描述\n"
|
||||
"- commands: (array of string) 修改此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n"
|
||||
"- output_format: (string) 应为 'full'"
|
||||
)
|
||||
else:
|
||||
system_prompt = (
|
||||
"你是一个专业的编程助手。根据用户指令和提供的上下文文件,生成完整的代码。"
|
||||
"返回严格的 JSON 对象,包含四个字段:\n"
|
||||
"- code: (string) 生成的完整代码\n"
|
||||
"- description: (string) 简短的中文功能描述\n"
|
||||
"- commands: (array of string) 生成此文件后需要执行的操作系统命令列表(如编译、安装依赖等),若无则返回空数组\n"
|
||||
"- output_format: (string) 应为 'full'"
|
||||
)
|
||||
|
||||
user_prompt = f"{prompt_instruction}\n\n参考文件上下文:\n{full_context}"
|
||||
if output_format == "diff":
|
||||
user_prompt += f"\noutput_format: {output_format}"
|
||||
|
||||
try:
|
||||
result = self._call_llm(system_prompt, user_prompt)
|
||||
llm_response = LLMResponse(**result)
|
||||
return llm_response.code, llm_response.description, llm_response.commands
|
||||
# 解析响应,假设包含 output_format 字段
|
||||
if output_format == "diff":
|
||||
diff = result.get("diff")
|
||||
description = result.get("description", "")
|
||||
commands = result.get("commands", [])
|
||||
output_format_resp = result.get("output_format", "diff")
|
||||
if diff is None:
|
||||
raise ValueError("LLM 响应中没有 diff 字段")
|
||||
# 调用 diff_applier 应用 diff
|
||||
try:
|
||||
chunks = parse_diff(diff)
|
||||
code, conflicts = apply_diff(existing_content, chunks)
|
||||
if conflicts:
|
||||
logger.warning(f"应用diff时发现冲突: {conflicts}")
|
||||
# 可以记录冲突,但继续处理
|
||||
except Exception as e:
|
||||
logger.error(f"应用 diff 时发生意外错误: {e}")
|
||||
self.console.print(f"[bold red]❌ 应用 diff 时发生意外错误: {e}[/bold red]")
|
||||
return "# 应用 diff 失败", f"应用 diff 时发生意外错误: {e}", []
|
||||
return code, description, commands
|
||||
else:
|
||||
code = result.get("code")
|
||||
description = result.get("description", "")
|
||||
commands = result.get("commands", [])
|
||||
output_format_resp = result.get("output_format", "full")
|
||||
if code is None:
|
||||
raise ValueError("LLM 响应中没有 code 字段")
|
||||
return code, description, commands
|
||||
except Exception as e:
|
||||
logger.error(f"生成文件 {file_path} 时调用LLM失败: {e}")
|
||||
self.console.print(f"[bold red]❌ 生成文件 {file_path} 时调用LLM失败: {e}[/bold red]")
|
||||
|
|
@ -358,6 +468,44 @@ class CodeGenerator:
|
|||
logger.error(f"生成文件 {file_path} 失败: {e}")
|
||||
return False, str(e)
|
||||
|
||||
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
对文件列表进行拓扑排序,基于依赖关系。
|
||||
|
||||
Args:
|
||||
files: 文件路径列表
|
||||
dependencies: 依赖字典,{file: [依赖文件]}
|
||||
|
||||
Returns:
|
||||
List[str]: 拓扑排序后的文件列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果检测到循环依赖
|
||||
"""
|
||||
from collections import deque
|
||||
# 构建图
|
||||
graph = {file: dependencies.get(file, []) for file in files}
|
||||
in_degree = {file: 0 for file in files}
|
||||
for file in files:
|
||||
for dep in graph[file]:
|
||||
if dep in in_degree:
|
||||
in_degree[dep] += 1
|
||||
|
||||
queue = deque([f for f in files if in_degree[f] == 0])
|
||||
sorted_files = []
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
sorted_files.append(node)
|
||||
for dep in graph[node]:
|
||||
in_degree[dep] -= 1
|
||||
if in_degree[dep] == 0:
|
||||
queue.append(dep)
|
||||
|
||||
if len(sorted_files) != len(files):
|
||||
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
|
||||
|
||||
return sorted_files
|
||||
|
||||
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
|
||||
"""
|
||||
执行单个命令,检查风险,失败仅记录错误不抛出异常
|
||||
|
|
@ -472,6 +620,15 @@ class CodeGenerator:
|
|||
dependencies = self._add_implicit_dependencies(files, dependencies)
|
||||
logger.info("已添加隐式依赖")
|
||||
|
||||
# 拓扑排序检查依赖关系
|
||||
try:
|
||||
sorted_files = self._topological_sort(files, dependencies)
|
||||
logger.info(f"拓扑排序成功,文件顺序: {sorted_files}")
|
||||
except ValueError as e:
|
||||
logger.error(f"依赖关系错误: {e}")
|
||||
self.console.print(f"[bold red]❌ 依赖关系错误: {e}[/bold red]")
|
||||
return # 退出生成
|
||||
|
||||
# 断点续写:确定已生成文件
|
||||
generated_files_set = set(self.state.generated_files if self.state else [])
|
||||
|
||||
|
|
@ -655,7 +812,7 @@ class CodeGenerator:
|
|||
self.console.print(f"[yellow]⚠ 依赖文件缺失,将不使用这些文件作为上下文: {missing_deps}[/yellow]")
|
||||
|
||||
# 构建生成指令
|
||||
instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'。\n"
|
||||
instruction = f"请根据工单描述{'修改' if action == 'modify' else '生成'}文件 '{file_path}'.\n"
|
||||
instruction += f"工单内容摘要:{description}\n"
|
||||
if action == "modify":
|
||||
instruction += "请在现有代码基础上进行修改,保持原有风格和功能不变。"
|
||||
|
|
@ -669,6 +826,7 @@ class CodeGenerator:
|
|||
instruction,
|
||||
dep_paths,
|
||||
existing_content=existing,
|
||||
output_format="full", # 在工单处理中默认使用 'full',可根据需求调整
|
||||
)
|
||||
logger.info(f"生成完成: {file_path} - {desc}")
|
||||
|
||||
|
|
@ -708,7 +866,8 @@ class CodeGenerator:
|
|||
logger.error(f"更新design.json失败: {e}")
|
||||
self.console.print(f"[bold red]❌ 更新design.json失败: {e}[/bold red]")
|
||||
"""
|
||||
self._update_design(generated_files, change_plan.design_updates)
|
||||
logger.info(f'change_plan: {change_plan}')
|
||||
self._update_design(generated_files, change_plan.get("design_updates", {}))
|
||||
self.console.print("[green]✅ design.json 已更新[/green]")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
"""
|
||||
Diff 应用模块,使用 GitPython 解析和应用 unified diff 格式。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any
|
||||
import git # 需要安装 GitPython
|
||||
|
||||
|
||||
def parse_diff(diff: str) -> List[str]:
|
||||
"""
|
||||
解析 unified diff 字符串,提取受影响的文件路径。
|
||||
|
||||
Args:
|
||||
diff: unified diff 格式的字符串。
|
||||
|
||||
Returns:
|
||||
文件路径列表。
|
||||
"""
|
||||
files = set()
|
||||
for line in diff.split('\n'):
|
||||
if line.startswith('--- a/'):
|
||||
# 提取旧文件路径
|
||||
path = line[6:].strip()
|
||||
if path and path != '/dev/null': # /dev/null 表示新文件
|
||||
files.add(path)
|
||||
elif line.startswith('+++ b/'):
|
||||
# 提取新文件路径
|
||||
path = line[6:].strip()
|
||||
if path and path != '/dev/null': # /dev/null 表示删除文件
|
||||
files.add(path)
|
||||
return list(files)
|
||||
|
||||
|
||||
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
应用 unified diff 到指定目录,使用 GitPython 库。
|
||||
|
||||
该函数解析 diff 并尝试应用更改到文件系统。如果目标目录不是 git 仓库,
|
||||
将尝试初始化一个临时仓库或报告错误。
|
||||
|
||||
Args:
|
||||
diff: unified diff 格式的字符串。
|
||||
target_dir: 目标目录路径,默认为当前目录。
|
||||
|
||||
Returns:
|
||||
字典,包含以下键:
|
||||
- 'success' (bool): 是否成功应用。
|
||||
- 'message' (str): 成功或错误消息。
|
||||
- 'applied_files' (List[str]): 成功应用的文件列表(如果成功)。
|
||||
- 'error_details' (str): 详细的错误信息(如果失败)。
|
||||
"""
|
||||
# 初始化返回值
|
||||
result = {
|
||||
'success': False,
|
||||
'message': '',
|
||||
'applied_files': [],
|
||||
'error_details': ''
|
||||
}
|
||||
|
||||
# 检查 diff 是否为空
|
||||
if not diff or diff.strip() == '':
|
||||
result['message'] = 'Diff string is empty'
|
||||
return result
|
||||
|
||||
# 解析 diff 获取文件列表
|
||||
try:
|
||||
affected_files = parse_diff_files(diff)
|
||||
except Exception as e:
|
||||
result['message'] = f"Failed to parse diff: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
return result
|
||||
|
||||
# 检查目标目录是否存在
|
||||
if not os.path.isdir(target_dir):
|
||||
result['message'] = f"Target directory does not exist: {target_dir}"
|
||||
return result
|
||||
|
||||
try:
|
||||
# 尝试获取或初始化 git 仓库
|
||||
try:
|
||||
repo = git.Repo(target_dir)
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
# 如果不是 git 仓库,初始化一个
|
||||
repo = git.Repo.init(target_dir)
|
||||
# 添加所有现有文件到索引,以便应用 diff
|
||||
repo.git.add('--all')
|
||||
|
||||
# 应用 diff
|
||||
# 使用 git apply 命令,通过 stdin 传入 diff
|
||||
# '--whitespace=nowarn' 忽略空白警告
|
||||
output = repo.git.apply('--whitespace=nowarn', '--', input=diff)
|
||||
|
||||
# 如果成功,更新结果
|
||||
result['success'] = True
|
||||
result['message'] = 'Diff applied successfully'
|
||||
result['applied_files'] = affected_files
|
||||
if output:
|
||||
result['message'] += f". Output: {output}"
|
||||
|
||||
except git.exc.GitError as e:
|
||||
# 处理 git 错误,如冲突、格式错误等
|
||||
result['message'] = f"Git error while applying diff: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
except Exception as e:
|
||||
# 处理其他异常
|
||||
result['message'] = f"Unexpected error: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 如果作为脚本运行,可以提供简单的测试
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
sample_diff = """--- a/old_file.txt
|
||||
+++ b/new_file.txt
|
||||
@@ -1 +1 @@
|
||||
-Hello World
|
||||
+Hello Universe
|
||||
"""
|
||||
print("Testing apply_diff...")
|
||||
res = apply_diff(sample_diff, ".")
|
||||
print(res)
|
||||
|
|
@ -11,6 +11,12 @@ class FileStatus(str, Enum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
"""输出格式枚举。"""
|
||||
FULL = "full"
|
||||
DIFF = "diff"
|
||||
|
||||
|
||||
# 模型用于 design.json 结构
|
||||
class FunctionModel(BaseModel):
|
||||
"""函数模型,对应 design.json 中的 functions 字段。"""
|
||||
|
|
@ -84,3 +90,4 @@ class LLMResponse(BaseModel):
|
|||
code: str
|
||||
description: str
|
||||
commands: List[str] = Field(default_factory=list)
|
||||
output_format: OutputFormat = Field(default=OutputFormat.FULL, description="输出格式,可选值为 'full' 或 'diff',默认为 'full'")
|
||||
|
|
@ -2,6 +2,7 @@ from typing import Tuple, Dict, List, Optional, Any
|
|||
import os
|
||||
from pathlib import Path
|
||||
import queue
|
||||
from collections import deque
|
||||
from loguru import logger # 添加导入
|
||||
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn
|
||||
|
||||
|
|
@ -258,3 +259,39 @@ def create_progress_bar(total: int = 100, description: str = "Processing",
|
|||
]
|
||||
progress = Progress(*columns, auto_refresh=auto_refresh)
|
||||
return progress
|
||||
|
||||
|
||||
def topological_sort(graph: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
基于依赖图进行拓扑排序,检测循环依赖并报错。
|
||||
|
||||
Args:
|
||||
graph: 依赖图,字典形式,键为节点(文件路径),值为该节点依赖的节点列表。
|
||||
|
||||
Returns:
|
||||
List[str]: 拓扑排序后的节点列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果检测到循环依赖。
|
||||
"""
|
||||
# 计算入度
|
||||
in_degrees = compute_in_degrees(graph)
|
||||
|
||||
# 初始化队列,入度为0的节点入队
|
||||
zero_degree_queue = deque([node for node, degree in in_degrees.items() if degree == 0])
|
||||
sorted_nodes = []
|
||||
|
||||
while zero_degree_queue:
|
||||
node = zero_degree_queue.popleft()
|
||||
sorted_nodes.append(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
if neighbor in in_degrees:
|
||||
in_degrees[neighbor] -= 1
|
||||
if in_degrees[neighbor] == 0:
|
||||
zero_degree_queue.append(neighbor)
|
||||
|
||||
# 检查循环依赖
|
||||
if len(sorted_nodes) != len(graph):
|
||||
raise ValueError(f"检测到循环依赖,排序节点数 {len(sorted_nodes)} 不等于总节点数 {len(graph)}")
|
||||
|
||||
return sorted_nodes
|
||||
|
|
|
|||
|
|
@ -262,3 +262,51 @@ class TestCodeGenerator:
|
|||
|
||||
# 运行,预期不抛出异常
|
||||
code_generator.run(Path(tmp_path / "README.md"))
|
||||
|
||||
def test_topological_sort_normal(self, code_generator):
|
||||
"""测试拓扑排序正常依赖排序"""
|
||||
files = ["a", "b", "c"]
|
||||
dependencies = {"a": ["b"], "b": ["c"], "c": []}
|
||||
result = code_generator._topological_sort(files, dependencies)
|
||||
# 验证排序顺序正确,c 在 b 之前,b 在 a 之前
|
||||
assert result == ["c", "b", "a"] # 根据实现,顺序是确定的
|
||||
# 验证每个文件的依赖都在其之前
|
||||
for i, node in enumerate(result):
|
||||
deps = dependencies.get(node, [])
|
||||
for dep in deps:
|
||||
assert result.index(dep) < i
|
||||
|
||||
def test_topological_sort_cycle_detection(self, code_generator):
|
||||
"""测试拓扑排序循环依赖检测"""
|
||||
files = ["a", "b"]
|
||||
dependencies = {"a": ["b"], "b": ["a"]}
|
||||
with pytest.raises(ValueError, match="检测到循环依赖"):
|
||||
code_generator._topological_sort(files, dependencies)
|
||||
|
||||
def test_topological_sort_empty(self, code_generator):
|
||||
"""测试拓扑排序空输入"""
|
||||
files = []
|
||||
dependencies = {}
|
||||
result = code_generator._topological_sort(files, dependencies)
|
||||
assert result == []
|
||||
|
||||
def test_topological_sort_partial_deps(self, code_generator):
|
||||
"""测试拓扑排序部分依赖不在列表中"""
|
||||
files = ["a", "b"]
|
||||
dependencies = {"a": ["b", "c"], "b": []} # c 不在 files 中
|
||||
result = code_generator._topological_sort(files, dependencies)
|
||||
# c 被忽略,因为不在 in_degree 中,排序应基于 b 依赖
|
||||
assert result == ["b", "a"]
|
||||
|
||||
def test_topological_sort_complex(self, code_generator):
|
||||
"""测试拓扑排序复杂依赖关系"""
|
||||
files = ["a", "b", "c", "d"]
|
||||
dependencies = {"a": ["b", "c"], "b": ["d"], "c": ["d"], "d": []}
|
||||
result = code_generator._topological_sort(files, dependencies)
|
||||
# 验证排序结果满足所有依赖
|
||||
for node in result:
|
||||
deps = dependencies.get(node, [])
|
||||
for dep in deps:
|
||||
assert result.index(dep) < result.index(node)
|
||||
# 验证所有文件都在结果中
|
||||
assert set(result) == set(files)
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for diff_applier.py, covering various scenarios such as new file creation,
|
||||
modification of existing files, conflict handling, and error cases.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
import pytest
|
||||
import git # GitPython is required; assumed installed via project dependencies
|
||||
|
||||
# Add src directory to path for module import
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from llm_codegen.diff_applier import parse_diff_files, apply_diff
|
||||
|
||||
|
||||
def test_parse_diff_files():
|
||||
"""Test parsing unified diff strings to extract file paths."""
|
||||
# Diff with modification and new file
|
||||
diff = """--- a/file1.txt
|
||||
+++ b/file1.txt
|
||||
@@ -1 +1 @@
|
||||
-old content
|
||||
+new content
|
||||
--- a/new_file.txt
|
||||
+++ b/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+new file content
|
||||
"""
|
||||
files = parse_diff_files(diff)
|
||||
assert set(files) == {'file1.txt', 'new_file.txt'}
|
||||
|
||||
# Diff with file deletion
|
||||
diff_del = """--- a/deleted.txt
|
||||
+++ /dev/null
|
||||
@@ -1 +0,0 @@
|
||||
-content to delete
|
||||
"""
|
||||
files = parse_diff_files(diff_del)
|
||||
assert files == ['deleted.txt'] # Old file path is extracted
|
||||
|
||||
# Empty diff
|
||||
assert parse_diff_files('') == []
|
||||
assert parse_diff_files('\n') == []
|
||||
|
||||
# Diff with only new file (no old file)
|
||||
diff_new_only = """--- /dev/null
|
||||
+++ b/only_new.txt
|
||||
@@ -0,0 +1 @@
|
||||
+only new
|
||||
"""
|
||||
files = parse_diff_files(diff_new_only)
|
||||
assert files == [] /dev/null is ignored
|
||||
|
||||
# Invalid diff format (should still handle gracefully)
|
||||
diff_invalid = "invalid diff string"
|
||||
files = parse_diff_files(diff_invalid)
|
||||
assert files == [] # No valid file paths found
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_git_repo():
|
||||
"""Create a temporary git repository for testing apply_diff."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
repo = git.Repo.init(temp_dir)
|
||||
# Make an initial commit to have a clean state
|
||||
init_file = os.path.join(temp_dir, 'README.md')
|
||||
with open(init_file, 'w') as f:
|
||||
f.write('Initial commit content')
|
||||
repo.git.add('README.md')
|
||||
repo.git.commit('-m', 'initial commit')
|
||||
yield temp_dir, repo
|
||||
# Cleanup after test
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_apply_diff_new_file(temp_git_repo):
|
||||
"""Test applying a diff that creates a new file."""
|
||||
temp_dir, repo = temp_git_repo
|
||||
diff = """--- /dev/null
|
||||
+++ b/test_new.txt
|
||||
@@ -0,0 +1 @@
|
||||
+This is a newly created file.
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'test_new.txt' in result['applied_files']
|
||||
new_file_path = os.path.join(temp_dir, 'test_new.txt')
|
||||
assert os.path.exists(new_file_path)
|
||||
with open(new_file_path, 'r') as f:
|
||||
content = f.read()
|
||||
assert content.strip() == 'This is a newly created file.'
|
||||
|
||||
|
||||
def test_apply_diff_modify_existing_file(temp_git_repo):
|
||||
"""Test applying a diff that modifies an existing file."""
|
||||
temp_dir, repo = temp_git_repo
|
||||
# Create an existing file and commit it
|
||||
existing_file = os.path.join(temp_dir, 'existing.txt')
|
||||
with open(existing_file, 'w') as f:
|
||||
f.write('Original line 1\nOriginal line 2')
|
||||
repo.git.add('existing.txt')
|
||||
repo.git.commit('-m', 'add existing file')
|
||||
|
||||
diff = """--- a/existing.txt
|
||||
+++ b/existing.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-Original line 1
|
||||
+Modified line 1
|
||||
Original line 2
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'existing.txt' in result['applied_files']
|
||||
with open(existing_file, 'r') as f:
|
||||
content = f.read()
|
||||
assert content == 'Modified line 1\nOriginal line 2'
|
||||
|
||||
|
||||
def test_apply_diff_conflict_handling(temp_git_repo):
|
||||
"""Test applying a diff that causes a conflict with uncommitted changes."""
|
||||
temp_dir, repo = temp_git_repo
|
||||
# Create and commit a file
|
||||
conflict_file = os.path.join(temp_dir, 'conflict.txt')
|
||||
with open(conflict_file, 'w') as f:
|
||||
f.write('Initial line 1\nInitial line 2')
|
||||
repo.git.add('conflict.txt')
|
||||
repo.git.commit('-m', 'add conflict file')
|
||||
|
||||
# Modify the file without committing to create conflict
|
||||
with open(conflict_file, 'w') as f:
|
||||
f.write('Changed line 1\nInitial line 2') # Change first line
|
||||
|
||||
diff = """--- a/conflict.txt
|
||||
+++ b/conflict.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-Initial line 1
|
||||
+Diff line 1
|
||||
Initial line 2
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == False
|
||||
# Check for conflict or error in message
|
||||
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
|
||||
assert result['error_details'] != ''
|
||||
|
||||
|
||||
def test_apply_diff_empty_diff():
|
||||
"""Test applying an empty diff string."""
|
||||
result = apply_diff('', '.')
|
||||
assert result['success'] == False
|
||||
assert 'empty' in result['message'].lower()
|
||||
|
||||
|
||||
def test_apply_diff_invalid_directory():
|
||||
"""Test applying a diff to a non-existent directory."""
|
||||
non_existent_dir = '/tmp/non_existent_dir_12345'
|
||||
diff = """--- a/dummy.txt
|
||||
+++ b/dummy.txt
|
||||
@@ -1 +1 @@
|
||||
-old
|
||||
+new
|
||||
"""
|
||||
result = apply_diff(diff, non_existent_dir)
|
||||
assert result['success'] == False
|
||||
assert 'does not exist' in result['message'].lower()
|
||||
|
||||
|
||||
def test_apply_diff_no_git_repo_initialization():
|
||||
"""Test applying a diff to a non-git directory, which should initialize a repo."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
# Create a non-git directory with a file
|
||||
non_git_file = os.path.join(temp_dir, 'non_git.txt')
|
||||
with open(non_git_file, 'w') as f:
|
||||
f.write('Pre-existing content')
|
||||
|
||||
diff = """--- a/non_git.txt
|
||||
+++ b/non_git.txt
|
||||
@@ -1 +1 @@
|
||||
-Pre-existing content
|
||||
+Updated content
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'non_git.txt' in result['applied_files']
|
||||
with open(non_git_file, 'r') as f:
|
||||
content = f.read()
|
||||
assert content == 'Updated content'
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests if executed as script
|
||||
pytest.main([__file__])
|
||||
Loading…
Reference in New Issue