From 7999d35d07ff73b5d8a36202c9778e590742bc99 Mon Sep 17 00:00:00 2001 From: songsenand Date: Wed, 18 Mar 2026 16:47:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=20diff=20=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E4=B8=8E=E5=BA=94=E7=94=A8=E9=80=BB=E8=BE=91=EF=BC=8C=E6=B8=85?= =?UTF-8?q?=E7=90=86=E8=B7=AF=E5=BE=84=E5=B9=B6=E5=A2=9E=E5=BC=BA=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E5=A4=84=E7=90=86=E5=87=86=E7=A1=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_codegen/diff_applier.py | 71 +++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/src/llm_codegen/diff_applier.py b/src/llm_codegen/diff_applier.py index 05b349f..34fc5e0 100644 --- a/src/llm_codegen/diff_applier.py +++ b/src/llm_codegen/diff_applier.py @@ -3,6 +3,14 @@ import os from typing import List, Dict, Any from unidiff import PatchSet, Hunk # 需要安装 unidiff2 +def _clean_path(path: str) -> str: + """清理路径,移除 a/ 或 b/ 前缀。""" + if path.startswith('a/'): + return path[2:] + if path.startswith('b/'): + return path[2:] + return path + def parse_diff(diff: str) -> List[str]: """ 解析 unified diff 字符串,提取受影响的文件路径。 @@ -16,11 +24,14 @@ def parse_diff(diff: str) -> List[str]: """ try: patch_set = PatchSet(diff) - # unidiff2 中的 patch 对象有一个 path 属性,代表修改的文件 - # 使用集合去重 - files = {patch.target_file for patch in patch_set} # 使用target_file获取修改后的文件名 - # 过滤掉 /dev/null - return [f.lstrip('/') for f in files if f != '/dev/null' and not f.endswith('/dev/null')] + # unidiff2 中的 patch 对象有 source_file 和 target_file + # 我们关心的是目标文件,即修改后/创建的文件 + files = set() + for patch in patch_set: + if patch.target_file and patch.target_file != '/dev/null': + cleaned_path = _clean_path(patch.target_file) + files.add(cleaned_path) + return list(files) except Exception as e: # 如果解析失败,抛出异常 raise e @@ -46,30 +57,30 @@ def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: L source_len = hunk.source_length # 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证) - source_lines_from_diff = [str(l).strip() for l in hunk.source_lines()] # strip()去除行首的'-'或' ' - actual_source_lines = current_lines[source_start : source_start + source_len] + # source_lines() 包含了删除行(-)和上下文行( ) + source_lines_from_diff = [] + for line in hunk.source_lines(): + # unidiff2的line对象转字符串会带有 +/-/ 等符号,需要strip掉 + source_lines_from_diff.append(str(line).strip()) - # 提取实际源行和上下文/删除行进行比较,rstrip换行符 + actual_source_lines = current_lines[source_start : source_start + source_len] actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines] if source_lines_from_diff != actual_source_for_comparison: raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}") - # 构建新的内容部分 + # 构建新的内容部分 (target lines) + # target_lines() 包含了新增行(+)和上下文行( ) new_part = [] - for line_obj in hunk: - if line_obj.is_added or line_obj.is_context: - # 添加新增行或上下文行 - # unidiff2的line对象转字符串会带有 +/-/ 等符号,我们需要获取实际内容 - # line.value 包含符号,line.text 不包含符号,但可能不准确 - # 最可靠的方法是 strip() 掉行首的 +/-/ 空格 - clean_line = str(line_obj).strip() - # 如果原文件行尾有换行符,我们也要加上 - original_trailing = "" - if current_lines and current_lines[0].endswith('\n'): - original_trailing = "\n" - - new_part.append(clean_line + original_trailing) + for line_obj in hunk.target_lines(): + # 获取实际内容,strip掉符号,并保持原有的换行符风格 + clean_line = str(line_obj).strip() + # 如果原文件有换行符,则恢复它 + if current_lines and current_lines[0].endswith('\n'): + original_trailing = "\n" + else: + original_trailing = "" + new_part.append(clean_line + original_trailing) # 替换原内容 new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:] @@ -111,7 +122,11 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: try: patch_set = PatchSet(diff) # 收集所有需要修改的目标文件路径 - affected_files = [p.target_file.lstrip('/') for p in patch_set if p.target_file != '/dev/null'] + affected_files = [] + for patch in patch_set: + if patch.target_file and patch.target_file != '/dev/null': + cleaned_path = _clean_path(patch.target_file) + affected_files.append(cleaned_path) except Exception as e: result['message'] = f"Failed to parse diff: {str(e)}" result['error_details'] = str(e) @@ -126,15 +141,15 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: # 遍历每个 patch (即每个文件的变更) for patch_obj in patch_set: # 获取目标文件路径 (修改后的文件名) - target_path = patch_obj.target_file.lstrip('/') + target_path = _clean_path(patch_obj.target_file) if not target_path or target_path == '/dev/null': - # 如果目标是 /dev/null,则是删除操作,我们跳过或记录 + # 如果目标是 /dev/null,则是删除操作,我们跳过 continue full_file_path = os.path.join(target_dir, target_path) # 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null) - source_path = patch_obj.source_file + source_path = _clean_path(patch_obj.source_file) if source_path == '/dev/null': # 源是 /dev/null,说明这是一个新文件,内容从空开始 original_content_lines = [] @@ -144,9 +159,7 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: with open(full_file_path, 'r', encoding='utf-8') as f: original_content_lines = f.readlines() except FileNotFoundError: - # 如果文件不存在,但diff不是新文件,这可能是个错误 - # 但对于新文件,这种情况也可能发生,因为路径可能不同 - # 我们继续,但要注意如果内容不匹配会报错 + # 如果文件不存在,但diff期望它存在,这会导致冲突 original_content_lines = [] # 应用此文件的补丁 (所有 hunks)