重构代码以支持并发生成和线程安全状态保存,优化拓扑排序与补丁应用逻辑
This commit is contained in:
parent
6e536f141c
commit
75fc8659a0
|
|
@ -23,9 +23,9 @@
|
|||
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
||||
"functions": [
|
||||
{
|
||||
"name": "main",
|
||||
"summary": "主CLI入口,处理命令行参数并启动生成器",
|
||||
"inputs": ["readme", "output_dir", "api_key", "base_url", "model", "log_file"],
|
||||
"name": "",
|
||||
"summary": "",
|
||||
"inputs": [],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ dependencies = [
|
|||
"loguru>=0.7.0",
|
||||
"openai>=1.0.0",
|
||||
"pathspec>=1.0.4",
|
||||
"gitpython>=3.1.0",
|
||||
"python-patch>=0.0.1",
|
||||
"unidiff2>=0.7.8",
|
||||
]
|
||||
authors = [
|
||||
{name = "Your Name", email = "your.email@example.com"}
|
||||
|
|
@ -33,6 +34,10 @@ dev = [
|
|||
"black>=23.0.0",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
default=true
|
||||
|
||||
[project.scripts]
|
||||
llm-codegen = "llm_codegen.cli:app"
|
||||
|
||||
|
|
@ -45,3 +50,8 @@ max_concurrent_requests = 5
|
|||
# 新增:指定包所在目录
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.4.2",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import difflib
|
|||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
||||
|
|
@ -50,6 +51,7 @@ class CodeGenerator:
|
|||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.state_file = self.output_dir / ".llm_generator_state.json"
|
||||
self.console = Console() # 添加console实例用于rich打印
|
||||
self._state_lock = threading.Lock()
|
||||
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
|
|
@ -171,18 +173,20 @@ class CodeGenerator:
|
|||
return None
|
||||
|
||||
def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
|
||||
"""保存断点续写状态,适应并发生成"""
|
||||
state = StateModel(
|
||||
current_file_index=0, # 在并发中无效,设为0保持兼容
|
||||
generated_files=generated_files,
|
||||
dependencies_map=dependencies_map,
|
||||
total_files=len(self.design.files) if self.design else 0,
|
||||
output_dir=str(self.output_dir),
|
||||
readme_path=self.readme_content[:100] if self.readme_content else ""
|
||||
)
|
||||
with open(self.state_file, "w", encoding="utf-8") as f:
|
||||
json.dump(state.model_dump(), f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"状态已保存: {self.state_file}")
|
||||
"""保存断点续写状态,适应并发生成(线程安全)"""
|
||||
with self._state_lock: # 串行化写入
|
||||
state = StateModel(
|
||||
current_file_index=0,
|
||||
generated_files=generated_files,
|
||||
dependencies_map=dependencies_map,
|
||||
total_files=len(self.design.files) if self.design else 0,
|
||||
output_dir=str(self.output_dir),
|
||||
readme_path=self.readme_content[:100] if self.readme_content else ""
|
||||
)
|
||||
with open(self.state_file, "w", encoding="utf-8") as f:
|
||||
json.dump(state.model_dump(), f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"状态已保存: {self.state_file}")
|
||||
|
||||
|
||||
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
||||
"""
|
||||
|
|
@ -471,36 +475,36 @@ class CodeGenerator:
|
|||
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
对文件列表进行拓扑排序,基于依赖关系。
|
||||
|
||||
Args:
|
||||
files: 文件路径列表
|
||||
dependencies: 依赖字典,{file: [依赖文件]}
|
||||
|
||||
Returns:
|
||||
List[str]: 拓扑排序后的文件列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果检测到循环依赖
|
||||
返回排序后的列表,满足每个文件的依赖项都出现在该文件之前。
|
||||
如果检测到循环依赖,抛出ValueError。
|
||||
"""
|
||||
from collections import deque
|
||||
# 构建图
|
||||
graph = {file: dependencies.get(file, []) for file in files}
|
||||
in_degree = {file: 0 for file in files}
|
||||
for file in files:
|
||||
for dep in graph[file]:
|
||||
if dep in in_degree:
|
||||
in_degree[dep] += 1
|
||||
|
||||
# 初始化入度和反向邻接表
|
||||
in_degree = {f: 0 for f in files}
|
||||
rev_graph = {f: [] for f in files} # 记录哪些文件依赖于f
|
||||
|
||||
# 构建图:如果文件f依赖于dep,则增加f的入度,并将f加入rev_graph[dep]
|
||||
for f in files:
|
||||
for dep in dependencies.get(f, []):
|
||||
if dep in files: # 只考虑在files中的依赖
|
||||
in_degree[f] += 1 # f依赖于dep,所以f的入度增加
|
||||
rev_graph[dep].append(f) # dep被f依赖
|
||||
|
||||
# 队列初始化为入度为0的文件(无依赖的文件)
|
||||
queue = deque([f for f in files if in_degree[f] == 0])
|
||||
sorted_files = []
|
||||
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
sorted_files.append(node)
|
||||
for dep in graph[node]:
|
||||
in_degree[dep] -= 1
|
||||
if in_degree[dep] == 0:
|
||||
queue.append(dep)
|
||||
# 所有依赖于node的文件入度减1
|
||||
for dependent in rev_graph[node]:
|
||||
in_degree[dependent] -= 1
|
||||
if in_degree[dependent] == 0:
|
||||
queue.append(dependent)
|
||||
|
||||
# 检查是否所有文件都已排序(无循环依赖)
|
||||
if len(sorted_files) != len(files):
|
||||
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
|
||||
|
||||
|
|
@ -671,33 +675,50 @@ class CodeGenerator:
|
|||
done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0)
|
||||
for future in done:
|
||||
file = futures.pop(future)
|
||||
success, error_msg = future.result()
|
||||
# 更新文件进度任务
|
||||
if file in file_tasks:
|
||||
try:
|
||||
success, error_msg = future.result()
|
||||
# 更新文件进度任务
|
||||
if file in file_tasks:
|
||||
if success:
|
||||
progress.update(file_tasks[file], completed=1)
|
||||
progress.remove_task(file_tasks[file]) # 移除任务
|
||||
else:
|
||||
# 如果失败,标记为错误状态
|
||||
progress.update(file_tasks[file], description=f"生成失败: {file}")
|
||||
progress.remove_task(file_tasks[file])
|
||||
del file_tasks[file] # 清理映射
|
||||
if success:
|
||||
progress.update(file_tasks[file], completed=1)
|
||||
progress.remove_task(file_tasks[file]) # 移除任务
|
||||
processed_files.add(file)
|
||||
# 更新入度:减少依赖该文件的节点的入度
|
||||
for other_file in files:
|
||||
if file in dependencies.get(other_file, []):
|
||||
in_degree[other_file] -= 1
|
||||
if in_degree[other_file] == 0 and other_file not in processed_files:
|
||||
queue.append(other_file)
|
||||
# 保存状态
|
||||
self.save_state(list(processed_files), dependencies)
|
||||
progress.update(total_task, advance=1) # 更新整体进度
|
||||
else:
|
||||
# 如果失败,标记为错误状态
|
||||
logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
|
||||
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
|
||||
# 错误处理:继续处理其他文件,但记录失败
|
||||
except Exception as e:
|
||||
# 捕获 Future 中存储的异常
|
||||
logger.error(f"任务 {file} 执行时发生异常: {e}")
|
||||
self.console.print(f"[bold red]❌ 任务 {file} 执行时发生异常: {e}[/bold red]")
|
||||
# 将其视为失败
|
||||
success = False
|
||||
error_msg = str(e)
|
||||
# 然后执行和上面 `else` 分支相同的失败处理逻辑
|
||||
if file in file_tasks:
|
||||
progress.update(file_tasks[file], description=f"生成失败: {file}")
|
||||
progress.remove_task(file_tasks[file])
|
||||
del file_tasks[file] # 清理映射
|
||||
if success:
|
||||
processed_files.add(file)
|
||||
# 更新入度:减少依赖该文件的节点的入度
|
||||
for other_file in files:
|
||||
if file in dependencies.get(other_file, []):
|
||||
in_degree[other_file] -= 1
|
||||
if in_degree[other_file] == 0 and other_file not in processed_files:
|
||||
queue.append(other_file)
|
||||
# 保存状态
|
||||
self.save_state(list(processed_files), dependencies)
|
||||
progress.update(total_task, advance=1) # 更新整体进度
|
||||
else:
|
||||
del file_tasks[file] # 清理映射
|
||||
logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
|
||||
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
|
||||
# 错误处理:继续处理其他文件,但记录失败
|
||||
|
||||
|
||||
logger.success("所有文件处理完成!")
|
||||
# 清理状态文件
|
||||
if self.state_file.exists():
|
||||
|
|
@ -708,7 +729,6 @@ class CodeGenerator:
|
|||
logger.error(f"清理状态文件失败: {e}")
|
||||
self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]")
|
||||
|
||||
|
||||
def process_issue(self, issue_content: str, issue_type: str) -> bool:
|
||||
"""
|
||||
处理需求增强或 Bug 修复工单
|
||||
|
|
|
|||
|
|
@ -1,16 +1,12 @@
|
|||
"""
|
||||
Diff 应用模块,使用 GitPython 解析和应用 unified diff 格式。
|
||||
"""
|
||||
|
||||
"""Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。"""
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any
|
||||
import git # 需要安装 GitPython
|
||||
|
||||
from unidiff import PatchSet, Hunk # 需要安装 unidiff2
|
||||
|
||||
def parse_diff(diff: str) -> List[str]:
|
||||
"""
|
||||
解析 unified diff 字符串,提取受影响的文件路径。
|
||||
此函数使用 unidiff2 库来解析 diff。
|
||||
|
||||
Args:
|
||||
diff: unified diff 格式的字符串。
|
||||
|
|
@ -18,27 +14,74 @@ def parse_diff(diff: str) -> List[str]:
|
|||
Returns:
|
||||
文件路径列表。
|
||||
"""
|
||||
files = set()
|
||||
for line in diff.split('\n'):
|
||||
if line.startswith('--- a/'):
|
||||
# 提取旧文件路径
|
||||
path = line[6:].strip()
|
||||
if path and path != '/dev/null': # /dev/null 表示新文件
|
||||
files.add(path)
|
||||
elif line.startswith('+++ b/'):
|
||||
# 提取新文件路径
|
||||
path = line[6:].strip()
|
||||
if path and path != '/dev/null': # /dev/null 表示删除文件
|
||||
files.add(path)
|
||||
return list(files)
|
||||
try:
|
||||
patch_set = PatchSet(diff)
|
||||
# unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件
|
||||
# 使用集合去重
|
||||
files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名
|
||||
# 过滤掉 /dev/null
|
||||
return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')]
|
||||
except Exception as e:
|
||||
# 如果解析失败,抛出异常
|
||||
raise e
|
||||
|
||||
def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]:
|
||||
"""
|
||||
将一个文件的补丁(多个hunk)应用到其内容上。
|
||||
|
||||
Args:
|
||||
file_content_lines: 文件内容的行列表。
|
||||
patch_hunks: 针对该文件的一个或多个Hunk对象列表。
|
||||
|
||||
Returns:
|
||||
应用了补丁后的新内容行列表。
|
||||
"""
|
||||
# 为了正确应用多个 hunk,必须从后往前处理,这样前面的修改才不会影响后面 hunk 的行号
|
||||
sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True)
|
||||
|
||||
current_lines = file_content_lines[:]
|
||||
|
||||
for hunk in sorted_hunks:
|
||||
source_start = hunk.source_start - 1 # 转换为0索引
|
||||
source_len = hunk.source_length
|
||||
|
||||
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
|
||||
source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' '
|
||||
actual_source_lines = current_lines[source_start : source_start + source_len]
|
||||
|
||||
# 提取实际源行和上下文/删除行进行比较,rstrip换行符
|
||||
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
|
||||
|
||||
if source_lines_from_diff != actual_source_for_comparison:
|
||||
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
|
||||
|
||||
# 构建新的内容部分
|
||||
new_part = []
|
||||
for line_obj in hunk:
|
||||
if line_obj.is_added or line_obj.is_context:
|
||||
# 添加新增行或上下文行
|
||||
# unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容
|
||||
# line.value 包含符号,line.text 不包含符号,但可能不准确
|
||||
# 最可靠的方法是 strip() 掉行首的 +/-/ 空格
|
||||
clean_line = str(line_obj).strip()
|
||||
# 如果原文件行尾有换行符,我们也要加上
|
||||
original_trailing = ""
|
||||
if current_lines and current_lines[0].endswith('\n'):
|
||||
original_trailing = "\n"
|
||||
|
||||
new_part.append(clean_line + original_trailing)
|
||||
|
||||
# 替换原内容
|
||||
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
|
||||
current_lines = new_lines
|
||||
|
||||
return current_lines
|
||||
|
||||
|
||||
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
应用 unified diff 到指定目录,使用 GitPython 库。
|
||||
|
||||
该函数解析 diff 并尝试应用更改到文件系统。如果目标目录不是 git 仓库,
|
||||
将尝试初始化一个临时仓库或报告错误。
|
||||
应用 unified diff 到指定目录。
|
||||
该函数解析 diff,读取磁盘上的文件,应用更改,并写回文件。
|
||||
|
||||
Args:
|
||||
diff: unified diff 格式的字符串。
|
||||
|
|
@ -64,9 +107,11 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
|||
result['message'] = 'Diff string is empty'
|
||||
return result
|
||||
|
||||
# 解析 diff 获取文件列表
|
||||
# 解析 diff 获取 PatchSet 对象
|
||||
try:
|
||||
affected_files = parse_diff_files(diff)
|
||||
patch_set = PatchSet(diff)
|
||||
# 收集所有需要修改的目标文件路径
|
||||
affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null']
|
||||
except Exception as e:
|
||||
result['message'] = f"Failed to parse diff: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
|
|
@ -78,39 +123,57 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
|||
return result
|
||||
|
||||
try:
|
||||
# 尝试获取或初始化 git 仓库
|
||||
try:
|
||||
repo = git.Repo(target_dir)
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
# 如果不是 git 仓库,初始化一个
|
||||
repo = git.Repo.init(target_dir)
|
||||
# 添加所有现有文件到索引,以便应用 diff
|
||||
repo.git.add('--all')
|
||||
# 遍历每个 patch (即每个文件的变更)
|
||||
for patch_obj in patch_set:
|
||||
# 获取目标文件路径 (修改后的文件名)
|
||||
target_path = patch_obj.target_file.lstrip('/')
|
||||
if not target_path or target_path == '/dev/null':
|
||||
# 如果目标是 /dev/null,则是删除操作,我们跳过或记录
|
||||
continue
|
||||
|
||||
# 应用 diff
|
||||
# 使用 git apply 命令,通过 stdin 传入 diff
|
||||
# '--whitespace=nowarn' 忽略空白警告
|
||||
output = repo.git.apply('--whitespace=nowarn', '--', input=diff)
|
||||
full_file_path = os.path.join(target_dir, target_path)
|
||||
|
||||
# 如果成功,更新结果
|
||||
# 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
|
||||
source_path = patch_obj.source_file
|
||||
if source_path == '/dev/null':
|
||||
# 源是 /dev/null,说明这是一个新文件,内容从空开始
|
||||
original_content_lines = []
|
||||
else:
|
||||
# 源是普通文件,尝试读取
|
||||
try:
|
||||
with open(full_file_path, 'r', encoding='utf-8') as f:
|
||||
original_content_lines = f.readlines()
|
||||
except FileNotFoundError:
|
||||
# 如果文件不存在,但diff不是新文件,这可能是个错误
|
||||
# 但对于新文件,这种情况也可能发生,因为路径可能不同
|
||||
# 我们继续,但要注意如果内容不匹配会报错
|
||||
original_content_lines = []
|
||||
|
||||
# 应用此文件的补丁 (所有 hunks)
|
||||
modified_content_lines = _apply_single_patch_to_content(
|
||||
original_content_lines,
|
||||
list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器
|
||||
)
|
||||
|
||||
# 将修改后的内容写回文件
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(full_file_path), exist_ok=True)
|
||||
|
||||
with open(full_file_path, 'w', encoding='utf-8', newline='') as f:
|
||||
f.writelines(modified_content_lines)
|
||||
|
||||
# 如果所有 patch 都成功应用
|
||||
result['success'] = True
|
||||
result['message'] = 'Diff applied successfully'
|
||||
result['applied_files'] = affected_files
|
||||
if output:
|
||||
result['message'] += f". Output: {output}"
|
||||
|
||||
except git.exc.GitError as e:
|
||||
# 处理 git 错误,如冲突、格式错误等
|
||||
result['message'] = f"Git error while applying diff: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
except Exception as e:
|
||||
# 处理其他异常
|
||||
result['message'] = f"Unexpected error: {str(e)}"
|
||||
# 处理应用过程中可能出现的任何错误
|
||||
result['message'] = f"Error while applying diff: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 如果作为脚本运行,可以提供简单的测试
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
|
|
@ -120,6 +183,7 @@ if __name__ == "__main__":
|
|||
-Hello World
|
||||
+Hello Universe
|
||||
"""
|
||||
print("Testing apply_diff...")
|
||||
|
||||
print("Testing apply_diff with unidiff2...")
|
||||
res = apply_diff(sample_diff, ".")
|
||||
print(res)
|
||||
|
|
|
|||
|
|
@ -1,23 +1,20 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for diff_applier.py, covering various scenarios such as new file creation,
|
||||
modification of existing files, conflict handling, and error cases.
|
||||
"""
|
||||
|
||||
"""Unit tests for diff_applier.py, covering various scenarios such as new file creation,
|
||||
modification of existing files, conflict handling, and error cases."""
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
import pytest
|
||||
import git # GitPython is required; assumed installed via project dependencies
|
||||
|
||||
# Add src directory to path for module import
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from llm_codegen.diff_applier import parse_diff_files, apply_diff
|
||||
# 假设您的文件名为 diff_applier_unidiff2.py
|
||||
from src.llm_codegen.diff_applier import parse_diff, apply_diff
|
||||
|
||||
|
||||
def test_parse_diff_files():
|
||||
def test_parse_diff():
|
||||
"""Test parsing unified diff strings to extract file paths."""
|
||||
# Diff with modification and new file
|
||||
diff = """--- a/file1.txt
|
||||
|
|
@ -25,128 +22,102 @@ def test_parse_diff_files():
|
|||
@@ -1 +1 @@
|
||||
-old content
|
||||
+new content
|
||||
--- a/new_file.txt
|
||||
--- /dev/null
|
||||
+++ b/new_file.txt
|
||||
@@ -0,0 +1 @@
|
||||
+new file content
|
||||
"""
|
||||
files = parse_diff_files(diff)
|
||||
+new file content"""
|
||||
files = parse_diff(diff)
|
||||
assert set(files) == {'file1.txt', 'new_file.txt'}
|
||||
|
||||
# Diff with file deletion
|
||||
diff_del = """--- a/deleted.txt
|
||||
+++ /dev/null
|
||||
@@ -1 +0,0 @@
|
||||
-content to delete
|
||||
"""
|
||||
files = parse_diff_files(diff_del)
|
||||
assert files == ['deleted.txt'] # Old file path is extracted
|
||||
-content to delete"""
|
||||
files = parse_diff(diff_del)
|
||||
# 删除操作的解析结果应为空,因为没有实际的“目标”文件被创建或修改
|
||||
assert files == []
|
||||
|
||||
# Empty diff
|
||||
assert parse_diff_files('') == []
|
||||
assert parse_diff_files('\n') == []
|
||||
assert parse_diff('') == []
|
||||
assert parse_diff('\n') == []
|
||||
|
||||
# Diff with only new file (no old file)
|
||||
diff_new_only = """--- /dev/null
|
||||
+++ b/only_new.txt
|
||||
@@ -0,0 +1 @@
|
||||
+only new
|
||||
"""
|
||||
files = parse_diff_files(diff_new_only)
|
||||
assert files == [] /dev/null is ignored
|
||||
+only new"""
|
||||
files = parse_diff(diff_new_only)
|
||||
assert files == ['only_new.txt']
|
||||
|
||||
# Invalid diff format (should still handle gracefully)
|
||||
diff_invalid = "invalid diff string"
|
||||
files = parse_diff_files(diff_invalid)
|
||||
assert files == [] # No valid file paths found
|
||||
files = parse_diff(diff_invalid)
|
||||
assert files == [] # No valid file paths found
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_git_repo():
|
||||
"""Create a temporary git repository for testing apply_diff."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
repo = git.Repo.init(temp_dir)
|
||||
# Make an initial commit to have a clean state
|
||||
init_file = os.path.join(temp_dir, 'README.md')
|
||||
with open(init_file, 'w') as f:
|
||||
f.write('Initial commit content')
|
||||
repo.git.add('README.md')
|
||||
repo.git.commit('-m', 'initial commit')
|
||||
yield temp_dir, repo
|
||||
# Cleanup after test
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_apply_diff_new_file(temp_git_repo):
|
||||
def test_apply_diff_new_file():
|
||||
"""Test applying a diff that creates a new file."""
|
||||
temp_dir, repo = temp_git_repo
|
||||
diff = """--- /dev/null
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
diff = """--- /dev/null
|
||||
+++ b/test_new.txt
|
||||
@@ -0,0 +1 @@
|
||||
+This is a newly created file.
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'test_new.txt' in result['applied_files']
|
||||
new_file_path = os.path.join(temp_dir, 'test_new.txt')
|
||||
assert os.path.exists(new_file_path)
|
||||
with open(new_file_path, 'r') as f:
|
||||
content = f.read()
|
||||
assert content.strip() == 'This is a newly created file.'
|
||||
+This is a newly created file."""
|
||||
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'test_new.txt' in result['applied_files']
|
||||
|
||||
def test_apply_diff_modify_existing_file(temp_git_repo):
|
||||
new_file_path = os.path.join(temp_dir, 'test_new.txt')
|
||||
assert os.path.exists(new_file_path)
|
||||
with open(new_file_path, 'r') as f:
|
||||
content = f.read()
|
||||
# 修复断言,strip()会去掉行末的换行符,使内容匹配
|
||||
assert content.strip() == 'This is a newly created file.'
|
||||
|
||||
def test_apply_diff_modify_existing_file():
|
||||
"""Test applying a diff that modifies an existing file."""
|
||||
temp_dir, repo = temp_git_repo
|
||||
# Create an existing file and commit it
|
||||
existing_file = os.path.join(temp_dir, 'existing.txt')
|
||||
with open(existing_file, 'w') as f:
|
||||
f.write('Original line 1\nOriginal line 2')
|
||||
repo.git.add('existing.txt')
|
||||
repo.git.commit('-m', 'add existing file')
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create an existing file
|
||||
existing_file = os.path.join(temp_dir, 'existing.txt')
|
||||
with open(existing_file, 'w') as f:
|
||||
f.write('Original line 1\nOriginal line 2\n')
|
||||
|
||||
diff = """--- a/existing.txt
|
||||
diff = """--- a/existing.txt
|
||||
+++ b/existing.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-Original line 1
|
||||
+Modified line 1
|
||||
Original line 2
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'existing.txt' in result['applied_files']
|
||||
with open(existing_file, 'r') as f:
|
||||
content = f.read()
|
||||
assert content == 'Modified line 1\nOriginal line 2'
|
||||
Original line 2"""
|
||||
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'existing.txt' in result['applied_files']
|
||||
|
||||
def test_apply_diff_conflict_handling(temp_git_repo):
|
||||
"""Test applying a diff that causes a conflict with uncommitted changes."""
|
||||
temp_dir, repo = temp_git_repo
|
||||
# Create and commit a file
|
||||
conflict_file = os.path.join(temp_dir, 'conflict.txt')
|
||||
with open(conflict_file, 'w') as f:
|
||||
f.write('Initial line 1\nInitial line 2')
|
||||
repo.git.add('conflict.txt')
|
||||
repo.git.commit('-m', 'add conflict file')
|
||||
with open(existing_file, 'r') as f:
|
||||
content = f.read()
|
||||
assert content == 'Modified line 1\nOriginal line 2\n'
|
||||
|
||||
# Modify the file without committing to create conflict
|
||||
with open(conflict_file, 'w') as f:
|
||||
f.write('Changed line 1\nInitial line 2') # Change first line
|
||||
def test_apply_diff_conflict_handling():
|
||||
"""Test applying a diff that causes a conflict because the source doesn't match."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a file with specific content
|
||||
conflict_file = os.path.join(temp_dir, 'conflict.txt')
|
||||
with open(conflict_file, 'w') as f:
|
||||
f.write('Different line 1\nOriginal line 2\n') # This is different from what diff expects
|
||||
|
||||
diff = """--- a/conflict.txt
|
||||
# This diff expects 'Initial line 1' but finds 'Different line 1'
|
||||
diff = """--- a/conflict.txt
|
||||
+++ b/conflict.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-Initial line 1
|
||||
+Diff line 1
|
||||
Initial line 2
|
||||
"""
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == False
|
||||
# Check for conflict or error in message
|
||||
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
|
||||
assert result['error_details'] != ''
|
||||
Initial line 2"""
|
||||
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == False # Should fail due to mismatch
|
||||
# Check for conflict or error in message
|
||||
assert 'error' in result['message'].lower() or 'does not match' in result['message'].lower()
|
||||
assert result['error_details'] != ''
|
||||
|
||||
def test_apply_diff_empty_diff():
|
||||
"""Test applying an empty diff string."""
|
||||
|
|
@ -154,7 +125,6 @@ def test_apply_diff_empty_diff():
|
|||
assert result['success'] == False
|
||||
assert 'empty' in result['message'].lower()
|
||||
|
||||
|
||||
def test_apply_diff_invalid_directory():
|
||||
"""Test applying a diff to a non-existent directory."""
|
||||
non_existent_dir = '/tmp/non_existent_dir_12345'
|
||||
|
|
@ -162,36 +132,33 @@ def test_apply_diff_invalid_directory():
|
|||
+++ b/dummy.txt
|
||||
@@ -1 +1 @@
|
||||
-old
|
||||
+new
|
||||
"""
|
||||
+new"""
|
||||
|
||||
result = apply_diff(diff, non_existent_dir)
|
||||
assert result['success'] == False
|
||||
assert 'does not exist' in result['message'].lower()
|
||||
|
||||
|
||||
def test_apply_diff_no_git_repo_initialization():
|
||||
"""Test applying a diff to a non-git directory, which should initialize a repo."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
"""Test applying a diff to a non-git directory. This test is now redundant as there's no git dependency."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a non-git directory with a file
|
||||
non_git_file = os.path.join(temp_dir, 'non_git.txt')
|
||||
with open(non_git_file, 'w') as f:
|
||||
f.write('Pre-existing content')
|
||||
f.write('Pre-existing content\n')
|
||||
|
||||
diff = """--- a/non_git.txt
|
||||
+++ b/non_git.txt
|
||||
@@ -1 +1 @@
|
||||
-Pre-existing content
|
||||
+Updated content
|
||||
"""
|
||||
+Updated content"""
|
||||
|
||||
result = apply_diff(diff, temp_dir)
|
||||
assert result['success'] == True
|
||||
assert 'non_git.txt' in result['applied_files']
|
||||
|
||||
with open(non_git_file, 'r') as f:
|
||||
content = f.read()
|
||||
assert content == 'Updated content'
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
assert content == 'Updated content\n'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue