重构代码以支持并发生成和线程安全状态保存,优化拓扑排序与补丁应用逻辑

This commit is contained in:
songsenand 2026-03-18 16:42:14 +08:00
parent 6e536f141c
commit 75fc8659a0
5 changed files with 296 additions and 235 deletions

View File

@ -23,9 +23,9 @@
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
"functions": [ "functions": [
{ {
"name": "main", "name": "",
"summary": "主CLI入口处理命令行参数并启动生成器", "summary": "",
"inputs": ["readme", "output_dir", "api_key", "base_url", "model", "log_file"], "inputs": [],
"outputs": [] "outputs": []
} }
], ],

View File

@ -14,7 +14,8 @@ dependencies = [
"loguru>=0.7.0", "loguru>=0.7.0",
"openai>=1.0.0", "openai>=1.0.0",
"pathspec>=1.0.4", "pathspec>=1.0.4",
"gitpython>=3.1.0", "python-patch>=0.0.1",
"unidiff2>=0.7.8",
] ]
authors = [ authors = [
{name = "Your Name", email = "your.email@example.com"} {name = "Your Name", email = "your.email@example.com"}
@ -33,6 +34,10 @@ dev = [
"black>=23.0.0", "black>=23.0.0",
] ]
[[tool.uv.index]]
url="https://pypi.tuna.tsinghua.edu.cn/simple"
default=true
[project.scripts] [project.scripts]
llm-codegen = "llm_codegen.cli:app" llm-codegen = "llm_codegen.cli:app"
@ -45,3 +50,8 @@ max_concurrent_requests = 5
# 新增:指定包所在目录 # 新增:指定包所在目录
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]
[dependency-groups]
dev = [
"pytest>=8.4.2",
]

View File

@ -7,6 +7,7 @@ import difflib
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
import threading
from rich.console import Console from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
@ -50,6 +51,7 @@ class CodeGenerator:
self.output_dir.mkdir(parents=True, exist_ok=True) self.output_dir.mkdir(parents=True, exist_ok=True)
self.state_file = self.output_dir / ".llm_generator_state.json" self.state_file = self.output_dir / ".llm_generator_state.json"
self.console = Console() # 添加console实例用于rich打印 self.console = Console() # 添加console实例用于rich打印
self._state_lock = threading.Lock()
self.max_concurrency = max_concurrency self.max_concurrency = max_concurrency
@ -171,18 +173,20 @@ class CodeGenerator:
return None return None
def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None: def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
"""保存断点续写状态,适应并发生成""" """保存断点续写状态,适应并发生成(线程安全)"""
state = StateModel( with self._state_lock: # 串行化写入
current_file_index=0, # 在并发中无效设为0保持兼容 state = StateModel(
generated_files=generated_files, current_file_index=0,
dependencies_map=dependencies_map, generated_files=generated_files,
total_files=len(self.design.files) if self.design else 0, dependencies_map=dependencies_map,
output_dir=str(self.output_dir), total_files=len(self.design.files) if self.design else 0,
readme_path=self.readme_content[:100] if self.readme_content else "" output_dir=str(self.output_dir),
) readme_path=self.readme_content[:100] if self.readme_content else ""
with open(self.state_file, "w", encoding="utf-8") as f: )
json.dump(state.model_dump(), f, indent=2, ensure_ascii=False) with open(self.state_file, "w", encoding="utf-8") as f:
logger.debug(f"状态已保存: {self.state_file}") json.dump(state.model_dump(), f, indent=2, ensure_ascii=False)
logger.debug(f"状态已保存: {self.state_file}")
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]: def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
""" """
@ -471,39 +475,39 @@ class CodeGenerator:
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]: def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
""" """
对文件列表进行拓扑排序基于依赖关系 对文件列表进行拓扑排序基于依赖关系
返回排序后的列表满足每个文件的依赖项都出现在该文件之前
Args: 如果检测到循环依赖抛出ValueError
files: 文件路径列表
dependencies: 依赖字典{file: [依赖文件]}
Returns:
List[str]: 拓扑排序后的文件列表
Raises:
ValueError: 如果检测到循环依赖
""" """
from collections import deque from collections import deque
# 构建图
graph = {file: dependencies.get(file, []) for file in files} # 初始化入度和反向邻接表
in_degree = {file: 0 for file in files} in_degree = {f: 0 for f in files}
for file in files: rev_graph = {f: [] for f in files} # 记录哪些文件依赖于f
for dep in graph[file]:
if dep in in_degree: # 构建图如果文件f依赖于dep则增加f的入度并将f加入rev_graph[dep]
in_degree[dep] += 1 for f in files:
for dep in dependencies.get(f, []):
if dep in files: # 只考虑在files中的依赖
in_degree[f] += 1 # f依赖于dep所以f的入度增加
rev_graph[dep].append(f) # dep被f依赖
# 队列初始化为入度为0的文件无依赖的文件
queue = deque([f for f in files if in_degree[f] == 0]) queue = deque([f for f in files if in_degree[f] == 0])
sorted_files = [] sorted_files = []
while queue: while queue:
node = queue.popleft() node = queue.popleft()
sorted_files.append(node) sorted_files.append(node)
for dep in graph[node]: # 所有依赖于node的文件入度减1
in_degree[dep] -= 1 for dependent in rev_graph[node]:
if in_degree[dep] == 0: in_degree[dependent] -= 1
queue.append(dep) if in_degree[dependent] == 0:
queue.append(dependent)
# 检查是否所有文件都已排序(无循环依赖)
if len(sorted_files) != len(files): if len(sorted_files) != len(files):
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。") raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
return sorted_files return sorted_files
def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool: def execute_command(self, cmd: str, cwd: Optional[Path] = None) -> bool:
@ -671,33 +675,50 @@ class CodeGenerator:
done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0) done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0)
for future in done: for future in done:
file = futures.pop(future) file = futures.pop(future)
success, error_msg = future.result() try:
# 更新文件进度任务 success, error_msg = future.result()
if file in file_tasks: # 更新文件进度任务
if file in file_tasks:
if success:
progress.update(file_tasks[file], completed=1)
progress.remove_task(file_tasks[file]) # 移除任务
else:
# 如果失败,标记为错误状态
progress.update(file_tasks[file], description=f"生成失败: {file}")
progress.remove_task(file_tasks[file])
del file_tasks[file] # 清理映射
if success: if success:
progress.update(file_tasks[file], completed=1) processed_files.add(file)
progress.remove_task(file_tasks[file]) # 移除任务 # 更新入度:减少依赖该文件的节点的入度
for other_file in files:
if file in dependencies.get(other_file, []):
in_degree[other_file] -= 1
if in_degree[other_file] == 0 and other_file not in processed_files:
queue.append(other_file)
# 保存状态
self.save_state(list(processed_files), dependencies)
progress.update(total_task, advance=1) # 更新整体进度
else: else:
# 如果失败,标记为错误状态 logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
# 错误处理:继续处理其他文件,但记录失败
except Exception as e:
# 捕获 Future 中存储的异常
logger.error(f"任务 {file} 执行时发生异常: {e}")
self.console.print(f"[bold red]❌ 任务 {file} 执行时发生异常: {e}[/bold red]")
# 将其视为失败
success = False
error_msg = str(e)
# 然后执行和上面 `else` 分支相同的失败处理逻辑
if file in file_tasks:
progress.update(file_tasks[file], description=f"生成失败: {file}") progress.update(file_tasks[file], description=f"生成失败: {file}")
progress.remove_task(file_tasks[file]) progress.remove_task(file_tasks[file])
del file_tasks[file] # 清理映射 del file_tasks[file] # 清理映射
if success:
processed_files.add(file)
# 更新入度:减少依赖该文件的节点的入度
for other_file in files:
if file in dependencies.get(other_file, []):
in_degree[other_file] -= 1
if in_degree[other_file] == 0 and other_file not in processed_files:
queue.append(other_file)
# 保存状态
self.save_state(list(processed_files), dependencies)
progress.update(total_task, advance=1) # 更新整体进度
else:
logger.error(f"文件 {file} 生成失败,错误: {error_msg}") logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]") self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
# 错误处理:继续处理其他文件,但记录失败 # 错误处理:继续处理其他文件,但记录失败
logger.success("所有文件处理完成!") logger.success("所有文件处理完成!")
# 清理状态文件 # 清理状态文件
if self.state_file.exists(): if self.state_file.exists():
@ -708,7 +729,6 @@ class CodeGenerator:
logger.error(f"清理状态文件失败: {e}") logger.error(f"清理状态文件失败: {e}")
self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]") self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]")
def process_issue(self, issue_content: str, issue_type: str) -> bool: def process_issue(self, issue_content: str, issue_type: str) -> bool:
""" """
处理需求增强或 Bug 修复工单 处理需求增强或 Bug 修复工单

