改进 diff 解析与应用逻辑,清理路径并增强文件处理准确性
This commit is contained in:
parent
75fc8659a0
commit
7999d35d07
|
|
@ -3,6 +3,14 @@ import os
|
|||
from typing import List, Dict, Any
|
||||
from unidiff import PatchSet, Hunk # 需要安装 unidiff2
|
||||
|
||||
def _clean_path(path: str) -> str:
|
||||
"""清理路径,移除 a/ 或 b/ 前缀。"""
|
||||
if path.startswith('a/'):
|
||||
return path[2:]
|
||||
if path.startswith('b/'):
|
||||
return path[2:]
|
||||
return path
|
||||
|
||||
def parse_diff(diff: str) -> List[str]:
|
||||
"""
|
||||
解析 unified diff 字符串,提取受影响的文件路径。
|
||||
|
|
@ -16,11 +24,14 @@ def parse_diff(diff: str) -> List[str]:
|
|||
"""
|
||||
try:
|
||||
patch_set = PatchSet(diff)
|
||||
# unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件
|
||||
# 使用集合去重
|
||||
files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名
|
||||
# 过滤掉 /dev/null
|
||||
return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')]
|
||||
# unidiff2 中的 patch 对象有 source_file 和 target_file
|
||||
# 我们关心的是目标文件,即修改后/创建的文件
|
||||
files = set()
|
||||
for patch in patch_set:
|
||||
if patch.target_file and patch.target_file != '/dev/null':
|
||||
cleaned_path = _clean_path(patch.target_file)
|
||||
files.add(cleaned_path)
|
||||
return list(files)
|
||||
except Exception as e:
|
||||
# 如果解析失败,抛出异常
|
||||
raise e
|
||||
|
|
@ -46,30 +57,30 @@ def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: L
|
|||
source_len = hunk.source_length
|
||||
|
||||
# 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证)
|
||||
source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' '
|
||||
actual_source_lines = current_lines[source_start : source_start + source_len]
|
||||
# source_lines() 包含了删除行(-)和上下文行( )
|
||||
source_lines_from_diff = []
|
||||
for line in hunk.source_lines():
|
||||
# unidiff2的line对象转字符串会带有 +/-/ 等符号,需要strip掉
|
||||
source_lines_from_diff.append(str(line).strip())
|
||||
|
||||
# 提取实际源行和上下文/删除行进行比较,rstrip换行符
|
||||
actual_source_lines = current_lines[source_start : source_start + source_len]
|
||||
actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines]
|
||||
|
||||
if source_lines_from_diff != actual_source_for_comparison:
|
||||
raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}")
|
||||
|
||||
# 构建新的内容部分
|
||||
# 构建新的内容部分 (target lines)
|
||||
# target_lines() 包含了新增行(+)和上下文行( )
|
||||
new_part = []
|
||||
for line_obj in hunk:
|
||||
if line_obj.is_added or line_obj.is_context:
|
||||
# 添加新增行或上下文行
|
||||
# unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容
|
||||
# line.value 包含符号,line.text 不包含符号,但可能不准确
|
||||
# 最可靠的方法是 strip() 掉行首的 +/-/ 空格
|
||||
clean_line = str(line_obj).strip()
|
||||
# 如果原文件行尾有换行符,我们也要加上
|
||||
original_trailing = ""
|
||||
if current_lines and current_lines[0].endswith('\n'):
|
||||
original_trailing = "\n"
|
||||
|
||||
new_part.append(clean_line + original_trailing)
|
||||
for line_obj in hunk.target_lines():
|
||||
# 获取实际内容,strip掉符号,并保持原有的换行符风格
|
||||
clean_line = str(line_obj).strip()
|
||||
# 如果原文件有换行符,则恢复它
|
||||
if current_lines and current_lines[0].endswith('\n'):
|
||||
original_trailing = "\n"
|
||||
else:
|
||||
original_trailing = ""
|
||||
new_part.append(clean_line + original_trailing)
|
||||
|
||||
# 替换原内容
|
||||
new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:]
|
||||
|
|
@ -111,7 +122,11 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
|||
try:
|
||||
patch_set = PatchSet(diff)
|
||||
# 收集所有需要修改的目标文件路径
|
||||
affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null']
|
||||
affected_files = []
|
||||
for patch in patch_set:
|
||||
if patch.target_file and patch.target_file != '/dev/null':
|
||||
cleaned_path = _clean_path(patch.target_file)
|
||||
affected_files.append(cleaned_path)
|
||||
except Exception as e:
|
||||
result['message'] = f"Failed to parse diff: {str(e)}"
|
||||
result['error_details'] = str(e)
|
||||
|
|
@ -126,15 +141,15 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
|||
# 遍历每个 patch (即每个文件的变更)
|
||||
for patch_obj in patch_set:
|
||||
# 获取目标文件路径 (修改后的文件名)
|
||||
target_path = patch_obj.target_file.lstrip('/')
|
||||
target_path = _clean_path(patch_obj.target_file)
|
||||
if not target_path or target_path == '/dev/null':
|
||||
# 如果目标是 /dev/null,则是删除操作,我们跳过或记录
|
||||
# 如果目标是 /dev/null,则是删除操作,我们跳过
|
||||
continue
|
||||
|
||||
full_file_path = os.path.join(target_dir, target_path)
|
||||
|
||||
# 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null)
|
||||
source_path = patch_obj.source_file
|
||||
source_path = _clean_path(patch_obj.source_file)
|
||||
if source_path == '/dev/null':
|
||||
# 源是 /dev/null,说明这是一个新文件,内容从空开始
|
||||
original_content_lines = []
|
||||
|
|
@ -144,9 +159,7 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]:
|
|||
with open(full_file_path, 'r', encoding='utf-8') as f:
|
||||
original_content_lines = f.readlines()
|
||||
except FileNotFoundError:
|
||||
# 如果文件不存在,但diff不是新文件,这可能是个错误
|
||||
# 但对于新文件,这种情况也可能发生,因为路径可能不同
|
||||
# 我们继续,但要注意如果内容不匹配会报错
|
||||
# 如果文件不存在,但diff期望它存在,这会导致冲突
|
||||
original_content_lines = []
|
||||
|
||||
# 应用此文件的补丁 (所有 hunks)
|
||||
|
|
|
|||
Loading…
Reference in New Issue