feat(core): 修复 process_issue 方法以遵循 design.json 中的文件依赖关系

This commit is contained in:
songsenand 2026-03-19 00:31:18 +08:00
parent 368c2b2e02
commit eba9bd5863
6 changed files with 225 additions and 87 deletions

2
.gitignore vendored
View File

@ -15,3 +15,5 @@ test_output/
*/__pycache__/ */__pycache__/
*.egg-info *.egg-info
*/*.egg-info */*.egg-info
/logs/*
/llm_responses/*

View File

@ -0,0 +1,25 @@
# Bug 工单
name: fix/enhance 命令未遵循文件依赖关系进行修改或生成
description: |
在执行 `llm-codegen fix` 或 `llm-codegen enhance` 命令时,代码生成过程没有严格遵循 `design.json` 中定义的文件依赖关系。
例如,如果文件 A 依赖于文件 B在修改工单中同时包含了 A 和 B那么在生成 A 的新内容时,其上下文可能仍然是 B 修改前的旧内容,
因为 B 的修改尚未应用到文件系统,或者应用顺序没有被强制保证。
steps_to_reproduce: |
1. 初始化一个项目,使其 `design.json` 中包含具有明确依赖关系的文件例如A 依赖于 B
2. 创建一个 `enhance` 或 `fix` 工单,该工单要求同时修改文件 A 和 B。
3. 运行 `llm-codegen enhance` 或 `llm-codegen fix` 命令。
4. 观察生成的日志和最终代码,可以发现 A 的内容可能是基于 B 的旧内容生成的,违背了依赖关系。
expected_behavior: |
`process_issue` 方法应该像 `run` 方法一样,解析 `design.json` 中的依赖关系。
在处理 `affected_files` 列表时,必须确保在处理文件 A 之前其所有依赖项B, C, D...)都已经被成功修改或生成并写入磁盘。
这样可以保证 LLM 在生成 A 时,能够看到最新的依赖文件内容。
actual_behavior: |
`process_issue` 方法只是简单地顺序处理 `affected_files` 列表,没有考虑它们之间在 `design.json` 中定义的依赖关系。
这导致了生成的代码可能基于陈旧的上下文,从而产生错误或不一致的代码。
affected_files:
- src/llm_codegen/core.py

View File

@ -16,6 +16,7 @@ dependencies = [
"pathspec>=1.0.4", "pathspec>=1.0.4",
"python-patch>=0.0.1", "python-patch>=0.0.1",
"unidiff2>=0.7.8", "unidiff2>=0.7.8",
"pendulum>=3.2.0",
] ]
authors = [ authors = [
{name = "Your Name", email = "your.email@example.com"} {name = "Your Name", email = "your.email@example.com"}

View File

@ -20,7 +20,9 @@ app = typer.Typer(help="基于LLM的自动化代码生成与维护工具")
console = Console() console = Console()
def init_logging(output_dir: Path, log_file: Optional[str] = None, command_name: str = "cli") -> str: def init_logging(
output_dir: Path, log_file: Optional[str] = None, command_name: str = "cli"
) -> str:
"""初始化日志配置到logs/目录""" """初始化日志配置到logs/目录"""
log_dir = output_dir / "logs" log_dir = output_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
@ -35,13 +37,23 @@ def init_logging(output_dir: Path, log_file: Optional[str] = None, command_name:
@app.command() @app.command()
def init( def init(
readme: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="README.md 文件路径"), readme: Path = typer.Argument(
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出根目录,默认为当前目录"), ..., exists=True, file_okay=True, dir_okay=False, help="README.md 文件路径"
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"), ),
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), output_dir: Optional[Path] = typer.Option(
None, "--output", "-o", help="输出根目录,默认为当前目录"
),
api_key: Optional[str] = typer.Option(
None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"
),
base_url: str = typer.Option(
"https://api.deepseek.com", "--base-url", help="API基础URL"
),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"), log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"),
max_concurrency: int = typer.Option(4, "--max-concurrency", help="并发生成的最大工作线程数默认4"), max_concurrency: int = typer.Option(
4, "--max-concurrency", help="并发生成的最大工作线程数默认4"
),
): ):
"""初始化项目:根据 README.md 自动生成完整的代码。""" """初始化项目:根据 README.md 自动生成完整的代码。"""
if output_dir is None: if output_dir is None:
@ -56,7 +68,7 @@ def init(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), BarColumn(),
console=console console=console,
) as progress: ) as progress:
task_id = progress.add_task("正在初始化项目...", total=None) task_id = progress.add_task("正在初始化项目...", total=None)
generator = CodeGenerator( generator = CodeGenerator(
@ -78,13 +90,27 @@ def init(
@app.command() @app.command()
def enhance( def enhance(
issue_file: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="需求工单文件路径(如 feature.issue"), issue_file: Path = typer.Argument(
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="项目根目录,默认为当前目录"), ...,
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"), exists=True,
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), file_okay=True,
dir_okay=False,
help="需求工单文件路径(如 feature.issue",
),
output_dir: Optional[Path] = typer.Option(
None, "--output", "-o", help="项目根目录,默认为当前目录"
),
api_key: Optional[str] = typer.Option(
None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"
),
base_url: str = typer.Option(
"https://api.deepseek.com", "--base-url", help="API基础URL"
),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"), log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"),
max_concurrency: int = typer.Option(4, "--max-concurrency", help="并发生成的最大工作线程数默认4"), max_concurrency: int = typer.Option(
4, "--max-concurrency", help="并发生成的最大工作线程数默认4"
),
): ):
"""增强项目:根据需求工单添加新功能。""" """增强项目:根据需求工单添加新功能。"""
if output_dir is None: if output_dir is None:
@ -96,7 +122,9 @@ def enhance(
# 处理致命错误检查design.json是否存在 # 处理致命错误检查design.json是否存在
design_path = output_dir / "design.json" design_path = output_dir / "design.json"
if not design_path.exists(): if not design_path.exists():
logger.error(f"design.json 不存在于 {output_dir},请先运行 init 命令初始化项目。") logger.error(
f"design.json 不存在于 {output_dir},请先运行 init 命令初始化项目。"
)
raise typer.Exit(code=1) raise typer.Exit(code=1)
# 读取工单文件 # 读取工单文件
@ -137,7 +165,7 @@ def enhance(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), BarColumn(),
console=console console=console,
) as progress: ) as progress:
task_id = progress.add_task("正在增强项目...", total=None) task_id = progress.add_task("正在增强项目...", total=None)
generator = CodeGenerator( generator = CodeGenerator(
@ -158,13 +186,27 @@ def enhance(
@app.command() @app.command()
def fix( def fix(
issue_file: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="Bug工单文件路径如 bug.issue"), issue_file: Path = typer.Argument(
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="项目根目录,默认为当前目录"), ...,
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"), exists=True,
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), file_okay=True,
dir_okay=False,
help="Bug工单文件路径如 bug.issue",
),
output_dir: Optional[Path] = typer.Option(
None, "--output", "-o", help="项目根目录,默认为当前目录"
),
api_key: Optional[str] = typer.Option(
None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"
),
base_url: str = typer.Option(
"https://api.deepseek.com", "--base-url", help="API基础URL"
),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"), log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"),
max_concurrency: int = typer.Option(4, "--max-concurrency", help="并发生成的最大工作线程数默认4"), max_concurrency: int = typer.Option(
4, "--max-concurrency", help="并发生成的最大工作线程数默认4"
),
): ):
"""修复项目根据Bug工单自动修复 Bug。""" """修复项目根据Bug工单自动修复 Bug。"""
if output_dir is None: if output_dir is None:
@ -186,7 +228,6 @@ def fix(
except Exception as e: except Exception as e:
logger.error(f"读取工单文件失败: {e}") logger.error(f"读取工单文件失败: {e}")
raise typer.Exit(code=1) raise typer.Exit(code=1)
try: try:
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
@ -212,18 +253,35 @@ def fix(
except Exception as e: except Exception as e:
logger.error(f"修复失败: {e}") logger.error(f"修复失败: {e}")
raise typer.Exit(code=1) raise typer.Exit(code=1)
console.print("[green]修复处理完成。成功处理文件,详情请查看日志。[/green]")
@app.command() @app.command()
def design( def design(
file: Path = typer.Option(..., "--file", "-f", help="README文件路径用于生成design.json", exists=True, file_okay=True, dir_okay=False), file: Path = typer.Option(
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出目录design.json将保存在此默认为当前目录"), ...,
"--file",
"-f",
help="README文件路径用于生成design.json",
exists=True,
file_okay=True,
dir_okay=False,
),
output_dir: Optional[Path] = typer.Option(
None, "--output", "-o", help="输出目录design.json将保存在此默认为当前目录"
),
force: bool = typer.Option(False, "--force", help="强制覆盖已存在的design.json"), force: bool = typer.Option(False, "--force", help="强制覆盖已存在的design.json"),
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"), api_key: Optional[str] = typer.Option(
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"
),
base_url: str = typer.Option(
"https://api.deepseek.com", "--base-url", help="API基础URL"
),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"), log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"),
max_concurrency: int = typer.Option(4, "--max-concurrency", help="并发生成的最大工作线程数默认4"), max_concurrency: int = typer.Option(
4, "--max-concurrency", help="并发生成的最大工作线程数默认4"
),
): ):
"""生成或更新design.json根据README文件生成中间设计文件不生成完整代码。""" """生成或更新design.json根据README文件生成中间设计文件不生成完整代码。"""
if output_dir is None: if output_dir is None:
@ -235,7 +293,9 @@ def design(
# 检查design.json是否存在并处理强制覆盖 # 检查design.json是否存在并处理强制覆盖
design_path = output_dir / "design.json" design_path = output_dir / "design.json"
if not force and design_path.exists(): if not force and design_path.exists():
logger.error(f"design.json 已存在于 {design_path}。使用 --force 参数以强制覆盖。") logger.error(
f"design.json 已存在于 {design_path}。使用 --force 参数以强制覆盖。"
)
raise typer.Exit(code=1) raise typer.Exit(code=1)
try: try:
@ -243,7 +303,7 @@ def design(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), BarColumn(),
console=console console=console,
) as progress: ) as progress:
task_id = progress.add_task("正在生成design.json...", total=None) task_id = progress.add_task("正在生成design.json...", total=None)
generator = CodeGenerator( generator = CodeGenerator(
@ -267,13 +327,21 @@ def design(
@app.command() @app.command()
def check( def check(
output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="项目根目录,默认为当前目录"), output_dir: Optional[Path] = typer.Option(
api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"), None, "--output", "-o", help="项目根目录,默认为当前目录"
base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), ),
api_key: Optional[str] = typer.Option(
None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"
),
base_url: str = typer.Option(
"https://api.deepseek.com", "--base-url", help="API基础URL"
),
model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"),
log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"), log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"),
max_retries: int = typer.Option(3, "--max-retries", help="最大修复重试次数"), max_retries: int = typer.Option(3, "--max-retries", help="最大修复重试次数"),
max_concurrency: int = typer.Option(4, "--max-concurrency", help="并发生成的最大工作线程数默认4"), max_concurrency: int = typer.Option(
4, "--max-concurrency", help="并发生成的最大工作线程数默认4"
),
): ):
"""运行代码检查和自动修复(不依赖于工单)""" """运行代码检查和自动修复(不依赖于工单)"""
if output_dir is None: if output_dir is None:

View File

@ -3,7 +3,7 @@ import os
import subprocess import subprocess
import sys import sys
import concurrent.futures import concurrent.futures
import difflib import pendulum
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
@ -15,7 +15,7 @@ from loguru import logger
from openai import OpenAI from openai import OpenAI
from .utils import is_dangerous_command from .utils import is_dangerous_command
from .models import DesignModel, StateModel, LLMResponse, FileModel from .models import DesignModel, StateModel, FileModel
from .diff_applier import parse_diff, apply_diff from .diff_applier import parse_diff, apply_diff
@ -98,10 +98,35 @@ class CodeGenerator:
content = message.content content = message.content
# 记录思考过程(如果存在) # 记录思考过程(如果存在)
reasoning_content = None
if hasattr(message, "reasoning_content") and message.reasoning_content: if hasattr(message, "reasoning_content") and message.reasoning_content:
logger.info(f"模型思考过程: {message.reasoning_content}") reasoning_content = message.reasoning_content
logger.info("模型思考过程已记录")
logger.debug(f"LLM原始响应: {content[:500]}...") # 创建响应目录
responses_dir = self.output_dir / "llm_responses"
responses_dir.mkdir(parents=True, exist_ok=True)
# 生成文件名(使用当前时间)
timestamp = pendulum.now().format("YYYYMMDD_HHmmss_SSS")
response_file = responses_dir / f"response_{timestamp}.json"
# 保存响应到JSON文件
response_data = {
"timestamp": timestamp,
"model": self.model,
"content": content,
"reasoning_content": reasoning_content,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"temperature": temperature,
"expect_json": expect_json
}
with open(response_file, "w", encoding="utf-8") as f:
json.dump(response_data, f, indent=2, ensure_ascii=False)
logger.debug(f"LLM原始响应: {response_file.name}")
if expect_json: if expect_json:
result = json.loads(content) result = json.loads(content)
@ -119,6 +144,7 @@ class CodeGenerator:
self.console.print(f"[bold red]❌ LLM调用失败: {e}[/bold red]") self.console.print(f"[bold red]❌ LLM调用失败: {e}[/bold red]")
raise raise
def parse_readme(self, readme_path: Path) -> str: def parse_readme(self, readme_path: Path) -> str:
""" """
读取README文件内容 读取README文件内容
@ -362,8 +388,8 @@ class CodeGenerator:
# 根据 output_format 设置 system_prompt # 根据 output_format 设置 system_prompt
if output_format == "diff": if output_format == "diff":
if existing_content is None: if existing_content is None:
logger.error(f"对于 output_format='diff',必须提供 existing_content") logger.error("对于 output_format='diff',必须提供 existing_content")
self.console.print(f"[bold red]❌ 对于 output_format='diff',必须提供 existing_content[/bold red]") self.console.print("[bold red]❌ 对于 output_format='diff',必须提供 existing_content[/bold red]")
return "# 错误:缺少现有内容", "生成失败,缺少现有内容", [] return "# 错误:缺少现有内容", "生成失败,缺少现有内容", []
system_prompt = ( system_prompt = (
"你是一个专业的编程助手。根据用户指令和提供的上下文文件生成文件的差异diff" "你是一个专业的编程助手。根据用户指令和提供的上下文文件生成文件的差异diff"
@ -405,7 +431,7 @@ class CodeGenerator:
diff = result.get("diff") diff = result.get("diff")
description = result.get("description", "") description = result.get("description", "")
commands = result.get("commands", []) commands = result.get("commands", [])
output_format_resp = result.get("output_format", "diff") result.get("output_format", "diff")
if diff is None: if diff is None:
raise ValueError("LLM 响应中没有 diff 字段") raise ValueError("LLM 响应中没有 diff 字段")
# 调用 diff_applier 应用 diff # 调用 diff_applier 应用 diff
@ -424,7 +450,7 @@ class CodeGenerator:
code = result.get("code") code = result.get("code")
description = result.get("description", "") description = result.get("description", "")
commands = result.get("commands", []) commands = result.get("commands", [])
output_format_resp = result.get("output_format", "full") result.get("output_format", "full")
if code is None: if code is None:
raise ValueError("LLM 响应中没有 code 字段") raise ValueError("LLM 响应中没有 code 字段")
return code, description, commands return code, description, commands
@ -791,9 +817,37 @@ class CodeGenerator:
self.console.print(f"[green]✅ 分析完成,将处理 {len(affected_files)} 个文件[/green]") self.console.print(f"[green]✅ 分析完成,将处理 {len(affected_files)} 个文件[/green]")
# 步骤2: 逐个处理文件 # 添加依赖关系排序:解析 design.json 中的依赖,确保依赖项先于被依赖项处理
generated_files = [] # 构建依赖关系字典用于拓扑排序
dependencies_dict = {}
for file_info in affected_files: for file_info in affected_files:
path = file_info["path"]
# 从 design.json 中获取依赖关系
deps = []
for f in self.design.files:
if f.path == path:
deps = f.dependencies
break
# 只考虑在 affected_files 中的依赖文件,以确保内部依赖顺序
affected_paths_set = set(info["path"] for info in affected_files)
filtered_deps = [dep for dep in deps if dep in affected_paths_set]
dependencies_dict[path] = filtered_deps
# 对 affected_files 进行拓扑排序
try:
sorted_paths = self._topological_sort([info["path"] for info in affected_files], dependencies_dict)
except ValueError as e:
logger.error(f"依赖关系排序失败: {e}")
self.console.print(f"[bold red]❌ 依赖关系排序失败: {e}[/bold red]")
return False # 排序失败,处理中止
# 重新排序 affected_files 基于 sorted_paths
file_info_map = {info["path"]: info for info in affected_files}
sorted_affected_files = [file_info_map[path] for path in sorted_paths]
# 步骤2: 逐个处理文件(按依赖顺序)
generated_files = []
for file_info in sorted_affected_files:
file_path = file_info["path"] file_path = file_info["path"]
action = file_info.get("action", "modify") # modify 或 create action = file_info.get("action", "modify") # modify 或 create
description = file_info.get("description", "") description = file_info.get("description", "")
@ -839,42 +893,34 @@ class CodeGenerator:
instruction += "请生成完整的代码文件。" instruction += "请生成完整的代码文件。"
# 调用 generate_file # 调用 generate_file
output_format = "full" if action == "create" else "diff" code, desc, commands = self.generate_file(
try: file_path,
code, desc, commands = self.generate_file( instruction,
file_path, dep_paths,
instruction, existing_content=existing,
dep_paths, output_format="full",
existing_content=existing,
output_format=output_format,
) )
logger.info(f"生成完成: {file_path} - {desc}") logger.info(f"生成完成: {file_path} - {desc}")
# 写入文件
full_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(full_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"已写入: {full_path}")
generated_files.append(file_path)
except Exception as e:
logger.error(f"写入文件 {file_path} 失败: {e}")
self.console.print(f"[bold red]❌ 写入文件 {file_path} 失败: {e}[/bold red]")
# 跳过命令执行
commands = []
# 执行关联命令
for cmd in commands:
logger.info(f"准备执行命令: {cmd}")
success = self.execute_command(cmd, cwd=self.output_dir)
if not success:
logger.warning(f"命令执行失败,但继续处理: {cmd}")
# 写入文件
full_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(full_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"已写入: {full_path}")
generated_files.append(file_path)
except Exception as e: except Exception as e:
logger.error(f"处理文件 {file_path} 失败: {e}") logger.error(f"写入文件 {file_path} 失败: {e}")
self.console.print(f"[bold red]❌ 处理文件 {file_path} 失败: {e}[/bold red]") self.console.print(f"[bold red]❌ 写入文件 {file_path} 失败: {e}[/bold red]")
# 继续处理其他文件 # 跳过命令执行
continue commands = []
# 执行关联命令
for cmd in commands:
logger.info(f"准备执行命令: {cmd}")
success = self.execute_command(cmd, cwd=self.output_dir)
if not success:
logger.warning(f"命令执行失败,但继续处理: {cmd}")
# 步骤3: 更新 design.json # 步骤3: 更新 design.json
if generated_files: if generated_files:
@ -978,7 +1024,7 @@ class CodeGenerator:
return False return False
else: else:
logger.error("没有README内容且README.md文件不存在无法刷新design") logger.error("没有README内容且README.md文件不存在无法刷新design")
self.console.print(f"[bold red]❌ 没有README内容且README.md文件不存在无法刷新design[/bold red]") self.console.print("[bold red]❌ 没有README内容且README.md文件不存在无法刷新design[/bold red]")
return False return False
try: try:

View File

@ -1,10 +1,6 @@
#!/usr/bin/env python3
"""Unit tests for diff_applier.py, covering various scenarios such as new file creation,
modification of existing files, conflict handling, and error cases."""
import os import os
import sys import sys
import tempfile import tempfile
import shutil
import pytest import pytest
# Add src directory to path for module import # Add src directory to path for module import
@ -64,7 +60,7 @@ def test_apply_diff_new_file():
+This is a newly created file.""" +This is a newly created file."""
result = apply_diff(diff, temp_dir) result = apply_diff(diff, temp_dir)
assert result['success'] == True assert result['success']
assert 'test_new.txt' in result['applied_files'] assert 'test_new.txt' in result['applied_files']
new_file_path = os.path.join(temp_dir, 'test_new.txt') new_file_path = os.path.join(temp_dir, 'test_new.txt')
@ -122,7 +118,7 @@ def test_apply_diff_conflict_handling():
def test_apply_diff_empty_diff(): def test_apply_diff_empty_diff():
"""Test applying an empty diff string.""" """Test applying an empty diff string."""
result = apply_diff('', '.') result = apply_diff('', '.')
assert result['success'] == False assert not result['success']
assert 'empty' in result['message'].lower() assert 'empty' in result['message'].lower()
def test_apply_diff_invalid_directory(): def test_apply_diff_invalid_directory():
@ -135,7 +131,7 @@ def test_apply_diff_invalid_directory():
+new""" +new"""
result = apply_diff(diff, non_existent_dir) result = apply_diff(diff, non_existent_dir)
assert result['success'] == False assert not result['success']
assert 'does not exist' in result['message'].lower() assert 'does not exist' in result['message'].lower()
def test_apply_diff_no_git_repo_initialization(): def test_apply_diff_no_git_repo_initialization():
@ -153,7 +149,7 @@ def test_apply_diff_no_git_repo_initialization():
+Updated content""" +Updated content"""
result = apply_diff(diff, temp_dir) result = apply_diff(diff, temp_dir)
assert result['success'] == True assert result['success']
assert 'non_git.txt' in result['applied_files'] assert 'non_git.txt' in result['applied_files']
with open(non_git_file, 'r') as f: with open(non_git_file, 'r') as f: