llmcodegen/src/llm_codegen/diff_applier.py

203 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。"""
import os
from typing import List, Dict, Any
from unidiff import PatchSet, Hunk # 需要安装 unidiff2
def _clean_path(path: str) -> str:
"""清理路径,移除 a/ 或 b/ 前缀。"""
if path.startswith('a/'):
return path[2:]
if path.startswith('b/'):
return path[2:]
return path
def parse_diff(diff: str) -> List[str]:
"""
解析 unified diff 字符串,提取受影响的文件路径。
此函数使用 unidiff2 库来解析 diff。
Args:
diff: unified diff 格式的字符串。
Returns:
文件路径列表。
"""
try:
patch_set = PatchSet(diff)
# unidiff2 中的 patch 对象有 source_file 和 target_file
# 我们关心的是目标文件,即修改后/创建的文件
files = set()
for patch in patch_set:
if patch.target_file and patch.target_file != '/dev/null':
cleaned_path = _clean_path(patch.target_file)
files.add(cleaned_path)
return list(files)
except Exception as e:
# 如果解析失败,抛出异常
raise e
def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]:
"""
将一个文件的补丁多个hunk应用到其内容上。
Args:
file_content_lines: 文件内容的行列表。
patch_hunks: 针对该文件的一个或多个Hunk对象列表。
Returns:
应用了补丁后的新内容行列表。
"""
# 为了正确应用多个 hunk必须从后往前处理这样前面的修改才不会影响后面 hunk 的行号
sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True)
current_lines = file_content_lines[:]
for hunk in sorted_hunks:
source_start = hunk.source_start - 1 # 转换为0索引
source_len = hunk.source_length
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
# source_lines() 包含了删除行(-)和上下文行( )
source_lines_from_diff = []
for line in hunk.source_lines():
# unidiff2的line对象转字符串会带有 +/-/ 等符号需要strip掉
source_lines_from_diff.append(str(line).strip())
actual_source_lines = current_lines[source_start : source_start + source_len]
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
if source_lines_from_diff != actual_source_for_comparison:
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
# 构建新的内容部分 (target lines)
# target_lines() 包含了新增行(+)和上下文行( )
new_part = []
for line_obj in hunk.target_lines():
# 获取实际内容strip掉符号并保持原有的换行符风格
clean_line = str(line_obj).strip()
# 如果原文件有换行符,则恢复它
if current_lines and current_lines[0].endswith('\n'):
original_trailing = "\n"
else:
original_trailing = ""
new_part.append(clean_line + original_trailing)
# 替换原内容
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
current_lines = new_lines
return current_lines
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
"""
应用 unified diff 到指定目录。
该函数解析 diff读取磁盘上的文件应用更改并写回文件。
Args:
diff: unified diff 格式的字符串。
target_dir: 目标目录路径,默认为当前目录。
Returns:
字典,包含以下键:
- 'success' (bool): 是否成功应用。
- 'message' (str): 成功或错误消息。
- 'applied_files' (List[str]): 成功应用的文件列表(如果成功)。
- 'error_details' (str): 详细的错误信息(如果失败)。
"""
# 初始化返回值
result = {
'success': False,
'message': '',
'applied_files': [],
'error_details': ''
}
# 检查 diff 是否为空
if not diff or diff.strip() == '':
result['message'] = 'Diff string is empty'
return result
# 解析 diff 获取 PatchSet 对象
try:
patch_set = PatchSet(diff)
# 收集所有需要修改的目标文件路径
affected_files = []
for patch in patch_set:
if patch.target_file and patch.target_file != '/dev/null':
cleaned_path = _clean_path(patch.target_file)
affected_files.append(cleaned_path)
except Exception as e:
result['message'] = f"Failed to parse diff: {str(e)}"
result['error_details'] = str(e)
return result
# 检查目标目录是否存在
if not os.path.isdir(target_dir):
result['message'] = f"Target directory does not exist: {target_dir}"
return result
try:
# 遍历每个 patch (即每个文件的变更)
for patch_obj in patch_set:
# 获取目标文件路径 (修改后的文件名)
target_path = _clean_path(patch_obj.target_file)
if not target_path or target_path == '/dev/null':
# 如果目标是 /dev/null则是删除操作我们跳过
continue
full_file_path = os.path.join(target_dir, target_path)
# 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
source_path = _clean_path(patch_obj.source_file)
if source_path == '/dev/null':
# 源是 /dev/null说明这是一个新文件内容从空开始
original_content_lines = []
else:
# 源是普通文件,尝试读取
try:
with open(full_file_path, 'r', encoding='utf-8') as f:
original_content_lines = f.readlines()
except FileNotFoundError:
# 如果文件不存在但diff期望它存在这会导致冲突
original_content_lines = []
# 应用此文件的补丁 (所有 hunks)
modified_content_lines = _apply_single_patch_to_content(
original_content_lines,
list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器
)
# 将修改后的内容写回文件
# 确保目录存在
os.makedirs(os.path.dirname(full_file_path), exist_ok=True)
with open(full_file_path, 'w', encoding='utf-8', newline='') as f:
f.writelines(modified_content_lines)
# 如果所有 patch 都成功应用
result['success'] = True
result['message'] = 'Diff applied successfully'
result['applied_files'] = affected_files
except Exception as e:
# 处理应用过程中可能出现的任何错误
result['message'] = f"Error while applying diff: {str(e)}"
result['error_details'] = str(e)
return result
# 如果作为脚本运行,可以提供简单的测试
if __name__ == "__main__":
# 示例用法
sample_diff = """--- a/old_file.txt
+++ b/new_file.txt
@@ -1 +1 @@
-Hello World
+Hello Universe
"""
print("Testing apply_diff with unidiff2...")
res = apply_diff(sample_diff, ".")
print(res)