import json import subprocess import pytest from pathlib import Path from src.llm_codegen.checker import Checker from src.llm_codegen.core import CodeGenerator from src.llm_codegen.utils import is_dangerous_command # ---------- Fake 对象 ---------- class FakeCodeGenerator: """假的 CodeGenerator,用于替代真实的 LLM 调用""" def __init__(self, return_value=None): self._call_llm_called = False self._call_llm_args = None self.return_value = return_value or {"patches": [], "description": "模拟修复"} def _call_llm(self, system_prompt, user_prompt, temperature=0.1): self._call_llm_called = True self._call_llm_args = (system_prompt, user_prompt, temperature) return self.return_value # ---------- Fixtures ---------- @pytest.fixture def fake_code_generator(): """返回一个假的 CodeGenerator 实例""" return FakeCodeGenerator() @pytest.fixture def checker(fake_code_generator, tmp_path): """创建 Checker 实例,使用临时目录和假的 code_generator""" output_dir = tmp_path / "test_output" output_dir.mkdir() return Checker( output_dir=output_dir, check_tools=["pylint", "mypy", "black"], code_generator=fake_code_generator, ) # ---------- 测试 ---------- class TestChecker: """测试 Checker 类的功能(无 mock)""" def test_init(self, checker, tmp_path): """测试初始化方法""" assert checker.output_dir == tmp_path / "test_output" assert checker.check_tools == ["pylint", "mypy", "black"] assert checker.results_file == checker.output_dir / "check_results.json" assert isinstance(checker.code_generator, FakeCodeGenerator) def test_run_check_success(self, checker, monkeypatch): """测试 run_check 方法成功运行检查工具""" file_path = Path("test_file.py") # 模拟危险检测返回安全 def fake_dangerous(cmd): return (False, "") monkeypatch.setattr("src.llm_codegen.checker.is_dangerous_command", fake_dangerous) # 模拟 subprocess.run 返回成功 def fake_run(cmd, *args, **kwargs): return subprocess.CompletedProcess( args=cmd, returncode=0, stdout="", stderr="" ) monkeypatch.setattr(subprocess, "run", fake_run) result = checker.run_check("pylint", file_path) assert result["tool"] == "pylint" assert result["file"] == str(file_path) assert result["returncode"] == 0 assert result["errors"] == [] def test_run_check_dangerous_command(self, checker, monkeypatch): """测试 run_check 处理危险命令""" file_path = Path("test_file.py") # 替换 is_dangerous_command 返回危险 def fake_dangerous(cmd): return (True, "包含危险关键词 'rm'") monkeypatch.setattr("src.llm_codegen.checker.is_dangerous_command", fake_dangerous) result = checker.run_check("rm -rf /", file_path) assert result["returncode"] == -1 assert "危险命令被阻止" in result["stderr"] def test_run_check_timeout(self, checker, monkeypatch): """测试 run_check 处理超时""" file_path = Path("test_file.py") # 模拟危险检测返回安全 def fake_dangerous(cmd): return (False, "") monkeypatch.setattr("src.llm_codegen.checker.is_dangerous_command", fake_dangerous) # 让 subprocess.run 抛出超时异常 def fake_run_timeout(*args, **kwargs): raise subprocess.TimeoutExpired(cmd="pylint", timeout=60) monkeypatch.setattr(subprocess, "run", fake_run_timeout) result = checker.run_check("pylint", file_path) assert result["returncode"] == -1 assert "检查超时" in result["stderr"] def test_run_parallel_checks(self, checker, tmp_path, monkeypatch): """测试并行运行检查""" test_file = tmp_path / "test.py" test_file.write_text("print('hello')\n") # 替换 run_check 方法,避免真正执行 fake_results = [ {"tool": "pylint", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []}, {"tool": "mypy", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []}, {"tool": "black", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []} ] call_count = 0 def fake_run_check(tool, file): nonlocal call_count call_count += 1 return fake_results[call_count - 1] monkeypatch.setattr(checker, "run_check", fake_run_check) results = checker.run_parallel_checks([test_file]) assert len(results) == 3 assert all(r["returncode"] == 0 for r in results) assert call_count == 3 def test_save_results(self, checker, tmp_path): """测试保存检查结果""" results = [{"tool": "pylint", "file": "file1.py", "returncode": 0}] checker.save_results(results) results_file = checker.output_dir / "check_results.json" assert results_file.exists() with open(results_file, 'r') as f: loaded = json.load(f) assert loaded == results def test_collect_errors(self, checker, tmp_path): """测试收集错误""" results = [ { "tool": "pylint", "file": "file1.py", "returncode": 1, "stdout": "", "stderr": "", "errors": ["未使用的导入"], }, { "tool": "mypy", "file": "file2.py", "returncode": 0, "stdout": "", "stderr": "", "errors": [], }, ] checker.save_results(results) errors = checker.collect_errors() assert len(errors) == 1 assert errors[0]["file"] == "file1.py" assert errors[0]["tool"] == "pylint" assert errors[0]["error"] == "未使用的导入" def test_collect_errors_no_results(self, checker): """测试收集错误时无结果文件""" errors = checker.collect_errors() assert errors == [] def test_auto_fix(self, checker, tmp_path): """测试自动修复错误""" errors = [{"file": "test.py", "tool": "pylint", "error": "未使用的导入"}] # 文件应放在 output_dir 下 test_file = checker.output_dir / "test.py" test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("import os\nprint('hi')\n") # 设置假的 _call_llm 返回值 fake_return = { "patches": [{"file": "test.py", "code": "print('hi')\n"}], "description": "移除未使用的导入", } checker.code_generator.return_value = fake_return success = checker.auto_fix(errors, context_files=["test.py"]) assert success is True with open(test_file, 'r') as f: assert f.read() == "print('hi')\n" assert checker.code_generator._call_llm_called is True def test_auto_fix_no_errors(self, checker): """测试自动修复无错误时""" success = checker.auto_fix([]) assert success is True def test_run_full_check_and_fix(self, checker, monkeypatch): """测试完整检查与修复循环""" # 替换相关方法,模拟行为 fake_results = [] fake_errors_1 = [{"error": "err"}] fake_errors_2 = [] fake_fix_success = True call_checks = 0 call_collect = 0 call_fix = 0 def fake_run_parallel_checks(): nonlocal call_checks call_checks += 1 return fake_results def fake_collect_errors(results=None): nonlocal call_collect call_collect += 1 if call_collect == 1: return fake_errors_1 else: return fake_errors_2 def fake_auto_fix(errors, context_files=None): nonlocal call_fix call_fix += 1 return fake_fix_success monkeypatch.setattr(checker, "run_parallel_checks", fake_run_parallel_checks) monkeypatch.setattr(checker, "collect_errors", fake_collect_errors) monkeypatch.setattr(checker, "auto_fix", fake_auto_fix) result = checker.run_full_check_and_fix(max_retries=2) assert result is True assert call_checks == 2 assert call_collect == 2 assert call_fix == 1 def test_run_full_check_and_fix_failure(self, checker, monkeypatch): """测试完整检查与修复循环失败""" fake_results = [] fake_errors = [{"error": "err"}] fake_fix_success = False call_checks = 0 call_collect = 0 call_fix = 0 def fake_run_parallel_checks(): nonlocal call_checks call_checks += 1 return fake_results def fake_collect_errors(results=None): nonlocal call_collect call_collect += 1 return fake_errors def fake_auto_fix(errors, context_files=None): nonlocal call_fix call_fix += 1 return fake_fix_success monkeypatch.setattr(checker, "run_parallel_checks", fake_run_parallel_checks) monkeypatch.setattr(checker, "collect_errors", fake_collect_errors) monkeypatch.setattr(checker, "auto_fix", fake_auto_fix) result = checker.run_full_check_and_fix(max_retries=1) assert result is False assert call_checks == 1 assert call_fix == 1