diff --git a/design.json b/design.json index 2ed3909..513a388 100644 --- a/design.json +++ b/design.json @@ -23,9 +23,9 @@ "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], "functions": [ { - "name": "main", - "summary": "主CLI入口,处理命令行参数并启动生成器", - "inputs": ["readme", "output_dir", "api_key", "base_url", "model", "log_file"], + "name": "", + "summary": "", + "inputs": [], "outputs": [] } ], diff --git a/pyproject.toml b/pyproject.toml index 3593a2b..9e6dd42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "loguru>=0.7.0", "openai>=1.0.0", "pathspec>=1.0.4", - "gitpython>=3.1.0", + "python-patch>=0.0.1", + "unidiff2>=0.7.8", ] authors = [ {name = "Your Name", email = "your.email@example.com"} @@ -33,6 +34,10 @@ dev = [ "black>=23.0.0", ] +[[tool.uv.index]] +url="https://pypi.tuna.tsinghua.edu.cn/simple" +default=true + [project.scripts] llm-codegen = "llm_codegen.cli:app" @@ -45,3 +50,8 @@ max_concurrent_requests = 5 # 新增:指定包所在目录 [tool.setuptools.packages.find] where = ["src"] + +[dependency-groups] +dev = [ + "pytest>=8.4.2", +] diff --git a/src/llm_codegen/core.py b/src/llm_codegen/core.py index 6073b04..bb2193d 100644 --- a/src/llm_codegen/core.py +++ b/src/llm_codegen/core.py @@ -7,6 +7,7 @@ import difflib from typing import List, Dict, Optional, Any, Tuple from pathlib import Path from collections import deque +import threading from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID @@ -50,6 +51,7 @@ class CodeGenerator: self.output_dir.mkdir(parents=True, exist_ok=True) self.state_file = self.output_dir / ".llm_generator_state.json" self.console = Console() # 添加console实例用于rich打印 + self._state_lock = threading.Lock() self.max_concurrency = max_concurrency @@ -171,18 +173,20 @@ class CodeGenerator: return None def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None: - """保存断点续写状态,适应并发生成""" - state = StateModel( - current_file_index=0, # 在并发中无效,设为0保持兼容 - generated_files=generated_files, - dependencies_map=dependencies_map, - total_files=len(self.design.files) if self.design else 0, - output_dir=str(self.output_dir), - readme_path=self.readme_content[:100] if self.readme_content else "" - ) - with open(self.state_file, "w", encoding="utf-8") as f: - json.dump(state.model_dump(), f, indent=2, ensure_ascii=False) - logger.debug(f"状态已保存: {self.state_file}") + """保存断点续写状态,适应并发生成(线程安全)""" + with self._state_lock: # 串行化写入 + state = StateModel( + current_file_index=0, + generated_files=generated_files, + dependencies_map=dependencies_map, + total_files=len(self.design.files) if self.design else 0, + output_dir=str(self.output_dir), + readme_path=self.readme_content[:100] if self.readme_content else "" + ) + with open(self.state_file, "w", encoding="utf-8") as f: + json.dump(state.model_dump(), f, indent=2, ensure_ascii=False) + logger.debug(f"状态已保存: {self.state_file}") + def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]: """ @@ -471,39 +475,39 @@ class CodeGenerator: def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]: """ 对文件列表进行拓扑排序,基于依赖关系。 - - Args: - files: 文件路径列表 - dependencies: 依赖字典,{file: [依赖文件]} - - Returns: - List[str]: 拓扑排序后的文件列表 - - Raises: - ValueError: 如果检测到循环依赖 + 返回排序后的列表,满足每个文件的依赖项都出现在该文件之前。 + 如果检测到循环依赖,抛出ValueError。 """ from collections import deque - # 构建图 - graph = {file: dependencies.get(file, []) for file in files} - in_degree = {file: 0 for file in files} - for file in files: - for dep in graph[file]: - if dep in in_degree: - in_degree[dep] += 1 - + + # 初始化入度和反向邻接表 + in_degree = {f: 0 for f in files} + rev_graph = {f: [] for f in files} # 记录哪些文件依赖于f + + # 构建图:如果文件f依赖于dep,则增加f的入度,并将f加入rev_graph[dep] + for f in files: + for dep in dependencies.get(f, []): + if dep in files: # 只考虑在files中的依赖 + in_degree[f] += 1 # f依赖于dep,所以f的入度增加 + rev_graph[dep].append(f) # dep被f依赖 + + # 队列初始化为入度为0的文件(无依赖的文件) queue = deque([f for f in files if in_degree[f] == 0]) sorted_files = [] + while queue: node = queue.popleft() sorted_files.append(node) - for dep in graph[node]: - in_degree[dep] -= 1 - if in_degree[dep] == 0: - queue.append(dep) - + # 所有依赖于node的文件入度减1 + for dependent in rev_graph[node]: + in_degree[dependent] -= 1 + if in_degree[dependent] == 0: + queue.append(dependent) + + # 检查是否所有文件都已排序(无循环依赖) if len(sorted_files) != len(files): raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。") - + return sorted_files def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool: @@ -671,33 +675,50 @@ class CodeGenerator: done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0) for future in done: file = futures.pop(future) - success, error_msg = future.result() - # 更新文件进度任务 - if file in file_tasks: + try: + success, error_msg = future.result() + # 更新文件进度任务 + if file in file_tasks: + if success: + progress.update(file_tasks[file], completed=1) + progress.remove_task(file_tasks[file]) # 移除任务 + else: + # 如果失败,标记为错误状态 + progress.update(file_tasks[file], description=f"生成失败: {file}") + progress.remove_task(file_tasks[file]) + del file_tasks[file] # 清理映射 if success: - progress.update(file_tasks[file], completed=1) - progress.remove_task(file_tasks[file]) # 移除任务 + processed_files.add(file) + # 更新入度:减少依赖该文件的节点的入度 + for other_file in files: + if file in dependencies.get(other_file, []): + in_degree[other_file] -= 1 + if in_degree[other_file] == 0 and other_file not in processed_files: + queue.append(other_file) + # 保存状态 + self.save_state(list(processed_files), dependencies) + progress.update(total_task, advance=1) # 更新整体进度 else: - # 如果失败,标记为错误状态 + logger.error(f"文件 {file} 生成失败,错误: {error_msg}") + self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]") + # 错误处理:继续处理其他文件,但记录失败 + except Exception as e: + # 捕获 Future 中存储的异常 + logger.error(f"任务 {file} 执行时发生异常: {e}") + self.console.print(f"[bold red]❌ 任务 {file} 执行时发生异常: {e}[/bold red]") + # 将其视为失败 + success = False + error_msg = str(e) + # 然后执行和上面 `else` 分支相同的失败处理逻辑 + if file in file_tasks: progress.update(file_tasks[file], description=f"生成失败: {file}") progress.remove_task(file_tasks[file]) - del file_tasks[file] # 清理映射 - if success: - processed_files.add(file) - # 更新入度:减少依赖该文件的节点的入度 - for other_file in files: - if file in dependencies.get(other_file, []): - in_degree[other_file] -= 1 - if in_degree[other_file] == 0 and other_file not in processed_files: - queue.append(other_file) - # 保存状态 - self.save_state(list(processed_files), dependencies) - progress.update(total_task, advance=1) # 更新整体进度 - else: + del file_tasks[file] # 清理映射 logger.error(f"文件 {file} 生成失败,错误: {error_msg}") self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]") # 错误处理:继续处理其他文件,但记录失败 + logger.success("所有文件处理完成!") # 清理状态文件 if self.state_file.exists(): @@ -708,7 +729,6 @@ class CodeGenerator: logger.error(f"清理状态文件失败: {e}") self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]") - def process_issue(self, issue_content: str, issue_type: str) -> bool: """ 处理需求增强或 Bug 修复工单 diff --git a/src/llm_codegen/diff_applier.py b/src/llm_codegen/diff_applier.py index ee5a5e5..05b349f 100644 --- a/src/llm_codegen/diff_applier.py +++ b/src/llm_codegen/diff_applier.py @@ -1,49 +1,92 @@ -""" -Diff 应用模块,使用 GitPython 解析和应用 unified diff 格式。 -""" - +"""Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。""" import os -import sys from typing import List, Dict, Any -import git # 需要安装 GitPython - +from unidiff import PatchSet, Hunk # 需要安装 unidiff2 def parse_diff(diff: str) -> List[str]: """ 解析 unified diff 字符串,提取受影响的文件路径。 - + 此函数使用 unidiff2 库来解析 diff。 + Args: diff: unified diff 格式的字符串。 - + Returns: 文件路径列表。 """ - files = set() - for line in diff.split('\n'): - if line.startswith('--- a/'): - # 提取旧文件路径 - path = line[6:].strip() - if path and path != '/dev/null': # /dev/null 表示新文件 - files.add(path) - elif line.startswith('+++ b/'): - # 提取新文件路径 - path = line[6:].strip() - if path and path != '/dev/null': # /dev/null 表示删除文件 - files.add(path) - return list(files) + try: + patch_set = PatchSet(diff) + # unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件 + # 使用集合去重 + files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名 + # 过滤掉 /dev/null + return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')] + except Exception as e: + # 如果解析失败,抛出异常 + raise e + +def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]: + """ + 将一个文件的补丁(多个hunk)应用到其内容上。 + + Args: + file_content_lines: 文件内容的行列表。 + patch_hunks: 针对该文件的一个或多个Hunk对象列表。 + + Returns: + 应用了补丁后的新内容行列表。 + """ + # 为了正确应用多个 hunk,必须从后往前处理,这样前面的修改才不会影响后面 hunk 的行号 + sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True) + + current_lines = file_content_lines[:] + + for hunk in sorted_hunks: + source_start = hunk.source_start - 1 # 转换为0索引 + source_len = hunk.source_length + + # 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证) + source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' ' + actual_source_lines = current_lines[source_start : source_start + source_len] + + # 提取实际源行和上下文/删除行进行比较,rstrip换行符 + actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines] + + if source_lines_from_diff != actual_source_for_comparison: + raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}") + + # 构建新的内容部分 + new_part = [] + for line_obj in hunk: + if line_obj.is_added or line_obj.is_context: + # 添加新增行或上下文行 + # unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容 + # line.value 包含符号,line.text 不包含符号,但可能不准确 + # 最可靠的方法是 strip() 掉行首的 +/-/ 空格 + clean_line = str(line_obj).strip() + # 如果原文件行尾有换行符,我们也要加上 + original_trailing = "" + if current_lines and current_lines[0].endswith('\n'): + original_trailing = "\n" + + new_part.append(clean_line + original_trailing) + + # 替换原内容 + new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:] + current_lines = new_lines + + return current_lines def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: """ - 应用 unified diff 到指定目录,使用 GitPython 库。 - - 该函数解析 diff 并尝试应用更改到文件系统。如果目标目录不是 git 仓库, - 将尝试初始化一个临时仓库或报告错误。 - + 应用 unified diff 到指定目录。 + 该函数解析 diff,读取磁盘上的文件,应用更改,并写回文件。 + Args: diff: unified diff 格式的字符串。 target_dir: 目标目录路径,默认为当前目录。 - + Returns: 字典,包含以下键: - 'success' (bool): 是否成功应用。 @@ -58,58 +101,78 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: 'applied_files': [], 'error_details': '' } - + # 检查 diff 是否为空 if not diff or diff.strip() == '': result['message'] = 'Diff string is empty' return result - - # 解析 diff 获取文件列表 + + # 解析 diff 获取 PatchSet 对象 try: - affected_files = parse_diff_files(diff) + patch_set = PatchSet(diff) + # 收集所有需要修改的目标文件路径 + affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null'] except Exception as e: result['message'] = f"Failed to parse diff: {str(e)}" result['error_details'] = str(e) return result - + # 检查目标目录是否存在 if not os.path.isdir(target_dir): result['message'] = f"Target directory does not exist: {target_dir}" return result - + try: - # 尝试获取或初始化 git 仓库 - try: - repo = git.Repo(target_dir) - except git.exc.InvalidGitRepositoryError: - # 如果不是 git 仓库,初始化一个 - repo = git.Repo.init(target_dir) - # 添加所有现有文件到索引,以便应用 diff - repo.git.add('--all') - - # 应用 diff - # 使用 git apply 命令,通过 stdin 传入 diff - # '--whitespace=nowarn' 忽略空白警告 - output = repo.git.apply('--whitespace=nowarn', '--', input=diff) - - # 如果成功,更新结果 + # 遍历每个 patch (即每个文件的变更) + for patch_obj in patch_set: + # 获取目标文件路径 (修改后的文件名) + target_path = patch_obj.target_file.lstrip('/') + if not target_path or target_path == '/dev/null': + # 如果目标是 /dev/null,则是删除操作,我们跳过或记录 + continue + + full_file_path = os.path.join(target_dir, target_path) + + # 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null) + source_path = patch_obj.source_file + if source_path == '/dev/null': + # 源是 /dev/null,说明这是一个新文件,内容从空开始 + original_content_lines = [] + else: + # 源是普通文件,尝试读取 + try: + with open(full_file_path, 'r', encoding='utf-8') as f: + original_content_lines = f.readlines() + except FileNotFoundError: + # 如果文件不存在,但diff不是新文件,这可能是个错误 + # 但对于新文件,这种情况也可能发生,因为路径可能不同 + # 我们继续,但要注意如果内容不匹配会报错 + original_content_lines = [] + + # 应用此文件的补丁 (所有 hunks) + modified_content_lines = _apply_single_patch_to_content( + original_content_lines, + list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器 + ) + + # 将修改后的内容写回文件 + # 确保目录存在 + os.makedirs(os.path.dirname(full_file_path), exist_ok=True) + + with open(full_file_path, 'w', encoding='utf-8', newline='') as f: + f.writelines(modified_content_lines) + + # 如果所有 patch 都成功应用 result['success'] = True result['message'] = 'Diff applied successfully' result['applied_files'] = affected_files - if output: - result['message'] += f". Output: {output}" - - except git.exc.GitError as e: - # 处理 git 错误,如冲突、格式错误等 - result['message'] = f"Git error while applying diff: {str(e)}" - result['error_details'] = str(e) - except Exception as e: - # 处理其他异常 - result['message'] = f"Unexpected error: {str(e)}" - result['error_details'] = str(e) - - return result + except Exception as e: + # 处理应用过程中可能出现的任何错误 + result['message'] = f"Error while applying diff: {str(e)}" + result['error_details'] = str(e) + + return result # 如果作为脚本运行,可以提供简单的测试 if __name__ == "__main__": @@ -120,6 +183,7 @@ if __name__ == "__main__": -Hello World +Hello Universe """ - print("Testing apply_diff...") + + print("Testing apply_diff with unidiff2...") res = apply_diff(sample_diff, ".") print(res) diff --git a/tests/test_diff_applier.py b/tests/test_diff_applier.py index 49cba51..c13b90e 100644 --- a/tests/test_diff_applier.py +++ b/tests/test_diff_applier.py @@ -1,23 +1,20 @@ #!/usr/bin/env python3 -""" -Unit tests for diff_applier.py, covering various scenarios such as new file creation, -modification of existing files, conflict handling, and error cases. -""" - +"""Unit tests for diff_applier.py, covering various scenarios such as new file creation, +modification of existing files, conflict handling, and error cases.""" import os import sys import tempfile import shutil import pytest -import git # GitPython is required; assumed installed via project dependencies # Add src directory to path for module import sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) -from llm_codegen.diff_applier import parse_diff_files, apply_diff +# 假设您的文件名为 diff_applier_unidiff2.py +from src.llm_codegen.diff_applier import parse_diff, apply_diff -def test_parse_diff_files(): +def test_parse_diff(): """Test parsing unified diff strings to extract file paths.""" # Diff with modification and new file diff = """--- a/file1.txt @@ -25,128 +22,102 @@ def test_parse_diff_files(): @@ -1 +1 @@ -old content +new content ---- a/new_file.txt +--- /dev/null +++ b/new_file.txt @@ -0,0 +1 @@ -+new file content -""" - files = parse_diff_files(diff) ++new file content""" + files = parse_diff(diff) assert set(files) == {'file1.txt', 'new_file.txt'} - + # Diff with file deletion diff_del = """--- a/deleted.txt +++ /dev/null @@ -1 +0,0 @@ --content to delete -""" - files = parse_diff_files(diff_del) - assert files == ['deleted.txt'] # Old file path is extracted - +-content to delete""" + files = parse_diff(diff_del) + # 删除操作的解析结果应为空,因为没有实际的“目标”文件被创建或修改 + assert files == [] + # Empty diff - assert parse_diff_files('') == [] - assert parse_diff_files('\n') == [] - + assert parse_diff('') == [] + assert parse_diff('\n') == [] + # Diff with only new file (no old file) diff_new_only = """--- /dev/null +++ b/only_new.txt @@ -0,0 +1 @@ -+only new -""" - files = parse_diff_files(diff_new_only) - assert files == [] /dev/null is ignored - ++only new""" + files = parse_diff(diff_new_only) + assert files == ['only_new.txt'] + # Invalid diff format (should still handle gracefully) diff_invalid = "invalid diff string" - files = parse_diff_files(diff_invalid) - assert files == [] # No valid file paths found + files = parse_diff(diff_invalid) + assert files == [] # No valid file paths found - -@pytest.fixture -def temp_git_repo(): - """Create a temporary git repository for testing apply_diff.""" - temp_dir = tempfile.mkdtemp() - repo = git.Repo.init(temp_dir) - # Make an initial commit to have a clean state - init_file = os.path.join(temp_dir, 'README.md') - with open(init_file, 'w') as f: - f.write('Initial commit content') - repo.git.add('README.md') - repo.git.commit('-m', 'initial commit') - yield temp_dir, repo - # Cleanup after test - shutil.rmtree(temp_dir, ignore_errors=True) - - -def test_apply_diff_new_file(temp_git_repo): +def test_apply_diff_new_file(): """Test applying a diff that creates a new file.""" - temp_dir, repo = temp_git_repo - diff = """--- /dev/null + with tempfile.TemporaryDirectory() as temp_dir: + diff = """--- /dev/null +++ b/test_new.txt @@ -0,0 +1 @@ -+This is a newly created file. -""" - result = apply_diff(diff, temp_dir) - assert result['success'] == True - assert 'test_new.txt' in result['applied_files'] - new_file_path = os.path.join(temp_dir, 'test_new.txt') - assert os.path.exists(new_file_path) - with open(new_file_path, 'r') as f: - content = f.read() - assert content.strip() == 'This is a newly created file.' ++This is a newly created file.""" + result = apply_diff(diff, temp_dir) + assert result['success'] == True + assert 'test_new.txt' in result['applied_files'] -def test_apply_diff_modify_existing_file(temp_git_repo): + new_file_path = os.path.join(temp_dir, 'test_new.txt') + assert os.path.exists(new_file_path) + with open(new_file_path, 'r') as f: + content = f.read() + # 修复断言,strip()会去掉行末的换行符,使内容匹配 + assert content.strip() == 'This is a newly created file.' + +def test_apply_diff_modify_existing_file(): """Test applying a diff that modifies an existing file.""" - temp_dir, repo = temp_git_repo - # Create an existing file and commit it - existing_file = os.path.join(temp_dir, 'existing.txt') - with open(existing_file, 'w') as f: - f.write('Original line 1\nOriginal line 2') - repo.git.add('existing.txt') - repo.git.commit('-m', 'add existing file') - - diff = """--- a/existing.txt + with tempfile.TemporaryDirectory() as temp_dir: + # Create an existing file + existing_file = os.path.join(temp_dir, 'existing.txt') + with open(existing_file, 'w') as f: + f.write('Original line 1\nOriginal line 2\n') + + diff = """--- a/existing.txt +++ b/existing.txt @@ -1,2 +1,2 @@ -Original line 1 +Modified line 1 - Original line 2 -""" - result = apply_diff(diff, temp_dir) - assert result['success'] == True - assert 'existing.txt' in result['applied_files'] - with open(existing_file, 'r') as f: - content = f.read() - assert content == 'Modified line 1\nOriginal line 2' + Original line 2""" + result = apply_diff(diff, temp_dir) + assert result['success'] == True + assert 'existing.txt' in result['applied_files'] -def test_apply_diff_conflict_handling(temp_git_repo): - """Test applying a diff that causes a conflict with uncommitted changes.""" - temp_dir, repo = temp_git_repo - # Create and commit a file - conflict_file = os.path.join(temp_dir, 'conflict.txt') - with open(conflict_file, 'w') as f: - f.write('Initial line 1\nInitial line 2') - repo.git.add('conflict.txt') - repo.git.commit('-m', 'add conflict file') - - # Modify the file without committing to create conflict - with open(conflict_file, 'w') as f: - f.write('Changed line 1\nInitial line 2') # Change first line - - diff = """--- a/conflict.txt + with open(existing_file, 'r') as f: + content = f.read() + assert content == 'Modified line 1\nOriginal line 2\n' + +def test_apply_diff_conflict_handling(): + """Test applying a diff that causes a conflict because the source doesn't match.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a file with specific content + conflict_file = os.path.join(temp_dir, 'conflict.txt') + with open(conflict_file, 'w') as f: + f.write('Different line 1\nOriginal line 2\n') # This is different from what diff expects + + # This diff expects 'Initial line 1' but finds 'Different line 1' + diff = """--- a/conflict.txt +++ b/conflict.txt @@ -1,2 +1,2 @@ -Initial line 1 +Diff line 1 - Initial line 2 -""" - result = apply_diff(diff, temp_dir) - assert result['success'] == False - # Check for conflict or error in message - assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower() - assert result['error_details'] != '' + Initial line 2""" + result = apply_diff(diff, temp_dir) + assert result['success'] == False # Should fail due to mismatch + # Check for conflict or error in message + assert 'error' in result['message'].lower() or 'does not match' in result['message'].lower() + assert result['error_details'] != '' def test_apply_diff_empty_diff(): """Test applying an empty diff string.""" @@ -154,7 +125,6 @@ def test_apply_diff_empty_diff(): assert result['success'] == False assert 'empty' in result['message'].lower() - def test_apply_diff_invalid_directory(): """Test applying a diff to a non-existent directory.""" non_existent_dir = '/tmp/non_existent_dir_12345' @@ -162,36 +132,33 @@ def test_apply_diff_invalid_directory(): +++ b/dummy.txt @@ -1 +1 @@ -old -+new -""" ++new""" + result = apply_diff(diff, non_existent_dir) assert result['success'] == False assert 'does not exist' in result['message'].lower() - def test_apply_diff_no_git_repo_initialization(): - """Test applying a diff to a non-git directory, which should initialize a repo.""" - temp_dir = tempfile.mkdtemp() - try: + """Test applying a diff to a non-git directory. This test is now redundant as there's no git dependency.""" + with tempfile.TemporaryDirectory() as temp_dir: # Create a non-git directory with a file non_git_file = os.path.join(temp_dir, 'non_git.txt') with open(non_git_file, 'w') as f: - f.write('Pre-existing content') - + f.write('Pre-existing content\n') + diff = """--- a/non_git.txt +++ b/non_git.txt @@ -1 +1 @@ -Pre-existing content -+Updated content -""" ++Updated content""" + result = apply_diff(diff, temp_dir) assert result['success'] == True assert 'non_git.txt' in result['applied_files'] + with open(non_git_file, 'r') as f: content = f.read() - assert content == 'Updated content' - finally: - shutil.rmtree(temp_dir, ignore_errors=True) + assert content == 'Updated content\n' if __name__ == "__main__":