重构代码以支持并发生成和线程安全状态保存,优化拓扑排序与补丁应用逻辑
This commit is contained in:
parent
6e536f141c
commit
75fc8659a0
|
|
@ -23,9 +23,9 @@
|
||||||
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
"dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"],
|
||||||
"functions": [
|
"functions": [
|
||||||
{
|
{
|
||||||
"name": "main",
|
"name": "",
|
||||||
"summary": "主CLI入口,处理命令行参数并启动生成器",
|
"summary": "",
|
||||||
"inputs": ["readme", "output_dir", "api_key", "base_url", "model", "log_file"],
|
"inputs": [],
|
||||||
"outputs": []
|
"outputs": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ dependencies = [
|
||||||
"loguru>=0.7.0",
|
"loguru>=0.7.0",
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
"pathspec>=1.0.4",
|
"pathspec>=1.0.4",
|
||||||
"gitpython>=3.1.0",
|
"python-patch>=0.0.1",
|
||||||
|
"unidiff2>=0.7.8",
|
||||||
]
|
]
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Your Name", email = "your.email@example.com"}
|
{name = "Your Name", email = "your.email@example.com"}
|
||||||
|
|
@ -33,6 +34,10 @@ dev = [
|
||||||
"black>=23.0.0",
|
"black>=23.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
url="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
|
default=true
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
llm-codegen = "llm_codegen.cli:app"
|
llm-codegen = "llm_codegen.cli:app"
|
||||||
|
|
||||||
|
|
@ -45,3 +50,8 @@ max_concurrent_requests = 5
|
||||||
# 新增:指定包所在目录
|
# 新增:指定包所在目录
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.4.2",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import difflib
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
import threading
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskID
|
||||||
|
|
@ -50,6 +51,7 @@ class CodeGenerator:
|
||||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.state_file = self.output_dir / ".llm_generator_state.json"
|
self.state_file = self.output_dir / ".llm_generator_state.json"
|
||||||
self.console = Console() # 添加console实例用于rich打印
|
self.console = Console() # 添加console实例用于rich打印
|
||||||
|
self._state_lock = threading.Lock()
|
||||||
|
|
||||||
self.max_concurrency = max_concurrency
|
self.max_concurrency = max_concurrency
|
||||||
|
|
||||||
|
|
@ -171,9 +173,10 @@ class CodeGenerator:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
|
def save_state(self, generated_files: List[str], dependencies_map: Dict[str, List[str]]) -> None:
|
||||||
"""保存断点续写状态,适应并发生成"""
|
"""保存断点续写状态,适应并发生成(线程安全)"""
|
||||||
|
with self._state_lock: # 串行化写入
|
||||||
state = StateModel(
|
state = StateModel(
|
||||||
current_file_index=0, # 在并发中无效,设为0保持兼容
|
current_file_index=0,
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
dependencies_map=dependencies_map,
|
dependencies_map=dependencies_map,
|
||||||
total_files=len(self.design.files) if self.design else 0,
|
total_files=len(self.design.files) if self.design else 0,
|
||||||
|
|
@ -184,6 +187,7 @@ class CodeGenerator:
|
||||||
json.dump(state.model_dump(), f, indent=2, ensure_ascii=False)
|
json.dump(state.model_dump(), f, indent=2, ensure_ascii=False)
|
||||||
logger.debug(f"状态已保存: {self.state_file}")
|
logger.debug(f"状态已保存: {self.state_file}")
|
||||||
|
|
||||||
|
|
||||||
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
def get_project_structure(self) -> Tuple[List[str], Dict[str, List[str]]]:
|
||||||
"""
|
"""
|
||||||
从design.json获取文件列表和依赖关系
|
从design.json获取文件列表和依赖关系
|
||||||
|
|
@ -471,36 +475,36 @@ class CodeGenerator:
|
||||||
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
|
def _topological_sort(self, files: List[str], dependencies: Dict[str, List[str]]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
对文件列表进行拓扑排序,基于依赖关系。
|
对文件列表进行拓扑排序,基于依赖关系。
|
||||||
|
返回排序后的列表,满足每个文件的依赖项都出现在该文件之前。
|
||||||
Args:
|
如果检测到循环依赖,抛出ValueError。
|
||||||
files: 文件路径列表
|
|
||||||
dependencies: 依赖字典,{file: [依赖文件]}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 拓扑排序后的文件列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果检测到循环依赖
|
|
||||||
"""
|
"""
|
||||||
from collections import deque
|
from collections import deque
|
||||||
# 构建图
|
|
||||||
graph = {file: dependencies.get(file, []) for file in files}
|
|
||||||
in_degree = {file: 0 for file in files}
|
|
||||||
for file in files:
|
|
||||||
for dep in graph[file]:
|
|
||||||
if dep in in_degree:
|
|
||||||
in_degree[dep] += 1
|
|
||||||
|
|
||||||
|
# 初始化入度和反向邻接表
|
||||||
|
in_degree = {f: 0 for f in files}
|
||||||
|
rev_graph = {f: [] for f in files} # 记录哪些文件依赖于f
|
||||||
|
|
||||||
|
# 构建图:如果文件f依赖于dep,则增加f的入度,并将f加入rev_graph[dep]
|
||||||
|
for f in files:
|
||||||
|
for dep in dependencies.get(f, []):
|
||||||
|
if dep in files: # 只考虑在files中的依赖
|
||||||
|
in_degree[f] += 1 # f依赖于dep,所以f的入度增加
|
||||||
|
rev_graph[dep].append(f) # dep被f依赖
|
||||||
|
|
||||||
|
# 队列初始化为入度为0的文件(无依赖的文件)
|
||||||
queue = deque([f for f in files if in_degree[f] == 0])
|
queue = deque([f for f in files if in_degree[f] == 0])
|
||||||
sorted_files = []
|
sorted_files = []
|
||||||
|
|
||||||
while queue:
|
while queue:
|
||||||
node = queue.popleft()
|
node = queue.popleft()
|
||||||
sorted_files.append(node)
|
sorted_files.append(node)
|
||||||
for dep in graph[node]:
|
# 所有依赖于node的文件入度减1
|
||||||
in_degree[dep] -= 1
|
for dependent in rev_graph[node]:
|
||||||
if in_degree[dep] == 0:
|
in_degree[dependent] -= 1
|
||||||
queue.append(dep)
|
if in_degree[dependent] == 0:
|
||||||
|
queue.append(dependent)
|
||||||
|
|
||||||
|
# 检查是否所有文件都已排序(无循环依赖)
|
||||||
if len(sorted_files) != len(files):
|
if len(sorted_files) != len(files):
|
||||||
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
|
raise ValueError(f"检测到循环依赖,排序失败。已排序 {len(sorted_files)} 个文件,总共 {len(files)} 个文件。")
|
||||||
|
|
||||||
|
|
@ -671,6 +675,7 @@ class CodeGenerator:
|
||||||
done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0)
|
done, not_done = concurrent.futures.wait(futures.keys(), return_when=concurrent.futures.FIRST_COMPLETED, timeout=1.0)
|
||||||
for future in done:
|
for future in done:
|
||||||
file = futures.pop(future)
|
file = futures.pop(future)
|
||||||
|
try:
|
||||||
success, error_msg = future.result()
|
success, error_msg = future.result()
|
||||||
# 更新文件进度任务
|
# 更新文件进度任务
|
||||||
if file in file_tasks:
|
if file in file_tasks:
|
||||||
|
|
@ -697,6 +702,22 @@ class CodeGenerator:
|
||||||
logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
|
logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
|
||||||
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
|
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
|
||||||
# 错误处理:继续处理其他文件,但记录失败
|
# 错误处理:继续处理其他文件,但记录失败
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获 Future 中存储的异常
|
||||||
|
logger.error(f"任务 {file} 执行时发生异常: {e}")
|
||||||
|
self.console.print(f"[bold red]❌ 任务 {file} 执行时发生异常: {e}[/bold red]")
|
||||||
|
# 将其视为失败
|
||||||
|
success = False
|
||||||
|
error_msg = str(e)
|
||||||
|
# 然后执行和上面 `else` 分支相同的失败处理逻辑
|
||||||
|
if file in file_tasks:
|
||||||
|
progress.update(file_tasks[file], description=f"生成失败: {file}")
|
||||||
|
progress.remove_task(file_tasks[file])
|
||||||
|
del file_tasks[file] # 清理映射
|
||||||
|
logger.error(f"文件 {file} 生成失败,错误: {error_msg}")
|
||||||
|
self.console.print(f"[bold red]❌ 文件 {file} 生成失败,错误: {error_msg}[/bold red]")
|
||||||
|
# 错误处理:继续处理其他文件,但记录失败
|
||||||
|
|
||||||
|
|
||||||
logger.success("所有文件处理完成!")
|
logger.success("所有文件处理完成!")
|
||||||
# 清理状态文件
|
# 清理状态文件
|
||||||
|
|
@ -708,7 +729,6 @@ class CodeGenerator:
|
||||||
logger.error(f"清理状态文件失败: {e}")
|
logger.error(f"清理状态文件失败: {e}")
|
||||||
self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]")
|
self.console.print(f"[bold red]❌ 清理状态文件失败: {e}[/bold red]")
|
||||||
|
|
||||||
|
|
||||||
def process_issue(self, issue_content: str, issue_type: str) -> bool:
|
def process_issue(self, issue_content: str, issue_type: str) -> bool:
|
||||||
"""
|
"""
|
||||||
处理需求增强或 Bug 修复工单
|
处理需求增强或 Bug 修复工单
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,12 @@
|
||||||
"""
|
"""Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。"""
|
||||||
Diff 应用模块,使用 GitPython 解析和应用 unified diff 格式。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
import git # 需要安装 GitPython
|
from unidiff import PatchSet, Hunk # 需要安装 unidiff2
|
||||||
|
|
||||||
|
|
||||||
def parse_diff(diff: str) -> List[str]:
|
def parse_diff(diff: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
解析 unified diff 字符串,提取受影响的文件路径。
|
解析 unified diff 字符串,提取受影响的文件路径。
|
||||||
|
此函数使用 unidiff2 库来解析 diff。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
diff: unified diff 格式的字符串。
|
diff: unified diff 格式的字符串。
|
||||||
|
|
@ -18,27 +14,74 @@ def parse_diff(diff: str) -> List[str]:
|
||||||
Returns:
|
Returns:
|
||||||
文件路径列表。
|
文件路径列表。
|
||||||
"""
|
"""
|
||||||
files = set()
|
try:
|
||||||
for line in diff.split('\n'):
|
patch_set = PatchSet(diff)
|
||||||
if line.startswith('--- a/'):
|
# unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件
|
||||||
# 提取旧文件路径
|
# 使用集合去重
|
||||||
path = line[6:].strip()
|
files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名
|
||||||
if path and path != '/dev/null': # /dev/null 表示新文件
|
# 过滤掉 /dev/null
|
||||||
files.add(path)
|
return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')]
|
||||||
elif line.startswith('+++ b/'):
|
except Exception as e:
|
||||||
# 提取新文件路径
|
# 如果解析失败,抛出异常
|
||||||
path = line[6:].strip()
|
raise e
|
||||||
if path and path != '/dev/null': # /dev/null 表示删除文件
|
|
||||||
files.add(path)
|
def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]:
|
||||||
return list(files)
|
"""
|
||||||
|
将一个文件的补丁(多个hunk)应用到其内容上。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content_lines: 文件内容的行列表。
|
||||||
|
patch_hunks: 针对该文件的一个或多个Hunk对象列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
应用了补丁后的新内容行列表。
|
||||||
|
"""
|
||||||
|
# 为了正确应用多个 hunk,必须从后往前处理,这样前面的修改才不会影响后面 hunk 的行号
|
||||||
|
sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True)
|
||||||
|
|
||||||
|
current_lines = file_content_lines[:]
|
||||||
|
|
||||||
|
for hunk in sorted_hunks:
|
||||||
|
source_start = hunk.source_start - 1 # 转换为0索引
|
||||||
|
source_len = hunk.source_length
|
||||||
|
|
||||||
|
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
|
||||||
|
source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' '
|
||||||
|
actual_source_lines = current_lines[source_start : source_start + source_len]
|
||||||
|
|
||||||
|
# 提取实际源行和上下文/删除行进行比较,rstrip换行符
|
||||||
|
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
|
||||||
|
|
||||||
|
if source_lines_from_diff != actual_source_for_comparison:
|
||||||
|
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
|
||||||
|
|
||||||
|
# 构建新的内容部分
|
||||||
|
new_part = []
|
||||||
|
for line_obj in hunk:
|
||||||
|
if line_obj.is_added or line_obj.is_context:
|
||||||
|
# 添加新增行或上下文行
|
||||||
|
# unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容
|
||||||
|
# line.value 包含符号,line.text 不包含符号,但可能不准确
|
||||||
|
# 最可靠的方法是 strip() 掉行首的 +/-/ 空格
|
||||||
|
clean_line = str(line_obj).strip()
|
||||||
|
# 如果原文件行尾有换行符,我们也要加上
|
||||||
|
original_trailing = ""
|
||||||
|
if current_lines and current_lines[0].endswith('\n'):
|
||||||
|
original_trailing = "\n"
|
||||||
|
|
||||||
|
new_part.append(clean_line + original_trailing)
|
||||||
|
|
||||||
|
# 替换原内容
|
||||||
|
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
|
||||||
|
current_lines = new_lines
|
||||||
|
|
||||||
|
return current_lines
|
||||||
|
|
||||||
|
|
||||||
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
应用 unified diff 到指定目录,使用 GitPython 库。
|
应用 unified diff 到指定目录。
|
||||||
|
该函数解析 diff,读取磁盘上的文件,应用更改,并写回文件。
|
||||||
该函数解析 diff 并尝试应用更改到文件系统。如果目标目录不是 git 仓库,
|
|
||||||
将尝试初始化一个临时仓库或报告错误。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
diff: unified diff 格式的字符串。
|
diff: unified diff 格式的字符串。
|
||||||
|
|
@ -64,9 +107,11 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||||||
result['message'] = 'Diff string is empty'
|
result['message'] = 'Diff string is empty'
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 解析 diff 获取文件列表
|
# 解析 diff 获取 PatchSet 对象
|
||||||
try:
|
try:
|
||||||
affected_files = parse_diff_files(diff)
|
patch_set = PatchSet(diff)
|
||||||
|
# 收集所有需要修改的目标文件路径
|
||||||
|
affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result['message'] = f"Failed to parse diff: {str(e)}"
|
result['message'] = f"Failed to parse diff: {str(e)}"
|
||||||
result['error_details'] = str(e)
|
result['error_details'] = str(e)
|
||||||
|
|
@ -78,39 +123,57 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试获取或初始化 git 仓库
|
# 遍历每个 patch (即每个文件的变更)
|
||||||
|
for patch_obj in patch_set:
|
||||||
|
# 获取目标文件路径 (修改后的文件名)
|
||||||
|
target_path = patch_obj.target_file.lstrip('/')
|
||||||
|
if not target_path or target_path == '/dev/null':
|
||||||
|
# 如果目标是 /dev/null,则是删除操作,我们跳过或记录
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_file_path = os.path.join(target_dir, target_path)
|
||||||
|
|
||||||
|
# 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
|
||||||
|
source_path = patch_obj.source_file
|
||||||
|
if source_path == '/dev/null':
|
||||||
|
# 源是 /dev/null,说明这是一个新文件,内容从空开始
|
||||||
|
original_content_lines = []
|
||||||
|
else:
|
||||||
|
# 源是普通文件,尝试读取
|
||||||
try:
|
try:
|
||||||
repo = git.Repo(target_dir)
|
with open(full_file_path, 'r', encoding='utf-8') as f:
|
||||||
except git.exc.InvalidGitRepositoryError:
|
original_content_lines = f.readlines()
|
||||||
# 如果不是 git 仓库,初始化一个
|
except FileNotFoundError:
|
||||||
repo = git.Repo.init(target_dir)
|
# 如果文件不存在,但diff不是新文件,这可能是个错误
|
||||||
# 添加所有现有文件到索引,以便应用 diff
|
# 但对于新文件,这种情况也可能发生,因为路径可能不同
|
||||||
repo.git.add('--all')
|
# 我们继续,但要注意如果内容不匹配会报错
|
||||||
|
original_content_lines = []
|
||||||
|
|
||||||
# 应用 diff
|
# 应用此文件的补丁 (所有 hunks)
|
||||||
# 使用 git apply 命令,通过 stdin 传入 diff
|
modified_content_lines = _apply_single_patch_to_content(
|
||||||
# '--whitespace=nowarn' 忽略空白警告
|
original_content_lines,
|
||||||
output = repo.git.apply('--whitespace=nowarn', '--', input=diff)
|
list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器
|
||||||
|
)
|
||||||
|
|
||||||
# 如果成功,更新结果
|
# 将修改后的内容写回文件
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.dirname(full_file_path), exist_ok=True)
|
||||||
|
|
||||||
|
with open(full_file_path, 'w', encoding='utf-8', newline='') as f:
|
||||||
|
f.writelines(modified_content_lines)
|
||||||
|
|
||||||
|
# 如果所有 patch 都成功应用
|
||||||
result['success'] = True
|
result['success'] = True
|
||||||
result['message'] = 'Diff applied successfully'
|
result['message'] = 'Diff applied successfully'
|
||||||
result['applied_files'] = affected_files
|
result['applied_files'] = affected_files
|
||||||
if output:
|
|
||||||
result['message'] += f". Output: {output}"
|
|
||||||
|
|
||||||
except git.exc.GitError as e:
|
|
||||||
# 处理 git 错误,如冲突、格式错误等
|
|
||||||
result['message'] = f"Git error while applying diff: {str(e)}"
|
|
||||||
result['error_details'] = str(e)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 处理其他异常
|
# 处理应用过程中可能出现的任何错误
|
||||||
result['message'] = f"Unexpected error: {str(e)}"
|
result['message'] = f"Error while applying diff: {str(e)}"
|
||||||
result['error_details'] = str(e)
|
result['error_details'] = str(e)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# 如果作为脚本运行,可以提供简单的测试
|
# 如果作为脚本运行,可以提供简单的测试
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 示例用法
|
# 示例用法
|
||||||
|
|
@ -120,6 +183,7 @@ if __name__ == "__main__":
|
||||||
-Hello World
|
-Hello World
|
||||||
+Hello Universe
|
+Hello Universe
|
||||||
"""
|
"""
|
||||||
print("Testing apply_diff...")
|
|
||||||
|
print("Testing apply_diff with unidiff2...")
|
||||||
res = apply_diff(sample_diff, ".")
|
res = apply_diff(sample_diff, ".")
|
||||||
print(res)
|
print(res)
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,20 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""Unit tests for diff_applier.py, covering various scenarios such as new file creation,
|
||||||
Unit tests for diff_applier.py, covering various scenarios such as new file creation,
|
modification of existing files, conflict handling, and error cases."""
|
||||||
modification of existing files, conflict handling, and error cases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import pytest
|
import pytest
|
||||||
import git # GitPython is required; assumed installed via project dependencies
|
|
||||||
|
|
||||||
# Add src directory to path for module import
|
# Add src directory to path for module import
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
from llm_codegen.diff_applier import parse_diff_files, apply_diff
|
# 假设您的文件名为 diff_applier_unidiff2.py
|
||||||
|
from src.llm_codegen.diff_applier import parse_diff, apply_diff
|
||||||
|
|
||||||
|
|
||||||
def test_parse_diff_files():
|
def test_parse_diff():
|
||||||
"""Test parsing unified diff strings to extract file paths."""
|
"""Test parsing unified diff strings to extract file paths."""
|
||||||
# Diff with modification and new file
|
# Diff with modification and new file
|
||||||
diff = """--- a/file1.txt
|
diff = """--- a/file1.txt
|
||||||
|
|
@ -25,128 +22,102 @@ def test_parse_diff_files():
|
||||||
@@ -1 +1 @@
|
@@ -1 +1 @@
|
||||||
-old content
|
-old content
|
||||||
+new content
|
+new content
|
||||||
--- a/new_file.txt
|
--- /dev/null
|
||||||
+++ b/new_file.txt
|
+++ b/new_file.txt
|
||||||
@@ -0,0 +1 @@
|
@@ -0,0 +1 @@
|
||||||
+new file content
|
+new file content"""
|
||||||
"""
|
files = parse_diff(diff)
|
||||||
files = parse_diff_files(diff)
|
|
||||||
assert set(files) == {'file1.txt', 'new_file.txt'}
|
assert set(files) == {'file1.txt', 'new_file.txt'}
|
||||||
|
|
||||||
# Diff with file deletion
|
# Diff with file deletion
|
||||||
diff_del = """--- a/deleted.txt
|
diff_del = """--- a/deleted.txt
|
||||||
+++ /dev/null
|
+++ /dev/null
|
||||||
@@ -1 +0,0 @@
|
@@ -1 +0,0 @@
|
||||||
-content to delete
|
-content to delete"""
|
||||||
"""
|
files = parse_diff(diff_del)
|
||||||
files = parse_diff_files(diff_del)
|
# 删除操作的解析结果应为空,因为没有实际的“目标”文件被创建或修改
|
||||||
assert files == ['deleted.txt'] # Old file path is extracted
|
assert files == []
|
||||||
|
|
||||||
# Empty diff
|
# Empty diff
|
||||||
assert parse_diff_files('') == []
|
assert parse_diff('') == []
|
||||||
assert parse_diff_files('\n') == []
|
assert parse_diff('\n') == []
|
||||||
|
|
||||||
# Diff with only new file (no old file)
|
# Diff with only new file (no old file)
|
||||||
diff_new_only = """--- /dev/null
|
diff_new_only = """--- /dev/null
|
||||||
+++ b/only_new.txt
|
+++ b/only_new.txt
|
||||||
@@ -0,0 +1 @@
|
@@ -0,0 +1 @@
|
||||||
+only new
|
+only new"""
|
||||||
"""
|
files = parse_diff(diff_new_only)
|
||||||
files = parse_diff_files(diff_new_only)
|
assert files == ['only_new.txt']
|
||||||
assert files == [] /dev/null is ignored
|
|
||||||
|
|
||||||
# Invalid diff format (should still handle gracefully)
|
# Invalid diff format (should still handle gracefully)
|
||||||
diff_invalid = "invalid diff string"
|
diff_invalid = "invalid diff string"
|
||||||
files = parse_diff_files(diff_invalid)
|
files = parse_diff(diff_invalid)
|
||||||
assert files == [] # No valid file paths found
|
assert files == [] # No valid file paths found
|
||||||
|
|
||||||
|
def test_apply_diff_new_file():
|
||||||
@pytest.fixture
|
|
||||||
def temp_git_repo():
|
|
||||||
"""Create a temporary git repository for testing apply_diff."""
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
repo = git.Repo.init(temp_dir)
|
|
||||||
# Make an initial commit to have a clean state
|
|
||||||
init_file = os.path.join(temp_dir, 'README.md')
|
|
||||||
with open(init_file, 'w') as f:
|
|
||||||
f.write('Initial commit content')
|
|
||||||
repo.git.add('README.md')
|
|
||||||
repo.git.commit('-m', 'initial commit')
|
|
||||||
yield temp_dir, repo
|
|
||||||
# Cleanup after test
|
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_apply_diff_new_file(temp_git_repo):
|
|
||||||
"""Test applying a diff that creates a new file."""
|
"""Test applying a diff that creates a new file."""
|
||||||
temp_dir, repo = temp_git_repo
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
diff = """--- /dev/null
|
diff = """--- /dev/null
|
||||||
+++ b/test_new.txt
|
+++ b/test_new.txt
|
||||||
@@ -0,0 +1 @@
|
@@ -0,0 +1 @@
|
||||||
+This is a newly created file.
|
+This is a newly created file."""
|
||||||
"""
|
|
||||||
result = apply_diff(diff, temp_dir)
|
result = apply_diff(diff, temp_dir)
|
||||||
assert result['success'] == True
|
assert result['success'] == True
|
||||||
assert 'test_new.txt' in result['applied_files']
|
assert 'test_new.txt' in result['applied_files']
|
||||||
|
|
||||||
new_file_path = os.path.join(temp_dir, 'test_new.txt')
|
new_file_path = os.path.join(temp_dir, 'test_new.txt')
|
||||||
assert os.path.exists(new_file_path)
|
assert os.path.exists(new_file_path)
|
||||||
with open(new_file_path, 'r') as f:
|
with open(new_file_path, 'r') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
# 修复断言,strip()会去掉行末的换行符,使内容匹配
|
||||||
assert content.strip() == 'This is a newly created file.'
|
assert content.strip() == 'This is a newly created file.'
|
||||||
|
|
||||||
|
def test_apply_diff_modify_existing_file():
|
||||||
def test_apply_diff_modify_existing_file(temp_git_repo):
|
|
||||||
"""Test applying a diff that modifies an existing file."""
|
"""Test applying a diff that modifies an existing file."""
|
||||||
temp_dir, repo = temp_git_repo
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Create an existing file and commit it
|
# Create an existing file
|
||||||
existing_file = os.path.join(temp_dir, 'existing.txt')
|
existing_file = os.path.join(temp_dir, 'existing.txt')
|
||||||
with open(existing_file, 'w') as f:
|
with open(existing_file, 'w') as f:
|
||||||
f.write('Original line 1\nOriginal line 2')
|
f.write('Original line 1\nOriginal line 2\n')
|
||||||
repo.git.add('existing.txt')
|
|
||||||
repo.git.commit('-m', 'add existing file')
|
|
||||||
|
|
||||||
diff = """--- a/existing.txt
|
diff = """--- a/existing.txt
|
||||||
+++ b/existing.txt
|
+++ b/existing.txt
|
||||||
@@ -1,2 +1,2 @@
|
@@ -1,2 +1,2 @@
|
||||||
-Original line 1
|
-Original line 1
|
||||||
+Modified line 1
|
+Modified line 1
|
||||||
Original line 2
|
Original line 2"""
|
||||||
"""
|
|
||||||
result = apply_diff(diff, temp_dir)
|
result = apply_diff(diff, temp_dir)
|
||||||
assert result['success'] == True
|
assert result['success'] == True
|
||||||
assert 'existing.txt' in result['applied_files']
|
assert 'existing.txt' in result['applied_files']
|
||||||
|
|
||||||
with open(existing_file, 'r') as f:
|
with open(existing_file, 'r') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
assert content == 'Modified line 1\nOriginal line 2'
|
assert content == 'Modified line 1\nOriginal line 2\n'
|
||||||
|
|
||||||
|
def test_apply_diff_conflict_handling():
|
||||||
def test_apply_diff_conflict_handling(temp_git_repo):
|
"""Test applying a diff that causes a conflict because the source doesn't match."""
|
||||||
"""Test applying a diff that causes a conflict with uncommitted changes."""
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
temp_dir, repo = temp_git_repo
|
# Create a file with specific content
|
||||||
# Create and commit a file
|
|
||||||
conflict_file = os.path.join(temp_dir, 'conflict.txt')
|
conflict_file = os.path.join(temp_dir, 'conflict.txt')
|
||||||
with open(conflict_file, 'w') as f:
|
with open(conflict_file, 'w') as f:
|
||||||
f.write('Initial line 1\nInitial line 2')
|
f.write('Different line 1\nOriginal line 2\n') # This is different from what diff expects
|
||||||
repo.git.add('conflict.txt')
|
|
||||||
repo.git.commit('-m', 'add conflict file')
|
|
||||||
|
|
||||||
# Modify the file without committing to create conflict
|
|
||||||
with open(conflict_file, 'w') as f:
|
|
||||||
f.write('Changed line 1\nInitial line 2') # Change first line
|
|
||||||
|
|
||||||
|
# This diff expects 'Initial line 1' but finds 'Different line 1'
|
||||||
diff = """--- a/conflict.txt
|
diff = """--- a/conflict.txt
|
||||||
+++ b/conflict.txt
|
+++ b/conflict.txt
|
||||||
@@ -1,2 +1,2 @@
|
@@ -1,2 +1,2 @@
|
||||||
-Initial line 1
|
-Initial line 1
|
||||||
+Diff line 1
|
+Diff line 1
|
||||||
Initial line 2
|
Initial line 2"""
|
||||||
"""
|
|
||||||
result = apply_diff(diff, temp_dir)
|
|
||||||
assert result['success'] == False
|
|
||||||
# Check for conflict or error in message
|
|
||||||
assert 'conflict' in result['message'].lower() or 'error' in result['message'].lower()
|
|
||||||
assert result['error_details'] != ''
|
|
||||||
|
|
||||||
|
result = apply_diff(diff, temp_dir)
|
||||||
|
assert result['success'] == False # Should fail due to mismatch
|
||||||
|
# Check for conflict or error in message
|
||||||
|
assert 'error' in result['message'].lower() or 'does not match' in result['message'].lower()
|
||||||
|
assert result['error_details'] != ''
|
||||||
|
|
||||||
def test_apply_diff_empty_diff():
|
def test_apply_diff_empty_diff():
|
||||||
"""Test applying an empty diff string."""
|
"""Test applying an empty diff string."""
|
||||||
|
|
@ -154,7 +125,6 @@ def test_apply_diff_empty_diff():
|
||||||
assert result['success'] == False
|
assert result['success'] == False
|
||||||
assert 'empty' in result['message'].lower()
|
assert 'empty' in result['message'].lower()
|
||||||
|
|
||||||
|
|
||||||
def test_apply_diff_invalid_directory():
|
def test_apply_diff_invalid_directory():
|
||||||
"""Test applying a diff to a non-existent directory."""
|
"""Test applying a diff to a non-existent directory."""
|
||||||
non_existent_dir = '/tmp/non_existent_dir_12345'
|
non_existent_dir = '/tmp/non_existent_dir_12345'
|
||||||
|
|
@ -162,36 +132,33 @@ def test_apply_diff_invalid_directory():
|
||||||
+++ b/dummy.txt
|
+++ b/dummy.txt
|
||||||
@@ -1 +1 @@
|
@@ -1 +1 @@
|
||||||
-old
|
-old
|
||||||
+new
|
+new"""
|
||||||
"""
|
|
||||||
result = apply_diff(diff, non_existent_dir)
|
result = apply_diff(diff, non_existent_dir)
|
||||||
assert result['success'] == False
|
assert result['success'] == False
|
||||||
assert 'does not exist' in result['message'].lower()
|
assert 'does not exist' in result['message'].lower()
|
||||||
|
|
||||||
|
|
||||||
def test_apply_diff_no_git_repo_initialization():
|
def test_apply_diff_no_git_repo_initialization():
|
||||||
"""Test applying a diff to a non-git directory, which should initialize a repo."""
|
"""Test applying a diff to a non-git directory. This test is now redundant as there's no git dependency."""
|
||||||
temp_dir = tempfile.mkdtemp()
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
try:
|
|
||||||
# Create a non-git directory with a file
|
# Create a non-git directory with a file
|
||||||
non_git_file = os.path.join(temp_dir, 'non_git.txt')
|
non_git_file = os.path.join(temp_dir, 'non_git.txt')
|
||||||
with open(non_git_file, 'w') as f:
|
with open(non_git_file, 'w') as f:
|
||||||
f.write('Pre-existing content')
|
f.write('Pre-existing content\n')
|
||||||
|
|
||||||
diff = """--- a/non_git.txt
|
diff = """--- a/non_git.txt
|
||||||
+++ b/non_git.txt
|
+++ b/non_git.txt
|
||||||
@@ -1 +1 @@
|
@@ -1 +1 @@
|
||||||
-Pre-existing content
|
-Pre-existing content
|
||||||
+Updated content
|
+Updated content"""
|
||||||
"""
|
|
||||||
result = apply_diff(diff, temp_dir)
|
result = apply_diff(diff, temp_dir)
|
||||||
assert result['success'] == True
|
assert result['success'] == True
|
||||||
assert 'non_git.txt' in result['applied_files']
|
assert 'non_git.txt' in result['applied_files']
|
||||||
|
|
||||||
with open(non_git_file, 'r') as f:
|
with open(non_git_file, 'r') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
assert content == 'Updated content'
|
assert content == 'Updated content\n'
|
||||||
finally:
|
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue