重构代码以支持并发生成和线程安全状态保存,优化拓扑排序与补丁应用逻辑

This commit is contained in:
songsenand 2026-03-18 16:42:14 +08:00
parent 6e536f141c
commit 75fc8659a0
5 changed files with 296 additions and 235 deletions

View File

@ -23,9 +23,9 @@
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
"functions": [ "functions": [
{ {
"name": "main", "name": "",
"summary": "主CLI入口处理命令行参数并启动生成器", "summary": "",
"inputs": ["readme", "output_dir", "api_key", "base_url", "model", "log_file"], "inputs": [],
"outputs": [] "outputs": []
} }
], ],

View File

@ -14,7 +14,8 @@ dependencies = [
"loguru>=0.7.0", "loguru>=0.7.0",
"openai>=1.0.0", "openai>=1.0.0",
"pathspec>=1.0.4", "pathspec>=1.0.4",
"gitpython>=3.1.0", "python-patch>=0.0.1",
"unidiff2>=0.7.8",
] ]
authors = [ authors = [
{name = "Your Name", email = "your.email@example.com"} {name = "Your Name", email = "your.email@example.com"}
@ -33,6 +34,10 @@ dev = [
"black>=23.0.0", "black>=23.0.0",
] ]
[[tool.uv.index]]
url="https://pypi.tuna.tsinghua.edu.cn/simple"
default=true
[project.scripts] [project.scripts]
llm-codegen = "llm_codegen.cli:app" llm-codegen = "llm_codegen.cli:app"
@ -45,3 +50,8 @@ max_concurrent_requests = 5
# 新增:指定包所在目录 # 新增:指定包所在目录
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]
[dependency-groups]
dev = [
"pytest>=8.4.2",
]

View File

@ -7,6 +7,7 @@ import difflib
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
import threading
from rich.console import Console from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
@ -50,6 +51,7 @@ class CodeGenerator:
self.output_dir.mkdir(parents=True, exist_ok=True) self.output_dir.mkdir(parents=True, exist_ok=True)
self.state_file = self.output_dir / ".llm_generator_state.json" self.state_file = self.output_dir / ".llm_generator_state.json"
self.console = Console() # 添加console实例用于rich打印 self.console = Console() # 添加console实例用于rich打印
self._state_lock = threading.Lock()
self.max_concurrency = max_concurrency self.max_concurrency = max_concurrency
@ -171,9 +173,10 @@ class CodeGenerator:
return None return None
def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None: def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
"""保存断点续写状态,适应并发生成""" """保存断点续写状态,适应并发生成(线程安全)"""
with self._state_lock: # 串行化写入
state = StateModel( state = StateModel(
current_file_index=0, # 在并发中无效设为0保持兼容 current_file_index=0,
generated_files=generated_files, generated_files=generated_files,
dependencies_map=dependencies_map, dependencies_map=dependencies_map,
total_files=len(self.design.files) if self.design else 0, total_files=len(self.design.files) if self.design else 0,
@ -184,6 +187,7 @@ class CodeGenerator:
json.dump(state.model_dump(), f, indent=2, ensure_ascii=False) json.dump(state.model_dump(), f, indent=2, ensure_ascii=False)
logger.debug(f"状态已保存: {self.state_file}") logger.debug(f"状态已保存: {self.state_file}")
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]: def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
""" """
从design.json获取文件列表和依赖关系 从design.json获取文件列表和依赖关系
@ -471,36 +475,36 @@ class CodeGenerator:
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]: def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
""" """
对文件列表进行拓扑排序基于依赖关系 对文件列表进行拓扑排序基于依赖关系
返回排序后的列表满足每个文件的依赖项都出现在该文件之前
Args: 如果检测到循环依赖抛出ValueError
files: 文件路径列表
dependencies: 依赖字典{file: [依赖文件]}
Returns:
List[str]: 拓扑排序后的文件列表
Raises:
ValueError: 如果检测到循环依赖
""" """
from collections import deque from collections import deque
# 构建图
graph = {file: dependencies.get(file, []) for file in files}
in_degree = {file: 0 for file in files}
for file in files:
for dep in graph[file]:
if dep in in_degree:
in_degree[dep] += 1
# 初始化入度和反向邻接表
in_degree = {f: 0 for f in files}
rev_graph = {f: [] for f in files} # 记录哪些文件依赖于f
# 构建图如果文件f依赖于dep则增加f的入度并将f加入rev_graph[dep]
for f in files:
for dep in dependencies.get(f, []):
if dep in files: # 只考虑在files中的依赖
in_degree[f] += 1 # f依赖于dep所以f的入度增加
rev_graph[dep].append(f) # dep被f依赖
# 队列初始化为入度为0的文件无依赖的文件
queue = deque([f for f in files if in_degree[f] == 0]) queue = deque([f for f in files if in_degree[f] == 0])
sorted_files = [] sorted_files = []
while queue: while queue:
node = queue.popleft() node = queue.popleft()
sorted_files.append(node) sorted_files.append(node)
for dep in graph[node]: # 所有依赖于node的文件入度减1
in_degree[dep] -= 1 for dependent in rev_graph[node]:
if in_degree[dep] == 0: in_degree[dependent] -= 1
queue.append(dep) if in_degree[dependent] == 0:
queue.append(dependent)
# 检查是否所有文件都已排序(无循环依赖)
if len(sorted_files) != len(files): if len(sorted_files) != len(files):
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。") raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
@ -671,6 +675,7 @@ class CodeGenerator:
done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0) done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0)
for future in done: for future in done:
file = futures.pop(future) file = futures.pop(future)
try:
success, error_msg = future.result() success, error_msg = future.result()
# 更新文件进度任务 # 更新文件进度任务
if file in file_tasks: if file in file_tasks:
@ -697,6 +702,22 @@ class CodeGenerator:
logger.error(f"文件 {file} 生成失败,错误: {error_msg}") logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]") self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
# 错误处理:继续处理其他文件,但记录失败 # 错误处理:继续处理其他文件,但记录失败
except Exception as e:
# 捕获 Future 中存储的异常
logger.error(f"任务 {file} 执行时发生异常: {e}")
self.console.print(f"[bold red]❌ 任务 {file} 执行时发生异常: {e}[/bold red]")
# 将其视为失败
success = False
error_msg = str(e)
# 然后执行和上面 `else` 分支相同的失败处理逻辑
if file in file_tasks:
progress.update(file_tasks[file], description=f"生成失败: {file}")
progress.remove_task(file_tasks[file])
del file_tasks[file] # 清理映射
logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
# 错误处理:继续处理其他文件,但记录失败
logger.success("所有文件处理完成!") logger.success("所有文件处理完成!")
# 清理状态文件 # 清理状态文件
@ -708,7 +729,6 @@ class CodeGenerator:
logger.error(f"清理状态文件失败: {e}") logger.error(f"清理状态文件失败: {e}")
self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]") self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]")
def process_issue(self, issue_content: str, issue_type: str) -> bool: def process_issue(self, issue_content: str, issue_type: str) -> bool:
""" """
处理需求增强或 Bug 修复工单 处理需求增强或 Bug 修复工单

View File