View File

@ -1,49 +1,92 @@
""" """Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。"""
Diff 应用模块使用 GitPython 解析和应用 unified diff 格式
"""
import os import os
import sys
from typing import List, Dict, Any from typing import List, Dict, Any
import git # 需要安装 GitPython from unidiff import PatchSet, Hunk # 需要安装 unidiff2
def parse_diff(diff: str) -> List[str]: def parse_diff(diff: str) -> List[str]:
""" """
解析 unified diff 字符串提取受影响的文件路径 解析 unified diff 字符串提取受影响的文件路径
此函数使用 unidiff2 库来解析 diff
Args: Args:
diff: unified diff 格式的字符串 diff: unified diff 格式的字符串
Returns: Returns:
文件路径列表 文件路径列表
""" """
files = set() try:
for line in diff.split('\n'): patch_set = PatchSet(diff)
if line.startswith('--- a/'): # unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件
# 提取旧文件路径 # 使用集合去重
path = line[6:].strip() files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名
if path and path != '/dev/null': # /dev/null 表示新文件 # 过滤掉 /dev/null
files.add(path) return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')]
elif line.startswith('+++ b/'): except Exception as e:
# 提取新文件路径 # 如果解析失败,抛出异常
path = line[6:].strip() raise e
if path and path != '/dev/null': # /dev/null 表示删除文件
files.add(path) def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]:
return list(files) """
将一个文件的补丁多个hunk应用到其内容上
Args:
file_content_lines: 文件内容的行列表
patch_hunks: 针对该文件的一个或多个Hunk对象列表
Returns:
应用了补丁后的新内容行列表
"""
# 为了正确应用多个 hunk必须从后往前处理这样前面的修改才不会影响后面 hunk 的行号
sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True)
current_lines = file_content_lines[:]
for hunk in sorted_hunks:
source_start = hunk.source_start - 1 # 转换为0索引
source_len = hunk.source_length
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' '
actual_source_lines = current_lines[source_start : source_start + source_len]
# 提取实际源行和上下文/删除行进行比较rstrip换行符
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
if source_lines_from_diff != actual_source_for_comparison:
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
# 构建新的内容部分
new_part = []
for line_obj in hunk:
if line_obj.is_added or line_obj.is_context:
# 添加新增行或上下文行
# unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容
# line.value 包含符号line.text 不包含符号,但可能不准确
# 最可靠的方法是 strip() 掉行首的 +/-/ 空格
clean_line = str(line_obj).strip()
# 如果原文件行尾有换行符,我们也要加上
original_trailing = ""
if current_lines and current_lines[0].endswith('\n'):
original_trailing = "\n"
new_part.append(clean_line + original_trailing)
# 替换原内容
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
current_lines = new_lines
return current_lines
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
""" """
应用 unified diff 到指定目录使用 GitPython 应用 unified diff 到指定目录
该函数解析 diff读取磁盘上的文件应用更改并写回文件
该函数解析 diff 并尝试应用更改到文件系统如果目标目录不是 git 仓库
将尝试初始化一个临时仓库或报告错误
Args: Args:
diff: unified diff 格式的字符串 diff: unified diff 格式的字符串
target_dir: 目标目录路径默认为当前目录 target_dir: 目标目录路径默认为当前目录
Returns: Returns:
字典包含以下键 字典包含以下键
- 'success' (bool): 是否成功应用 - 'success' (bool): 是否成功应用
@ -58,58 +101,78 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
'applied_files': [], 'applied_files': [],
'error_details': '' 'error_details': ''
} }
# 检查 diff 是否为空 # 检查 diff 是否为空
if not diff or diff.strip() == '': if not diff or diff.strip() == '':
result['message'] = 'Diff string is empty' result['message'] = 'Diff string is empty'
return result return result
# 解析 diff 获取文件列表 # 解析 diff 获取 PatchSet 对象
try: try:
affected_files = parse_diff_files(diff) patch_set = PatchSet(diff)
# 收集所有需要修改的目标文件路径
affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null']
except Exception as e: except Exception as e:
result['message'] = f"Failed to parse diff: {str(e)}" result['message'] = f"Failed to parse diff: {str(e)}"
result['error_details'] = str(e) result['error_details'] = str(e)
return result return result
# 检查目标目录是否存在 # 检查目标目录是否存在
if not os.path.isdir(target_dir): if not os.path.isdir(target_dir):
result['message'] = f"Target directory does not exist: {target_dir}" result['message'] = f"Target directory does not exist: {target_dir}"
return result return result
try: try:
# 尝试获取或初始化 git 仓库 # 遍历每个 patch (即每个文件的变更)
try: for patch_obj in patch_set:
repo = git.Repo(target_dir) # 获取目标文件路径 (修改后的文件名)
except git.exc.InvalidGitRepositoryError: target_path = patch_obj.target_file.lstrip('/')
# 如果不是 git 仓库,初始化一个 if not target_path or target_path == '/dev/null':
repo = git.Repo.init(target_dir) # 如果目标是 /dev/null则是删除操作我们跳过或记录
# 添加所有现有文件到索引,以便应用 diff continue
repo.git.add('--all')
full_file_path = os.path.join(target_dir, target_path)
# 应用 diff
# 使用 git apply 命令,通过 stdin 传入 diff # 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
# '--whitespace=nowarn' 忽略空白警告 source_path = patch_obj.source_file
output = repo.git.apply('--whitespace=nowarn', '--', input=diff) if source_path == '/dev/null':
# 源是 /dev/null说明这是一个新文件内容从空开始
# 如果成功,更新结果 original_content_lines = []
else:
# 源是普通文件,尝试读取
try:
with open(full_file_path, 'r', encoding='utf-8') as f:
original_content_lines = f.readlines()
except FileNotFoundError:
# 如果文件不存在但diff不是新文件这可能是个错误
# 但对于新文件,这种情况也可能发生,因为路径可能不同
# 我们继续,但要注意如果内容不匹配会报错
original_content_lines = []
# 应用此文件的补丁 (所有 hunks)
modified_content_lines = _apply_single_patch_to_content(
original_content_lines,
list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器
)
# 将修改后的内容写回文件
# 确保目录存在
os.makedirs(os.path.dirname(full_file_path), exist_ok=True)
with open(full_file_path, 'w', encoding='utf-8', newline='') as f:
f.writelines(modified_content_lines)
# 如果所有 patch 都成功应用
result['success'] = True result['success'] = True
result['message'] = 'Diff applied successfully' result['message'] = 'Diff applied successfully'
result['applied_files'] = affected_files result['applied_files'] = affected_files
if output:
result['message'] += f". Output: {output}"
except git.exc.GitError as e:
# 处理 git 错误,如冲突、格式错误等
result['message'] = f"Git error while applying diff: {str(e)}"
result['error_details'] = str(e)
except Exception as e:
# 处理其他异常
result['message'] = f"Unexpected error: {str(e)}"
result['error_details'] = str(e)
return result
except Exception as e:
# 处理应用过程中可能出现的任何错误
result['message'] = f"Error while applying diff: {str(e)}"
result['error_details'] = str(e)
return result
# 如果作为脚本运行,可以提供简单的测试 # 如果作为脚本运行,可以提供简单的测试
if __name__ == "__main__": if __name__ == "__main__":
@ -120,6 +183,7 @@ if __name__ == "__main__":
-Hello World -Hello World
+Hello Universe +Hello Universe
""" """
print("Testing apply_diff...")
print("Testing apply_diff with unidiff2...")
res = apply_diff(sample_diff, ".") res = apply_diff(sample_diff, ".")
print(res) print(res)

