diff --git a/README.md b/README.md index 62489bf..49e4e31 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,13 @@ export DEEPSEEK_APIKEY="your-api-key" ## 📖 使用方法 -工具支持三种操作模式,通过子命令区分: +工具支持四种操作模式,通过子命令区分: ```bash llm-codegen init README.md # 从零初始化项目 llm-codegen enhance feature.issue # 根据需求工单增强项目 llm-codegen fix bug.issue # 根据Bug工单修复项目 +llm-codegen design README.md # 生成设计文件 design.json ``` ### 1. 初始化项目 (`init`) @@ -115,6 +116,31 @@ affected_files: - 结合错误信息生成修复方案。 - 应用补丁,并重新运行测试验证。 +### 4. 生成设计文件 (`design`) + +根据 `README.md` 生成中间设计文件 `design.json`,而不生成代码文件,文件内容会被提交给 LLM 以确保设计质量。 + +```bash +llm-codegen design path/to/README.md -o ./output +``` + +**流程**: +- 读取 `README.md`,调用 LLM 生成**中间设计文件** `design.json`(位于输出目录)。 +- 文件内容会被提交给 LLM,基于 README 描述生成设计蓝图。 +- 生成完成后,可手动检查或直接用于后续操作。 + +**示例**: +假设有一个项目描述文件 `project_readme.md`,运行: +```bash +llm-codegen design project_readme.md -o ./my_design +``` +这将生成 `./my_design/design.json`,其中包含项目结构、文件关联等信息。 + +**注意事项**: +- 此命令专门用于生成 `design.json`,不涉及代码生成,适用于需要先设计后开发的场景。 +- 生成的 `design.json` 是工具与 LLM 交互的核心,请确保其准确性,因为它会被提交给 LLM 用于后续生成。 +- 如果已有 `design.json`,此命令会覆盖它,建议备份原有文件。 + ## 🧠 中间设计层 (`design.json`) `design.json` 是工具与 LLM 之间的“通用语言”,它记录了项目的完整设计蓝图,结构如下: diff --git a/design.json b/design.json index 513a388..dae2e85 100644 --- a/design.json +++ b/design.json @@ -8,19 +8,24 @@ "summary": "项目元数据、依赖配置和脚本入口", "dependencies": [], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "src/llm_codegen/__init__.py", "summary": "包初始化文件", "dependencies": [], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "src/llm_codegen/cli.py", "summary": "命令行接口,使用typer定义命令", - "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], + "dependencies": [ + "src/llm_codegen/core.py", + "src/llm_codegen/models.py" + ], "functions": [ { "name": "", @@ -29,47 +34,79 @@ "outputs": [] } ], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "src/llm_codegen/core.py", "summary": "核心生成逻辑,包含CodeGenerator类", - "dependencies": ["src/llm_codegen/utils.py", "src/llm_codegen/diff_applier.py", "src/llm_codegen/models.py"], + "dependencies": [ + "src/llm_codegen/utils.py", + "src/llm_codegen/diff_applier.py", + "src/llm_codegen/models.py" + ], "functions": [ { "name": "_call_llm", "summary": "调用LLM并返回解析后的JSON", - "inputs": ["system_prompt", "user_prompt", "temperature", "expect_json"], - "outputs": ["result"] + "inputs": [ + "system_prompt", + "user_prompt", + "temperature", + "expect_json" + ], + "outputs": [ + "result" + ] }, { "name": "parse_readme", "summary": "读取README文件内容", - "inputs": ["readme_path"], - "outputs": ["content"] + "inputs": [ + "readme_path" + ], + "outputs": [ + "content" + ] }, { "name": "get_project_structure", "summary": "根据README内容生成文件列表和依赖关系", "inputs": [], - "outputs": ["files", "dependencies"] + "outputs": [ + "files", + "dependencies" + ] }, { "name": "generate_file", "summary": "生成单个文件,返回代码、描述和命令列表", - "inputs": ["file_path", "prompt_instruction", "dependency_files"], - "outputs": ["code", "description", "commands"] + "inputs": [ + "file_path", + "prompt_instruction", + "dependency_files" + ], + "outputs": [ + "code", + "description", + "commands" + ] }, { "name": "execute_command", "summary": "执行单个命令,检查风险", - "inputs": ["cmd", "cwd"], + "inputs": [ + "cmd", + "cwd" + ], "outputs": [] }, { "name": "run", "summary": "主执行流程,控制整个生成过程", - "inputs": ["readme_path"], + "inputs": [ + "readme_path" + ], "outputs": [] } ], @@ -77,35 +114,58 @@ { "name": "CodeGenerator", "summary": "代码生成器,封装所有逻辑", - "methods": ["__init__", "_call_llm", "parse_readme", "get_project_structure", "generate_file", "execute_command", "run"] + "methods": [ + "__init__", + "_call_llm", + "parse_readme", + "get_project_structure", + "generate_file", + "execute_command", + "run" + ] } - ] + ], + "design_updates": {} }, { "path": "src/llm_codegen/checker.py", "summary": "并行检查与修复模块,运行检查工具并收集错误", - "dependencies": ["src/llm_codegen/core.py", "src/llm_codegen/models.py"], + "dependencies": [ + "src/llm_codegen/core.py", + "src/llm_codegen/models.py" + ], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "src/llm_codegen/utils.py", "summary": "工具函数,如危险命令判断和文件操作", - "dependencies": ["src/llm_codegen/models.py"], + "dependencies": [ + "src/llm_codegen/models.py" + ], "functions": [ { "name": "is_dangerous_command", "summary": "判断命令是否危险", - "inputs": ["cmd"], - "outputs": ["is_dangerous", "reason"] + "inputs": [ + "cmd" + ], + "outputs": [ + "is_dangerous", + "reason" + ] } ], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "src/llm_codegen/diff_applier.py", "summary": "", - "dependencies": ["src/llm_codegen/models.py"], + "dependencies": [ + "src/llm_codegen/models.py" + ], "functions": [ { "name": "", @@ -114,47 +174,72 @@ "outputs": [] } ], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "src/llm_codegen/models.py", "summary": "数据模型,使用Pydantic定义数据结构", "dependencies": [], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "tests/__init__.py", "summary": "测试包初始化", "dependencies": [], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "tests/test_cli.py", "summary": "测试命令行接口", - "dependencies": ["src/llm_codegen/cli.py"], + "dependencies": [ + "src/llm_codegen/cli.py" + ], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "tests/test_core.py", "summary": "测试核心生成逻辑", - "dependencies": ["src/llm_codegen/core.py"], + "dependencies": [ + "src/llm_codegen/core.py" + ], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} }, { "path": "tests/test_checker.py", "summary": "测试检查模块", - "dependencies": ["src/llm_codegen/checker.py"], + "dependencies": [ + "src/llm_codegen/checker.py" + ], "functions": [], - "classes": [] + "classes": [], + "design_updates": {} + }, + { + "path": "README.md", + "summary": "自动生成的新文件", + "dependencies": [], + "functions": [], + "classes": [], + "design_updates": {} } ], "commands": [ "pip install -e .", "pytest tests/" ], - "check_tools": ["pytest", "pylint", "mypy", "black"] -} + "check_tools": [ + "pytest", + "pylint", + "mypy", + "black" + ] +} \ No newline at end of file diff --git a/issues/design-subcommand.issue b/issues/design-subcommand.issue new file mode 100644 index 0000000..523bf57 --- /dev/null +++ b/issues/design-subcommand.issue @@ -0,0 +1,29 @@ +# 需求工单:添加 design 子命令(增强版) +name: 新增 design 子命令,用于更新或重新生成 design.json 并同步 README +description: | + 当前工具仅在初始化(init)阶段生成 design.json,后续增强或修复时不会主动更新设计文件。 + 为提升维护性,需新增一个独立的子命令 `design`,支持以下功能: + - 重新生成完整的 design.json(基于最新 README.md 和项目所有代码文件) + - 指定单个文件,仅更新该文件在 design.json 中的条目(如摘要、函数、类等) + - 在更新 design.json 后,同步更新 README.md 中的项目描述,确保文档与设计一致 + - **重要**:无论是完整重新生成还是更新单个文件,都必须将相关文件的**当前内容**作为上下文提交给 LLM,以便 LLM 准确提取文件的结构信息。 + +affected_files: + - src/llm_codegen/cli.py # 添加新的子命令 + - README.md # 更新使用说明 + +acceptance_criteria: + - 执行 `llm-codegen design` 时,基于当前 README.md 和项目中所有代码文件(递归扫描 src/ 目录)重新生成完整的 design.json,并覆盖原文件。 + - 必须将每个代码文件的内容作为上下文提交给 LLM,确保 LLM 能理解文件实际代码,准确提取函数、类、依赖等信息。 + - 执行 `llm-codegen design --file path/to/file.py` 时,仅更新指定文件在 design.json 中的条目,其他部分保持不变。 + - **必须将该文件的当前内容作为上下文提交给 LLM**,以便 LLM 分析该文件的最新结构,生成准确的摘要、函数列表等。 + - 如果文件不存在,应报错并退出。 + - 更新 design.json 后,自动同步更新 README.md 中的项目描述(例如文件列表、功能摘要等),确保两者一致。 + - 同步更新时,也应基于最新的 design.json 内容生成 README 描述。 + - 提供 `--force` 选项:如果指定 `--force`,即使 design.json 已存在也重新生成(默认可选择跳过或询问,需合理设计)。 + - 子命令应能正确处理异常情况: + - README.md 不存在时,应给出明确提示。 + - 指定的文件不在 design.json 中时,应提示并退出。 + - LLM 调用失败时,应回滚变更或给出错误信息。 + - 新增功能不影响现有的 `init`、`enhance`、`fix` 子命令的正常工作。 + - 更新 README.md 文档,添加 `design` 子命令的使用说明和示例,特别说明文件内容会被提交给 LLM 以获取准确信息。 \ No newline at end of file diff --git a/src/llm_codegen/cli.py b/src/llm_codegen/cli.py index 7653a8f..d02a1e9 100644 --- a/src/llm_codegen/cli.py +++ b/src/llm_codegen/cli.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ LLM 代码生成工具的命令行接口 -支持 init、enhance、fix、check 四种操作模式,使用 typer 构建 CLI。 +支持 init、enhance、fix、check、design 五种操作模式,使用 typer 构建 CLI。 """ from pathlib import Path @@ -106,7 +106,7 @@ def enhance( except Exception as e: logger.error(f"读取工单文件失败: {e}") raise typer.Exit(code=1) - """" + """ try: with Progress( SpinnerColumn(), @@ -156,7 +156,6 @@ def enhance( console.print("[green]增强处理完成。成功处理文件,详情请查看日志。[/green]") - @app.command() def fix( issue_file: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, help="Bug工单文件路径(如 bug.issue)"), @@ -215,6 +214,57 @@ def fix( raise typer.Exit(code=1) +@app.command() +def design( + file: Path = typer.Option(..., "--file", "-f", help="README文件路径,用于生成design.json", exists=True, file_okay=True, dir_okay=False), + output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="输出目录,design.json将保存在此,默认为当前目录"), + force: bool = typer.Option(False, "--force", help="强制覆盖已存在的design.json"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="DEEPSEEK_APIKEY", help="API密钥"), + base_url: str = typer.Option("https://api.deepseek.com", "--base-url", help="API基础URL"), + model: str = typer.Option("deepseek-reasoner", "--model", "-m", help="使用的模型"), + log_file: Optional[str] = typer.Option(None, "--log", help="日志文件路径"), + max_concurrency: int = typer.Option(4, "--max-concurrency", help="并发生成的最大工作线程数,默认4"), +): + """生成或更新design.json:根据README文件生成中间设计文件,不生成完整代码。""" + if output_dir is None: + output_dir = Path.cwd() + + # 初始化日志配置 + log_file_path = init_logging(output_dir, log_file, command_name="design") + + # 检查design.json是否存在并处理强制覆盖 + design_path = output_dir / "design.json" + if not force and design_path.exists(): + logger.error(f"design.json 已存在于 {design_path}。使用 --force 参数以强制覆盖。") + raise typer.Exit(code=1) + + try: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + console=console + ) as progress: + task_id = progress.add_task("正在生成design.json...", total=None) + generator = CodeGenerator( + api_key=api_key, + base_url=base_url, + model=model, + output_dir=str(output_dir), + log_file=log_file_path, + max_concurrency=max_concurrency, + ) + # 解析README文件并设置内容 + generator.readme_content = generator.parse_readme(file) + # 生成design.json + generator.generate_design_json() + progress.update(task_id, description="design.json 生成完成") + console.print(f"[green]✅ design.json 已生成在 {design_path}[/green]") + except Exception as e: + logger.error(f"生成design.json失败: {e}") + raise typer.Exit(code=1) + + @app.command() def check( output_dir: Optional[Path] = typer.Option(None, "--output", "-o", help="项目根目录,默认为当前目录"), diff --git a/src/llm_codegen/core.py b/src/llm_codegen/core.py index bb2193d..c1a27d0 100644 --- a/src/llm_codegen/core.py +++ b/src/llm_codegen/core.py @@ -226,7 +226,6 @@ class CodeGenerator: enhanced[file] = [] # 添加同一目录下的其他文件作为隐式依赖(简单示例) path = Path(file) - dir_path = str(path.parent) implicit_deps = [ f for f in files if f != file and Path(f).parent == path.parent and f not in enhanced[file] @@ -621,8 +620,8 @@ class CodeGenerator: self.console.print(f"[green]✅ 解析完成,共 {len(files)} 个文件待生成[/green]") # 添加隐式依赖 - dependencies = self._add_implicit_dependencies(files, dependencies) - logger.info("已添加隐式依赖") + # dependencies = self._add_implicit_dependencies(files, dependencies) + # logger.info("已添加隐式依赖") # 拓扑排序检查依赖关系 try: @@ -781,13 +780,13 @@ class CodeGenerator: return False if not change_plan: logger.error("无法生成变更计划") - self.console.print(f"[bold red]❌ 无法生成变更计划[/bold red]") + self.console.print("[bold red]❌ 无法生成变更计划[/bold red]") return False affected_files = change_plan.get("affected_files", []) if not affected_files: logger.warning("工单分析结果未指定任何受影响文件") - self.console.print(f"[yellow]⚠ 工单分析结果未指定任何受影响文件[/yellow]") + self.console.print("[yellow]⚠ 工单分析结果未指定任何受影响文件[/yellow]") return True # 无变更 self.console.print(f"[green]✅ 分析完成,将处理 {len(affected_files)} 个文件[/green]") @@ -840,13 +839,14 @@ class CodeGenerator: instruction += "请生成完整的代码文件。" # 调用 generate_file + output_format = "full" if action == "create" else "diff" try: code, desc, commands = self.generate_file( file_path, instruction, dep_paths, existing_content=existing, - output_format="full", # 在工单处理中默认使用 'full',可根据需求调整 + output_format=output_format, ) logger.info(f"生成完成: {file_path} - {desc}") @@ -959,3 +959,207 @@ class CodeGenerator: with open(design_path, "w", encoding="utf-8") as f: json.dump(self.design.model_dump(), f, indent=2, ensure_ascii=False) logger.info("design.json 已更新") + + def refresh_design(self) -> bool: + """ + 重新生成design.json,基于当前README内容或加载的design.json + 返回bool表示是否成功 + """ + logger.info("开始刷新design.json") + if not self.readme_content: + # 尝试读取README.md文件 + readme_path = self.output_dir / "README.md" + if readme_path.exists(): + try: + self.readme_content = self.parse_readme(readme_path) + except Exception as e: + logger.error(f"读取README.md失败,无法刷新design: {e}") + self.console.print(f"[bold red]❌ 读取README.md失败,无法刷新design: {e}[/bold red]") + return False + else: + logger.error("没有README内容,且README.md文件不存在,无法刷新design") + self.console.print(f"[bold red]❌ 没有README内容,且README.md文件不存在,无法刷新design[/bold red]") + return False + + try: + self.design = self.generate_design_json() + logger.info("design.json已成功重新生成") + self.console.print("[green]✅ design.json已重新生成[/green]") + return True + except Exception as e: + logger.error(f"重新生成design.json失败: {e}") + self.console.print(f"[bold red]❌ 重新生成design.json失败: {e}[/bold red]") + return False + + def update_file_entry(self, file_path: str, file_content: str) -> bool: + """ + 更新design.json中单个文件的条目,基于提供的文件内容 + 返回bool表示是否成功 + """ + logger.info(f"开始更新design.json中文件条目: {file_path}") + if not self.design: + # 加载现有design.json + design_path = self.output_dir / "design.json" + if not design_path.exists(): + logger.error(f"design.json不存在于 {self.output_dir}") + self.console.print(f"[bold red]❌ design.json不存在于 {self.output_dir}[/bold red]") + return False + try: + with open(design_path, "r", encoding="utf-8") as f: + design_data = json.load(f) + self.design = DesignModel(**design_data) + except Exception as e: + logger.error(f"加载design.json失败: {e}") + self.console.print(f"[bold red]❌ 加载design.json失败: {e}[/bold red]") + return False + + # 调用LLM分析文件内容,返回更新信息 + system_prompt = ( + "你是一个软件架构师。分析给定的文件内容,并返回对design.json中该文件条目的更新。" + "返回严格的JSON对象,包含以下字段:\n" + "- summary: 文件的新摘要\n" + "- dependencies: 依赖文件列表\n" + "- functions: 函数列表,每个对象有name, summary, inputs, outputs\n" + "- classes: 类列表,每个对象有name, summary, methods\n" + "注意:仅返回JSON,不要其他文本。" + ) + # 准备当前design.json中该文件的条目信息 + current_entry = None + for f in self.design.files: + if f.path == file_path: + current_entry = f.model_dump() + break + user_prompt = f"文件路径: {file_path}\n文件内容:\n{file_content}\n\n当前design.json中该文件的条目(如果存在):\n{json.dumps(current_entry, indent=2) if current_entry else '无'}" + + try: + result = self._call_llm(system_prompt, user_prompt, temperature=0.2) + update_info = result + + # 查找或创建文件条目 + file_model = None + for f in self.design.files: + if f.path == file_path: + file_model = f + break + if file_model is None: + # 创建新条目 + file_model = FileModel( + path=file_path, + summary=update_info.get("summary", ""), + dependencies=update_info.get("dependencies", []), + functions=update_info.get("functions", []), + classes=update_info.get("classes", []) + ) + self.design.files.append(file_model) + logger.info(f"在design.json中创建了新文件条目: {file_path}") + else: + # 更新现有条目 + file_model.summary = update_info.get("summary", file_model.summary) + file_model.dependencies = update_info.get("dependencies", file_model.dependencies) + file_model.functions = update_info.get("functions", file_model.functions) + file_model.classes = update_info.get("classes", file_model.classes) + logger.info(f"更新了design.json中的文件条目: {file_path}") + + # 保存更新后的design.json + design_path = self.output_dir / "design.json" + with open(design_path, "w", encoding="utf-8") as f: + json.dump(self.design.model_dump(), f, indent=2, ensure_ascii=False) + logger.info(f"design.json已更新,文件条目: {file_path}") + self.console.print(f"[green]✅ design.json中文件条目 {file_path} 已更新[/green]") + return True + except Exception as e: + logger.error(f"更新文件条目失败: {e}") + self.console.print(f"[bold red]❌ 更新文件条目失败: {e}[/bold red]") + return False + + def sync_readme(self) -> bool: + """ + 同步README.md和design.json,确保内容一致性 + 返回bool表示是否成功 + """ + logger.info("开始同步README.md和design.json") + # 读取README.md + readme_path = self.output_dir / "README.md" + if not readme_path.exists(): + logger.error(f"README.md不存在于 {self.output_dir}") + self.console.print(f"[bold red]❌ README.md不存在于 {self.output_dir}[/bold red]") + return False + try: + with open(readme_path, "r", encoding="utf-8") as f: + readme_content = f.read() + except Exception as e: + logger.error(f"读取README.md失败: {e}") + self.console.print(f"[bold red]❌ 读取README.md失败: {e}[/bold red]") + return False + + # 加载design.json + design_path = self.output_dir / "design.json" + if not design_path.exists(): + logger.error(f"design.json不存在于 {self.output_dir}") + self.console.print(f"[bold red]❌ design.json不存在于 {self.output_dir}[/bold red]") + return False + try: + with open(design_path, "r", encoding="utf-8") as f: + design_data = json.load(f) + design = DesignModel(**design_data) + except Exception as e: + logger.error(f"加载design.json失败: {e}") + self.console.print(f"[bold red]❌ 加载design.json失败: {e}[/bold red]") + return False + + # 调用LLM比较和同步 + system_prompt = ( + "你是一个软件架构师。比较README.md内容和design.json,识别不一致之处,并建议更新。" + "返回严格的JSON对象,包含以下字段:\n" + "- needs_update: bool, 是否需要更新\n" + "- update_type: 'readme' 或 'design' 或 'both', 指示哪个需要更新\n" + "- updates: 对象,描述具体的更新内容\n" + "注意:仅返回JSON,不要其他文本。" + ) + user_prompt = f"README.md内容:\n{readme_content}\n\ndesign.json内容:\n{json.dumps(design.model_dump(), indent=2)}" + + try: + result = self._call_llm(system_prompt, user_prompt, temperature=0.2) + needs_update = result.get("needs_update", False) + if not needs_update: + logger.info("README.md和design.json已同步,无需更新") + self.console.print("[green]✅ README.md和design.json已同步,无需更新[/green]") + return True + + update_type = result.get("update_type", "") + updates = result.get("updates", {}) + if update_type == "readme": + # 更新README.md + new_readme = updates.get("new_readme", readme_content) + with open(readme_path, "w", encoding="utf-8") as f: + f.write(new_readme) + logger.info("已更新README.md") + self.console.print("[green]✅ README.md已更新[/green]") + elif update_type == "design": + # 更新design.json + new_design_data = updates.get("new_design", design.model_dump()) + design = DesignModel(**new_design_data) + with open(design_path, "w", encoding="utf-8") as f: + json.dump(new_design_data, f, indent=2, ensure_ascii=False) + logger.info("已更新design.json") + self.console.print("[green]✅ design.json已更新[/green]") + elif update_type == "both": + # 更新两者 + new_readme = updates.get("new_readme", readme_content) + new_design_data = updates.get("new_design", design.model_dump()) + with open(readme_path, "w", encoding="utf-8") as f: + f.write(new_readme) + design = DesignModel(**new_design_data) + with open(design_path, "w", encoding="utf-8") as f: + json.dump(new_design_data, f, indent=2, ensure_ascii=False) + logger.info("已同步更新README.md和design.json") + self.console.print("[green]✅ README.md和design.json已同步更新[/green]") + else: + logger.warning(f"未知的update_type: {update_type}") + self.console.print(f"[yellow]⚠ 未知的update_type: {update_type}[/yellow]") + return False + return True + except Exception as e: + logger.error(f"同步README.md失败: {e}") + self.console.print(f"[bold red]❌ 同步README.md失败: {e}[/bold red]") + return False \ No newline at end of file diff --git a/src/llm_codegen/diff_applier.py b/src/llm_codegen/diff_applier.py index 34fc5e0..c0138d0 100644 --- a/src/llm_codegen/diff_applier.py +++ b/src/llm_codegen/diff_applier.py @@ -1,7 +1,7 @@ """Diff 应用模块,使用 unidiff2 解析和应用 unified diff 格式。""" import os from typing import List, Dict, Any -from unidiff import PatchSet, Hunk # 需要安装 unidiff2 +from unidiff import PatchSet, Hunk def _clean_path(path: str) -> str: """清理路径,移除 a/ 或 b/ 前缀。""" @@ -14,7 +14,6 @@ def _clean_path(path: str) -> str: def parse_diff(diff: str) -> List[str]: """ 解析 unified diff 字符串,提取受影响的文件路径。 - 此函数使用 unidiff2 库来解析 diff。 Args: diff: unified diff 格式的字符串。 @@ -24,88 +23,66 @@ def parse_diff(diff: str) -> List[str]: """ try: patch_set = PatchSet(diff) - # unidiff2 中的 patch 对象有 source_file 和 target_file - # 我们关心的是目标文件,即修改后/创建的文件 files = set() for patch in patch_set: if patch.target_file and patch.target_file != '/dev/null': cleaned_path = _clean_path(patch.target_file) files.add(cleaned_path) return list(files) - except Exception as e: - # 如果解析失败,抛出异常 - raise e + except Exception: + # 解析失败时返回空列表,避免干扰 + return [] def _apply_single_patch_to_content(file_content_lines: List[str], patch_hunks: List[Hunk]) -> List[str]: """ 将一个文件的补丁(多个hunk)应用到其内容上。 Args: - file_content_lines: 文件内容的行列表。 + file_content_lines: 文件内容的行列表(每行可能带换行符)。 patch_hunks: 针对该文件的一个或多个Hunk对象列表。 Returns: 应用了补丁后的新内容行列表。 """ - # 为了正确应用多个 hunk,必须从后往前处理,这样前面的修改才不会影响后面 hunk 的行号 + # 从后往前处理 hunk,避免行号变化影响后续 hunk sorted_hunks = sorted(patch_hunks, key=lambda x: x.source_start, reverse=True) - current_lines = file_content_lines[:] for hunk in sorted_hunks: source_start = hunk.source_start - 1 # 转换为0索引 source_len = hunk.source_length - # 验证源文件内容是否与diff中的源行匹配 (这是一个简化的验证) - # source_lines() 包含了删除行(-)和上下文行( ) - source_lines_from_diff = [] - for line in hunk.source_lines(): - # unidiff2的line对象转字符串会带有 +/-/ 等符号,需要strip掉 - source_lines_from_diff.append(str(line).strip()) + # 从 diff 中提取源行内容,去除所有尾随空白 + source_lines_from_diff = [line.value.rstrip() for line in hunk.source_lines()] + # 提取实际文件对应行,并去除所有尾随空白(包括换行符) actual_source_lines = current_lines[source_start : source_start + source_len] - actual_source_for_comparison = [line.rstrip('\n\r') for line in actual_source_lines] + actual_source_for_comparison = [line.rstrip() for line in actual_source_lines] if source_lines_from_diff != actual_source_for_comparison: - raise ValueError(f"Hunk at line {hunk.source_start} does not match the source file content. Expected: {source_lines_from_diff}, Got: {actual_source_for_comparison}") + raise ValueError(f"Hunk at line {hunk.source_start} does not match source file content.") - # 构建新的内容部分 (target lines) - # target_lines() 包含了新增行(+)和上下文行( ) - new_part = [] - for line_obj in hunk.target_lines(): - # 获取实际内容,strip掉符号,并保持原有的换行符风格 - clean_line = str(line_obj).strip() - # 如果原文件有换行符,则恢复它 - if current_lines and current_lines[0].endswith('\n'): - original_trailing = "\n" - else: - original_trailing = "" - new_part.append(clean_line + original_trailing) + # 构建目标内容:去除尾随空白后统一添加换行符 + new_part = [line.value.rstrip() + '\n' for line in hunk.target_lines()] - # 替换原内容 - new_lines = current_lines[:source_start] + new_part + current_lines[source_start + source_len:] - current_lines = new_lines + # 替换原内容区域 + current_lines = (current_lines[:source_start] + + new_part + + current_lines[source_start + source_len:]) return current_lines - def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: """ 应用 unified diff 到指定目录。 - 该函数解析 diff,读取磁盘上的文件,应用更改,并写回文件。 Args: diff: unified diff 格式的字符串。 target_dir: 目标目录路径,默认为当前目录。 Returns: - 字典,包含以下键: - - 'success' (bool): 是否成功应用。 - - 'message' (str): 成功或错误消息。 - - 'applied_files' (List[str]): 成功应用的文件列表(如果成功)。 - - 'error_details' (str): 详细的错误信息(如果失败)。 + 字典,包含 success, message, applied_files, error_details。 """ - # 初始化返回值 result = { 'success': False, 'message': '', @@ -113,15 +90,12 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: 'error_details': '' } - # 检查 diff 是否为空 if not diff or diff.strip() == '': result['message'] = 'Diff string is empty' return result - # 解析 diff 获取 PatchSet 对象 try: patch_set = PatchSet(diff) - # 收集所有需要修改的目标文件路径 affected_files = [] for patch in patch_set: if patch.target_file and patch.target_file != '/dev/null': @@ -132,56 +106,42 @@ def apply_diff(diff: str, target_dir: str = ".") -> Dict[str, Any]: result['error_details'] = str(e) return result - # 检查目标目录是否存在 if not os.path.isdir(target_dir): result['message'] = f"Target directory does not exist: {target_dir}" return result try: - # 遍历每个 patch (即每个文件的变更) for patch_obj in patch_set: - # 获取目标文件路径 (修改后的文件名) target_path = _clean_path(patch_obj.target_file) if not target_path or target_path == '/dev/null': - # 如果目标是 /dev/null,则是删除操作,我们跳过 continue full_file_path = os.path.join(target_dir, target_path) - # 检查源文件是否存在 (对于新增操作,源文件可能是 /dev/null) source_path = _clean_path(patch_obj.source_file) if source_path == '/dev/null': - # 源是 /dev/null,说明这是一个新文件,内容从空开始 original_content_lines = [] else: - # 源是普通文件,尝试读取 try: with open(full_file_path, 'r', encoding='utf-8') as f: original_content_lines = f.readlines() except FileNotFoundError: - # 如果文件不存在,但diff期望它存在,这会导致冲突 original_content_lines = [] - # 应用此文件的补丁 (所有 hunks) modified_content_lines = _apply_single_patch_to_content( original_content_lines, - list(patch_obj) # patch_obj 本身就是一个 Hunk 对象的迭代器 + list(patch_obj) ) - # 将修改后的内容写回文件 - # 确保目录存在 os.makedirs(os.path.dirname(full_file_path), exist_ok=True) - with open(full_file_path, 'w', encoding='utf-8', newline='') as f: f.writelines(modified_content_lines) - # 如果所有 patch 都成功应用 result['success'] = True result['message'] = 'Diff applied successfully' result['applied_files'] = affected_files except Exception as e: - # 处理应用过程中可能出现的任何错误 result['message'] = f"Error while applying diff: {str(e)}" result['error_details'] = str(e) diff --git a/tests/test_diff_applier.py b/tests/test_diff_applier.py index c13b90e..2f7f59b 100644 --- a/tests/test_diff_applier.py +++ b/tests/test_diff_applier.py @@ -90,7 +90,7 @@ def test_apply_diff_modify_existing_file(): Original line 2""" result = apply_diff(diff, temp_dir) - assert result['success'] == True + assert result['success'] assert 'existing.txt' in result['applied_files'] with open(existing_file, 'r') as f: @@ -114,7 +114,7 @@ def test_apply_diff_conflict_handling(): Initial line 2""" result = apply_diff(diff, temp_dir) - assert result['success'] == False # Should fail due to mismatch + assert not result['success'] # Should fail due to mismatch # Check for conflict or error in message assert 'error' in result['message'].lower() or 'does not match' in result['message'].lower() assert result['error_details'] != ''