@ -1,16 +1,12 @@
""" """Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。"""
Diff 应用模块使用 GitPython 解析和应用 unified diff 格式
"""
import os import os
import sys
from typing import List, Dict, Any from typing import List, Dict, Any
import git # 需要安装 GitPython from unidiff import PatchSet, Hunk # 需要安装 unidiff2
def parse_diff(diff: str) -> List[str]: def parse_diff(diff: str) -> List[str]:
""" """
解析 unified diff 字符串提取受影响的文件路径 解析 unified diff 字符串提取受影响的文件路径
此函数使用 unidiff2 库来解析 diff
Args: Args:
diff: unified diff 格式的字符串 diff: unified diff 格式的字符串
@ -18,27 +14,74 @@ def parse_diff(diff: str) -> List[str]:
Returns: Returns:
文件路径列表 文件路径列表
""" """
files = set() try:
for line in diff.split('\n'): patch_set = PatchSet(diff)
if line.startswith('--- a/'): # unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件
# 提取旧文件路径 # 使用集合去重
path = line[6:].strip() files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名
if path and path != '/dev/null': # /dev/null 表示新文件 # 过滤掉 /dev/null
files.add(path) return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')]
elif line.startswith('+++ b/'): except Exception as e:
# 提取新文件路径 # 如果解析失败,抛出异常
path = line[6:].strip() raise e
if path and path != '/dev/null': # /dev/null 表示删除文件
files.add(path) def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]:
return list(files) """
将一个文件的补丁多个hunk应用到其内容上
Args:
file_content_lines: 文件内容的行列表
patch_hunks: 针对该文件的一个或多个Hunk对象列表
Returns:
应用了补丁后的新内容行列表
"""
# 为了正确应用多个 hunk必须从后往前处理这样前面的修改才不会影响后面 hunk 的行号
sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True)
current_lines = file_content_lines[:]
for hunk in sorted_hunks:
source_start = hunk.source_start - 1 # 转换为0索引
source_len = hunk.source_length
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' '
actual_source_lines = current_lines[source_start : source_start + source_len]
# 提取实际源行和上下文/删除行进行比较rstrip换行符
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
if source_lines_from_diff != actual_source_for_comparison:
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
# 构建新的内容部分
new_part = []
for line_obj in hunk:
if line_obj.is_added or line_obj.is_context:
# 添加新增行或上下文行
# unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容
# line.value 包含符号line.text 不包含符号,但可能不准确
# 最可靠的方法是 strip() 掉行首的 +/-/ 空格
clean_line = str(line_obj).strip()
# 如果原文件行尾有换行符,我们也要加上
original_trailing = ""
if current_lines and current_lines[0].endswith('\n'):
original_trailing = "\n"
new_part.append(clean_line + original_trailing)
# 替换原内容
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
current_lines = new_lines
return current_lines
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
""" """
应用 unified diff 到指定目录使用 GitPython 应用 unified diff 到指定目录
该函数解析 diff读取磁盘上的文件应用更改并写回文件
该函数解析 diff 并尝试应用更改到文件系统如果目标目录不是 git 仓库
将尝试初始化一个临时仓库或报告错误
Args: Args:
diff: unified diff 格式的字符串 diff: unified diff 格式的字符串
@ -64,9 +107,11 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
result['message'] = 'Diff string is empty' result['message'] = 'Diff string is empty'
return result return result
# 解析 diff 获取文件列表 # 解析 diff 获取 PatchSet 对象
try: try:
affected_files = parse_diff_files(diff) patch_set = PatchSet(diff)
# 收集所有需要修改的目标文件路径
affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null']
except Exception as e: except Exception as e:
result['message'] = f"Failed to parse diff: {str(e)}" result['message'] = f"Failed to parse diff: {str(e)}"
result['error_details'] = str(e) result['error_details'] = str(e)
@ -78,39 +123,57 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
return result return result
try: try:
# 尝试获取或初始化 git 仓库 # 遍历每个 patch (即每个文件的变更)
for patch_obj in patch_set:
# 获取目标文件路径 (修改后的文件名)
target_path = patch_obj.target_file.lstrip('/')
if not target_path or target_path == '/dev/null':
# 如果目标是 /dev/null则是删除操作我们跳过或记录
continue
full_file_path = os.path.join(target_dir, target_path)
# 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
source_path = patch_obj.source_file
if source_path == '/dev/null':
# 源是 /dev/null说明这是一个新文件内容从空开始
original_content_lines = []
else:
# 源是普通文件,尝试读取
try: try:
repo = git.Repo(target_dir) with open(full_file_path, 'r', encoding='utf-8') as f:
except git.exc.InvalidGitRepositoryError: original_content_lines = f.readlines()
# 如果不是 git 仓库,初始化一个 except FileNotFoundError:
repo = git.Repo.init(target_dir) # 如果文件不存在但diff不是新文件这可能是个错误
# 添加所有现有文件到索引,以便应用 diff # 但对于新文件,这种情况也可能发生,因为路径可能不同
repo.git.add('--all') # 我们继续,但要注意如果内容不匹配会报错
original_content_lines = []
# 应用 diff # 应用此文件的补丁 (所有 hunks)
# 使用 git apply 命令,通过 stdin 传入 diff modified_content_lines = _apply_single_patch_to_content(
# '--whitespace=nowarn' 忽略空白警告 original_content_lines,
output = repo.git.apply('--whitespace=nowarn', '--', input=diff) list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器
)
# 如果成功,更新结果 # 将修改后的内容写回文件
# 确保目录存在
os.makedirs(os.path.dirname(full_file_path), exist_ok=True)
with open(full_file_path, 'w', encoding='utf-8', newline='') as f:
f.writelines(modified_content_lines)
# 如果所有 patch 都成功应用
result['success'] = True result['success'] = True
result['message'] = 'Diff applied successfully' result['message'] = 'Diff applied successfully'
result['applied_files'] = affected_files result['applied_files'] = affected_files
if output:
result['message'] += f". Output: {output}"
except git.exc.GitError as e:
# 处理 git 错误,如冲突、格式错误等
result['message'] = f"Git error while applying diff: {str(e)}"
result['error_details'] = str(e)
except Exception as e: except Exception as e:
# 处理其他异常 # 处理应用过程中可能出现的任何错误
result['message'] = f"Unexpected error: {str(e)}" result['message'] = f"Error while applying diff: {str(e)}"
result['error_details'] = str(e) result['error_details'] = str(e)
return result return result
# 如果作为脚本运行,可以提供简单的测试 # 如果作为脚本运行,可以提供简单的测试
if __name__ == "__main__": if __name__ == "__main__":
# 示例用法 # 示例用法
@ -120,6 +183,7 @@ if __name__ == "__main__":
-Hello World -Hello World
+Hello Universe +Hello Universe
""" """
print("Testing apply_diff...")
print("Testing apply_diff with unidiff2...")
res = apply_diff(sample_diff, ".") res = apply_diff(sample_diff, ".")
print(res) print(res)

