203 lines
7.5 KiB
Python
203 lines
7.5 KiB
Python
"""Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。"""
|
||
import os
|
||
from typing import List, Dict, Any
|
||
from unidiff import PatchSet, Hunk # 需要安装 unidiff2
|
||
|
||
def _clean_path(path: str) -> str:
|
||
"""清理路径,移除 a/ 或 b/ 前缀。"""
|
||
if path.startswith('a/'):
|
||
return path[2:]
|
||
if path.startswith('b/'):
|
||
return path[2:]
|
||
return path
|
||
|
||
def parse_diff(diff: str) -> List[str]:
|
||
"""
|
||
解析 unified diff 字符串,提取受影响的文件路径。
|
||
此函数使用 unidiff2 库来解析 diff。
|
||
|
||
Args:
|
||
diff: unified diff 格式的字符串。
|
||
|
||
Returns:
|
||
文件路径列表。
|
||
"""
|
||
try:
|
||
patch_set = PatchSet(diff)
|
||
# unidiff2 中的 patch 对象有 source_file 和 target_file
|
||
# 我们关心的是目标文件,即修改后/创建的文件
|
||
files = set()
|
||
for patch in patch_set:
|
||
if patch.target_file and patch.target_file != '/dev/null':
|
||
cleaned_path = _clean_path(patch.target_file)
|
||
files.add(cleaned_path)
|
||
return list(files)
|
||
except Exception as e:
|
||
# 如果解析失败,抛出异常
|
||
raise e
|
||
|
||
def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]:
|
||
"""
|
||
将一个文件的补丁(多个hunk)应用到其内容上。
|
||
|
||
Args:
|
||
file_content_lines: 文件内容的行列表。
|
||
patch_hunks: 针对该文件的一个或多个Hunk对象列表。
|
||
|
||
Returns:
|
||
应用了补丁后的新内容行列表。
|
||
"""
|
||
# 为了正确应用多个 hunk,必须从后往前处理,这样前面的修改才不会影响后面 hunk 的行号
|
||
sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True)
|
||
|
||
current_lines = file_content_lines[:]
|
||
|
||
for hunk in sorted_hunks:
|
||
source_start = hunk.source_start - 1 # 转换为0索引
|
||
source_len = hunk.source_length
|
||
|
||
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
|
||
# source_lines() 包含了删除行(-)和上下文行( )
|
||
source_lines_from_diff = []
|
||
for line in hunk.source_lines():
|
||
# unidiff2的line对象转字符串会带有 +/-/ 等符号,需要strip掉
|
||
source_lines_from_diff.append(str(line).strip())
|
||
|
||
actual_source_lines = current_lines[source_start : source_start + source_len]
|
||
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
|
||
|
||
if source_lines_from_diff != actual_source_for_comparison:
|
||
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
|
||
|
||
# 构建新的内容部分 (target lines)
|
||
# target_lines() 包含了新增行(+)和上下文行( )
|
||
new_part = []
|
||
for line_obj in hunk.target_lines():
|
||
# 获取实际内容,strip掉符号,并保持原有的换行符风格
|
||
clean_line = str(line_obj).strip()
|
||
# 如果原文件有换行符,则恢复它
|
||
if current_lines and current_lines[0].endswith('\n'):
|
||
original_trailing = "\n"
|
||
else:
|
||
original_trailing = ""
|
||
new_part.append(clean_line + original_trailing)
|
||
|
||
# 替换原内容
|
||
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
|
||
current_lines = new_lines
|
||
|
||
return current_lines
|
||
|
||
|
||
def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
||
"""
|
||
应用 unified diff 到指定目录。
|
||
该函数解析 diff,读取磁盘上的文件,应用更改,并写回文件。
|
||
|
||
Args:
|
||
diff: unified diff 格式的字符串。
|
||
target_dir: 目标目录路径,默认为当前目录。
|
||
|
||
Returns:
|
||
字典,包含以下键:
|
||
- 'success' (bool): 是否成功应用。
|
||
- 'message' (str): 成功或错误消息。
|
||
- 'applied_files' (List[str]): 成功应用的文件列表(如果成功)。
|
||
- 'error_details' (str): 详细的错误信息(如果失败)。
|
||
"""
|
||
# 初始化返回值
|
||
result = {
|
||
'success': False,
|
||
'message': '',
|
||
'applied_files': [],
|
||
'error_details': ''
|
||
}
|
||
|
||
# 检查 diff 是否为空
|
||
if not diff or diff.strip() == '':
|
||
result['message'] = 'Diff string is empty'
|
||
return result
|
||
|
||
# 解析 diff 获取 PatchSet 对象
|
||
try:
|
||
patch_set = PatchSet(diff)
|
||
# 收集所有需要修改的目标文件路径
|
||
affected_files = []
|
||
for patch in patch_set:
|
||
if patch.target_file and patch.target_file != '/dev/null':
|
||
cleaned_path = _clean_path(patch.target_file)
|
||
affected_files.append(cleaned_path)
|
||
except Exception as e:
|
||
result['message'] = f"Failed to parse diff: {str(e)}"
|
||
result['error_details'] = str(e)
|
||
return result
|
||
|
||
# 检查目标目录是否存在
|
||
if not os.path.isdir(target_dir):
|
||
result['message'] = f"Target directory does not exist: {target_dir}"
|
||
return result
|
||
|
||
try:
|
||
# 遍历每个 patch (即每个文件的变更)
|
||
for patch_obj in patch_set:
|
||
# 获取目标文件路径 (修改后的文件名)
|
||
target_path = _clean_path(patch_obj.target_file)
|
||
if not target_path or target_path == '/dev/null':
|
||
# 如果目标是 /dev/null,则是删除操作,我们跳过
|
||
continue
|
||
|
||
full_file_path = os.path.join(target_dir, target_path)
|
||
|
||
# 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
|
||
source_path = _clean_path(patch_obj.source_file)
|
||
if source_path == '/dev/null':
|
||
# 源是 /dev/null,说明这是一个新文件,内容从空开始
|
||
original_content_lines = []
|
||
else:
|
||
# 源是普通文件,尝试读取
|
||
try:
|
||
with open(full_file_path, 'r', encoding='utf-8') as f:
|
||
original_content_lines = f.readlines()
|
||
except FileNotFoundError:
|
||
# 如果文件不存在,但diff期望它存在,这会导致冲突
|
||
original_content_lines = []
|
||
|
||
# 应用此文件的补丁 (所有 hunks)
|
||
modified_content_lines = _apply_single_patch_to_content(
|
||
original_content_lines,
|
||
list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器
|
||
)
|
||
|
||
# 将修改后的内容写回文件
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(full_file_path), exist_ok=True)
|
||
|
||
with open(full_file_path, 'w', encoding='utf-8', newline='') as f:
|
||
f.writelines(modified_content_lines)
|
||
|
||
# 如果所有 patch 都成功应用
|
||
result['success'] = True
|
||
result['message'] = 'Diff applied successfully'
|
||
result['applied_files'] = affected_files
|
||
|
||
except Exception as e:
|
||
# 处理应用过程中可能出现的任何错误
|
||
result['message'] = f"Error while applying diff: {str(e)}"
|
||
result['error_details'] = str(e)
|
||
|
||
return result
|
||
|
||
# 如果作为脚本运行,可以提供简单的测试
|
||
if __name__ == "__main__":
|
||
# 示例用法
|
||
sample_diff = """--- a/old_file.txt
|
||
+++ b/new_file.txt
|
||
@@ -1 +1 @@
|
||
-Hello World
|
||
+Hello Universe
|
||
"""
|
||
|
||
print("Testing apply_diff with unidiff2...")
|
||
res = apply_diff(sample_diff, ".")
|
||
print(res)
|