View File

@ -1,23 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """Unit tests for diff_applier.py, covering various scenarios such as new file creation,
Unit tests for diff_applier.py, covering various scenarios such as new file creation, modification of existing files, conflict handling, and error cases."""
modification of existing files, conflict handling, and error cases.
"""
import os import os
import sys import sys
import tempfile import tempfile
import shutil import shutil
import pytest import pytest
import git # GitPython is required; assumed installed via project dependencies
# Add src directory to path for module import # Add src directory to path for module import
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from llm_codegen.diff_applier import parse_diff_files, apply_diff # 假设您的文件名为 diff_applier_unidiff2.py
from src.llm_codegen.diff_applier import parse_diff, apply_diff
def test_parse_diff_files(): def test_parse_diff():
"""Test parsing unified diff strings to extract file paths.""" """Test parsing unified diff strings to extract file paths."""
# Diff with modification and new file # Diff with modification and new file
diff = """--- a/file1.txt diff = """--- a/file1.txt
@ -25,128 +22,102 @@ def test_parse_diff_files():
@@ -1 +1 @@ @@ -1 +1 @@
-old content -old content
+new content +new content
--- a/new_file.txt --- /dev/null
+++ b/new_file.txt +++ b/new_file.txt
@@ -0,0 +1 @@ @@ -0,0 +1 @@
+new file content +new file content"""
""" files = parse_diff(diff)
files = parse_diff_files(diff)
assert set(files) == {'file1.txt', 'new_file.txt'} assert set(files) == {'file1.txt', 'new_file.txt'}
# Diff with file deletion # Diff with file deletion
diff_del = """--- a/deleted.txt diff_del = """--- a/deleted.txt
+++ /dev/null +++ /dev/null
@@ -1 +0,0 @@ @@ -1 +0,0 @@
-content to delete -content to delete"""
""" files = parse_diff(diff_del)
files = parse_diff_files(diff_del) # 删除操作的解析结果应为空,因为没有实际的“目标”文件被创建或修改
assert files == ['deleted.txt'] # Old file path is extracted assert files == []
# Empty diff # Empty diff
assert parse_diff_files('') == [] assert parse_diff('') == []
assert parse_diff_files('\n') == [] assert parse_diff('\n') == []
# Diff with only new file (no old file) # Diff with only new file (no old file)
diff_new_only = """--- /dev/null diff_new_only = """--- /dev/null
+++ b/only_new.txt +++ b/only_new.txt
@@ -0,0 +1 @@ @@ -0,0 +1 @@
+only new +only new"""
""" files = parse_diff(diff_new_only)
files = parse_diff_files(diff_new_only) assert files == ['only_new.txt']
assert files == [] /dev/null is ignored
# Invalid diff format (should still handle gracefully) # Invalid diff format (should still handle gracefully)
diff_invalid = "invalid diff string" diff_invalid = "invalid diff string"
files = parse_diff_files(diff_invalid) files = parse_diff(diff_invalid)
assert files == [] # No valid file paths found assert files == [] # No valid file paths found
def test_apply_diff_new_file():
@pytest.fixture
def temp_git_repo():
"""Create a temporary git repository for testing apply_diff."""
temp_dir = tempfile.mkdtemp()
repo = git.Repo.init(temp_dir)
# Make an initial commit to have a clean state
init_file = os.path.join(temp_dir, 'README.md')
with open(init_file, 'w') as f:
f.write('Initial commit content')
repo.git.add('README.md')
repo.git.commit('-m', 'initial commit')
yield temp_dir, repo
# Cleanup after test
shutil.rmtree(temp_dir, ignore_errors=True)
def test_apply_diff_new_file(temp_git_repo):
"""Test applying a diff that creates a new file.""" """Test applying a diff that creates a new file."""
temp_dir, repo = temp_git_repo with tempfile.TemporaryDirectory() as temp_dir:
diff = """--- /dev/null diff = """--- /dev/null
+++ b/test_new.txt +++ b/test_new.txt
@@ -0,0 +1 @@ @@ -0,0 +1 @@
+This is a newly created file. +This is a newly created file."""
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'test_new.txt' in result['applied_files']
new_file_path = os.path.join(temp_dir, 'test_new.txt')
assert os.path.exists(new_file_path)
with open(new_file_path, 'r') as f:
content = f.read()
assert content.strip() == 'This is a newly created file.'
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'test_new.txt' in result['applied_files']
def test_apply_diff_modify_existing_file(temp_git_repo): new_file_path = os.path.join(temp_dir, 'test_new.txt')
assert os.path.exists(new_file_path)
with open(new_file_path, 'r') as f:
content = f.read()
# 修复断言strip()会去掉行末的换行符,使内容匹配
assert content.strip() == 'This is a newly created file.'
def test_apply_diff_modify_existing_file():
"""Test applying a diff that modifies an existing file.""" """Test applying a diff that modifies an existing file."""
temp_dir, repo = temp_git_repo with tempfile.TemporaryDirectory() as temp_dir:
# Create an existing file and commit it # Create an existing file
existing_file = os.path.join(temp_dir, 'existing.txt') existing_file = os.path.join(temp_dir, 'existing.txt')
with open(existing_file, 'w') as f: with open(existing_file, 'w') as f:
f.write('Original line 1\nOriginal line 2') f.write('Original line 1\nOriginal line 2\n')
repo.git.add('existing.txt')
repo.git.commit('-m', 'add existing file') diff = """--- a/existing.txt
diff = """--- a/existing.txt
+++ b/existing.txt +++ b/existing.txt
@@ -1,2 +1,2 @@ @@ -1,2 +1,2 @@
-Original line 1 -Original line 1
+Modified line 1 +Modified line 1
Original line 2 Original line 2"""
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'existing.txt' in result['applied_files']
with open(existing_file, 'r') as f:
content = f.read()
assert content == 'Modified line 1\nOriginal line 2'
result = apply_diff(diff, temp_dir)
assert result['success'] == True
assert 'existing.txt' in result['applied_files']
def test_apply_diff_conflict_handling(temp_git_repo): with open(existing_file, 'r') as f:
"""Test applying a diff that causes a conflict with uncommitted changes.""" content = f.read()
temp_dir, repo = temp_git_repo assert content == 'Modified line 1\nOriginal line 2\n'
# Create and commit a file
conflict_file = os.path.join(temp_dir, 'conflict.txt') def test_apply_diff_conflict_handling():
with open(conflict_file, 'w') as f: """Test applying a diff that causes a conflict because the source doesn't match."""
f.write('Initial line 1\nInitial line 2') with tempfile.TemporaryDirectory() as temp_dir:
repo.git.add('conflict.txt') # Create a file with specific content
repo.git.commit('-m', 'add conflict file') conflict_file = os.path.join(temp_dir, 'conflict.txt')
with open(conflict_file, 'w') as f:
# Modify the file without committing to create conflict f.write('Different line 1\nOriginal line 2\n') # This is different from what diff expects
with open(conflict_file, 'w') as f:
f.write('Changed line 1\nInitial line 2') # Change first line # This diff expects 'Initial line 1' but finds 'Different line 1'
diff = """--- a/conflict.txt
diff = """--- a/conflict.txt
+++ b/conflict.txt +++ b/conflict.txt
@@ -1,2 +1,2 @@ @@ -1,2 +1,2 @@
-Initial line 1 -Initial line 1
+Diff line 1 +Diff line 1
Initial line 2 Initial line 2"""
"""
result = apply_diff(diff, temp_dir)
assert result['success'] == False
# Check for conflict or error in message
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
assert result['error_details'] != ''
result = apply_diff(diff, temp_dir)
assert result['success'] == False # Should fail due to mismatch
# Check for conflict or error in message
assert 'error' in result['message'].lower() or 'does not match' in result['message'].lower()
assert result['error_details'] != ''
def test_apply_diff_empty_diff(): def test_apply_diff_empty_diff():
"""Test applying an empty diff string.""" """Test applying an empty diff string."""
@ -154,7 +125,6 @@ def test_apply_diff_empty_diff():
assert result['success'] == False assert result['success'] == False
assert 'empty' in result['message'].lower() assert 'empty' in result['message'].lower()
def test_apply_diff_invalid_directory(): def test_apply_diff_invalid_directory():
"""Test applying a diff to a non-existent directory.""" """Test applying a diff to a non-existent directory."""
non_existent_dir = '/tmp/non_existent_dir_12345' non_existent_dir = '/tmp/non_existent_dir_12345'
@ -162,36 +132,33 @@ def test_apply_diff_invalid_directory():
+++ b/dummy.txt +++ b/dummy.txt
@@ -1 +1 @@ @@ -1 +1 @@
-old -old
+new +new"""
"""
result = apply_diff(diff, non_existent_dir) result = apply_diff(diff, non_existent_dir)
assert result['success'] == False assert result['success'] == False
assert 'does not exist' in result['message'].lower() assert 'does not exist' in result['message'].lower()
def test_apply_diff_no_git_repo_initialization(): def test_apply_diff_no_git_repo_initialization():
"""Test applying a diff to a non-git directory, which should initialize a repo.""" """Test applying a diff to a non-git directory. This test is now redundant as there's no git dependency."""
temp_dir = tempfile.mkdtemp() with tempfile.TemporaryDirectory() as temp_dir:
try:
# Create a non-git directory with a file # Create a non-git directory with a file
non_git_file = os.path.join(temp_dir, 'non_git.txt') non_git_file = os.path.join(temp_dir, 'non_git.txt')
with open(non_git_file, 'w') as f: with open(non_git_file, 'w') as f:
f.write('Pre-existing content') f.write('Pre-existing content\n')
diff = """--- a/non_git.txt diff = """--- a/non_git.txt
+++ b/non_git.txt +++ b/non_git.txt
@@ -1 +1 @@ @@ -1 +1 @@
-Pre-existing content -Pre-existing content
+Updated content +Updated content"""
"""
result = apply_diff(diff, temp_dir) result = apply_diff(diff, temp_dir)
assert result['success'] == True assert result['success'] == True
assert 'non_git.txt' in result['applied_files'] assert 'non_git.txt' in result['applied_files']
with open(non_git_file, 'r') as f: with open(non_git_file, 'r') as f:
content = f.read() content = f.read()
assert content == 'Updated content' assert content == 'Updated content\n'
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__": if __name__ == "__main__":