View File

@ -1,23 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """Unit tests for diff_applier.py, covering various scenarios such as new file creation,
Unit tests for diff_applier.py, covering various scenarios such as new file creation, modification of existing files, conflict handling, and error cases."""
modification of existing files, conflict handling, and error cases.
"""
import os import os
import sys import sys
import tempfile import tempfile
import shutil import shutil
import pytest import pytest
import git # GitPython is required; assumed installed via project dependencies
# Add src directory to path for module import # Add src directory to path for module import
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from llm_codegen.diff_applier import parse_diff_files, apply_diff # 假设您的文件名为 diff_applier_unidiff2.py
from src.llm_codegen.diff_applier import parse_diff, apply_diff
def test_parse_diff_files(): def test_parse_diff():
"""Test parsing unified diff strings to extract file paths.""" """Test parsing unified diff strings to extract file paths."""
# Diff with modification and new file # Diff with modification and new file
diff = """--- a/file1.txt diff = """--- a/file1.txt
@ -25,128 +22,102 @@ def test_parse_diff_files():
@@ -1 +1 @@ @@ -1 +1 @@
-old content -old content
+new content +new content
--- a/new_file.txt --- /dev/null
+++ b/new_file.txt +++ b/new_file.txt
@@ -0,0 +1 @@ @@ -0,0 +1 @@
+new file content +new file content"""
""" files = parse_diff(diff)
files = parse_diff_files(diff)
assert set(files) == {'file1.txt', 'new_file.txt'} assert set(files) == {'file1.txt', 'new_file.txt'}
# Diff with file deletion # Diff with file deletion
diff_del = """--- a/deleted.txt diff_del = """--- a/deleted.txt
+++ /dev/null +++ /dev/null
@@ -1 +0,0 @@ @@ -1 +0,0 @@
-content to delete -content to delete"""
""" files = parse_diff(diff_del)
files = parse_diff_files(diff_del) # 删除操作的解析结果应为空,因为没有实际的“目标”文件被创建或修改
assert files == ['deleted.txt'] # Old file path is extracted assert files == []
# Empty diff # Empty diff
assert parse_diff_files('') == [] assert parse_diff('') == []
assert parse_diff_files('\n') == [] assert parse_diff('\n') == []
# Diff with only new file (no old file) # Diff with only new file (no old file)
diff_new_only = """--- /dev/null diff_new_only = """--- /dev/null
+++ b/only_new.txt +++ b/only_new.txt
@@ -0,0 +1 @@ @@ -0,0 +1 @@
+only new +only new"""
""" files = parse_diff(diff_new_only)
files = parse_diff_files(diff_new_only) assert files == ['only_new.txt']
assert files == [] /dev/null is ignored
# Invalid diff format (should still handle gracefully) # Invalid diff format (should still handle gracefully)
diff_invalid = "invalid diff string" diff_invalid = "invalid diff string"
files = parse_diff_files(diff_invalid) files = parse_diff(diff_invalid)
assert files == [] # No valid file paths found assert files == [] # No valid file paths found
def test_apply_diff_new_file():
@pytest.fixture
def temp_git_repo():
"""Create a temporary git repository for testing apply_diff."""
temp_dir = tempfile.mkdtemp()
repo = git.Repo.init(temp_dir)
# Make an initial commit to have a clean state
init_file = os.path.join(temp_dir, 'README.md')
with open(init_file, 'w') as f:
f.write('Initial commit content')
repo.git.add('README.md')
repo.git.commit('-m', 'initial commit')
yield temp_dir, repo
# Cleanup after test
shutil.rmtree(temp_dir, ignore_errors=True)
def test_apply_diff_new_file(temp_git_repo):
"""Test applying a diff that creates a new file.""" """Test applying a diff that creates a new file."""
temp_dir, repo = temp_git_repo with tempfile.TemporaryDirectory() as temp_dir:
diff = """--- /dev/null diff = """--- /dev/null
+++ b/test_new.txt +++ b/test_new.txt
@@ -0,0 +1 @@ @@ -0,0 +1 @@
+This is a newly created file. +This is a newly created file."""
"""
result = apply_diff(diff, temp_dir) result = apply_diff(diff, temp_dir)
assert result['success'] == True assert result['success'] == True
assert 'test_new.txt' in result['applied_files'] assert 'test_new.txt' in result['applied_files']
new_file_path = os.path.join(temp_dir, 'test_new.txt') new_file_path = os.path.join(temp_dir, 'test_new.txt')
assert os.path.exists(new_file_path) assert os.path.exists(new_file_path)
with open(new_file_path, 'r') as f: with open(new_file_path, 'r') as f:
content = f.read() content = f.read()
# 修复断言strip()会去掉行末的换行符,使内容匹配
assert content.strip() == 'This is a newly created file.' assert content.strip() == 'This is a newly created file.'
def test_apply_diff_modify_existing_file():
def test_apply_diff_modify_existing_file(temp_git_repo):
"""Test applying a diff that modifies an existing file.""" """Test applying a diff that modifies an existing file."""
temp_dir, repo = temp_git_repo with tempfile.TemporaryDirectory() as temp_dir:
# Create an existing file and commit it # Create an existing file
existing_file = os.path.join(temp_dir, 'existing.txt') existing_file = os.path.join(temp_dir, 'existing.txt')
with open(existing_file, 'w') as f: with open(existing_file, 'w') as f:
f.write('Original line 1\nOriginal line 2') f.write('Original line 1\nOriginal line 2\n')
repo.git.add('existing.txt')
repo.git.commit('-m', 'add existing file')
diff = """--- a/existing.txt diff = """--- a/existing.txt
+++ b/existing.txt +++ b/existing.txt
@@ -1,2 +1,2 @@ @@ -1,2 +1,2 @@
-Original line 1 -Original line 1
+Modified line 1 +Modified line 1
Original line 2 Original line 2"""
"""
result = apply_diff(diff, temp_dir) result = apply_diff(diff, temp_dir)
assert result['success'] == True assert result['success'] == True
assert 'existing.txt' in result['applied_files'] assert 'existing.txt' in result['applied_files']
with open(existing_file, 'r') as f: with open(existing_file, 'r') as f:
content = f.read() content = f.read()
assert content == 'Modified line 1\nOriginal line 2' assert content == 'Modified line 1\nOriginal line 2\n'
def test_apply_diff_conflict_handling():
def test_apply_diff_conflict_handling(temp_git_repo): """Test applying a diff that causes a conflict because the source doesn't match."""
"""Test applying a diff that causes a conflict with uncommitted changes.""" with tempfile.TemporaryDirectory() as temp_dir:
temp_dir, repo = temp_git_repo # Create a file with specific content
# Create and commit a file
conflict_file = os.path.join(temp_dir, 'conflict.txt') conflict_file = os.path.join(temp_dir, 'conflict.txt')
with open(conflict_file, 'w') as f: with open(conflict_file, 'w') as f:
f.write('Initial line 1\nInitial line 2') f.write('Different line 1\nOriginal line 2\n') # This is different from what diff expects
repo.git.add('conflict.txt')
repo.git.commit('-m', 'add conflict file')
# Modify the file without committing to create conflict
with open(conflict_file, 'w') as f:
f.write('Changed line 1\nInitial line 2') # Change first line
# This diff expects 'Initial line 1' but finds 'Different line 1'
diff = """--- a/conflict.txt diff = """--- a/conflict.txt
+++ b/conflict.txt +++ b/conflict.txt
@@ -1,2 +1,2 @@ @@ -1,2 +1,2 @@
-Initial line 1 -Initial line 1
+Diff line 1 +Diff line 1
Initial line 2 Initial line 2"""
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == False
# Check for conflict or error in message
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
assert result['error_details'] != ''
result = apply_diff(diff, temp_dir)
assert result['success'] == False # Should fail due to mismatch
# Check for conflict or error in message
assert 'error' in result['message'].lower() or 'does not match' in result['message'].lower()
assert result['error_details'] != ''
def test_apply_diff_empty_diff(): def test_apply_diff_empty_diff():
"""Test applying an empty diff string.""" """Test applying an empty diff string."""
@ -154,7 +125,6 @@ def test_apply_diff_empty_diff():
assert result['success'] == False assert result['success'] == False
assert 'empty' in result['message'].lower() assert 'empty' in result['message'].lower()
def test_apply_diff_invalid_directory(): def test_apply_diff_invalid_directory():
"""Test applying a diff to a non-existent directory.""" """Test applying a diff to a non-existent directory."""
non_existent_dir = '/tmp/non_existent_dir_12345' non_existent_dir = '/tmp/non_existent_dir_12345'
@ -162,36 +132,33 @@ def test_apply_diff_invalid_directory():
+++ b/dummy.txt +++ b/dummy.txt
@@ -1 +1 @@ @@ -1 +1 @@
-old -old
+new +new"""
"""
result = apply_diff(diff, non_existent_dir) result = apply_diff(diff, non_existent_dir)
assert result['success'] == False assert result['success'] == False
assert 'does not exist' in result['message'].lower() assert 'does not exist' in result['message'].lower()
def test_apply_diff_no_git_repo_initialization(): def test_apply_diff_no_git_repo_initialization():
"""Test applying a diff to a non-git directory, which should initialize a repo.""" """Test applying a diff to a non-git directory. This test is now redundant as there's no git dependency."""
temp_dir = tempfile.mkdtemp() with tempfile.TemporaryDirectory() as temp_dir:
try:
# Create a non-git directory with a file # Create a non-git directory with a file
non_git_file = os.path.join(temp_dir, 'non_git.txt') non_git_file = os.path.join(temp_dir, 'non_git.txt')
with open(non_git_file, 'w') as f: with open(non_git_file, 'w') as f:
f.write('Pre-existing content') f.write('Pre-existing content\n')
diff = """--- a/non_git.txt diff = """--- a/non_git.txt
+++ b/non_git.txt +++ b/non_git.txt
@@ -1 +1 @@ @@ -1 +1 @@
-Pre-existing content -Pre-existing content
+Updated content +Updated content"""
"""
result = apply_diff(diff, temp_dir) result = apply_diff(diff, temp_dir)
assert result['success'] == True assert result['success'] == True
assert 'non_git.txt' in result['applied_files'] assert 'non_git.txt' in result['applied_files']
with open(non_git_file, 'r') as f: with open(non_git_file, 'r') as f:
content = f.read() content = f.read()
assert content == 'Updated content' assert content == 'Updated content\n'
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__": if __name__ == "__main__":