260 lines
8.7 KiB
Python
260 lines
8.7 KiB
Python
import json
|
||
import subprocess
|
||
import pytest
|
||
from pathlib import Path
|
||
|
||
from src.llm_codegen.checker import Checker
|
||
from src.llm_codegen.core import CodeGenerator
|
||
from src.llm_codegen.utils import is_dangerous_command
|
||
|
||
|
||
# ---------- Fake 对象 ----------
|
||
class FakeCodeGenerator:
|
||
"""假的 CodeGenerator,用于替代真实的 LLM 调用"""
|
||
def __init__(self, return_value=None):
|
||
self._call_llm_called = False
|
||
self._call_llm_args = None
|
||
self.return_value = return_value or {"patches": [], "description": "模拟修复"}
|
||
|
||
def _call_llm(self, system_prompt, user_prompt, temperature=0.1):
|
||
self._call_llm_called = True
|
||
self._call_llm_args = (system_prompt, user_prompt, temperature)
|
||
return self.return_value
|
||
|
||
|
||
# ---------- Fixtures ----------
|
||
@pytest.fixture
|
||
def fake_code_generator():
|
||
"""返回一个假的 CodeGenerator 实例"""
|
||
return FakeCodeGenerator()
|
||
|
||
|
||
@pytest.fixture
|
||
def checker(fake_code_generator, tmp_path):
|
||
"""创建 Checker 实例,使用临时目录和假的 code_generator"""
|
||
output_dir = tmp_path / "test_output"
|
||
output_dir.mkdir()
|
||
return Checker(
|
||
output_dir=output_dir,
|
||
code_generator=fake_code_generator,
|
||
)
|
||
|
||
|
||
# ---------- 测试 ----------
|
||
class TestChecker:
|
||
"""测试 Checker 类的功能(无 mock)"""
|
||
|
||
def test_init(self, checker, tmp_path):
|
||
"""测试初始化方法"""
|
||
assert checker.output_dir == tmp_path / "test_output"
|
||
assert checker.check_tools == ["black"]
|
||
assert checker.results_file == checker.output_dir / "check_results.json"
|
||
assert isinstance(checker.code_generator, FakeCodeGenerator)
|
||
|
||
def test_run_check_success(self, checker, monkeypatch):
|
||
"""测试 run_check 方法成功运行检查工具"""
|
||
file_path = Path("test_file.py")
|
||
|
||
# 模拟 subprocess.run 返回成功
|
||
def fake_run(cmd, *args, **kwargs):
|
||
return subprocess.CompletedProcess(
|
||
args=cmd,
|
||
returncode=0,
|
||
stdout="",
|
||
stderr=""
|
||
)
|
||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||
|
||
result = checker.run_check("pylint", file_path)
|
||
|
||
assert result["tool"] == "pylint"
|
||
assert result["file"] == str(file_path)
|
||
assert result["returncode"] == 0
|
||
assert result["errors"] == []
|
||
|
||
def test_run_check_timeout(self, checker, monkeypatch):
|
||
"""测试 run_check 处理超时"""
|
||
file_path = Path("test_file.py")
|
||
|
||
# 让 subprocess.run 抛出超时异常
|
||
def fake_run_timeout(*args, **kwargs):
|
||
raise subprocess.TimeoutExpired(cmd="pylint", timeout=60)
|
||
monkeypatch.setattr(subprocess, "run", fake_run_timeout)
|
||
|
||
result = checker.run_check("pylint", file_path)
|
||
|
||
assert result["returncode"] == -1
|
||
assert "检查超时" in result["stderr"]
|
||
|
||
def test_run_parallel_checks(self, checker, tmp_path, monkeypatch):
|
||
"""测试并行运行检查"""
|
||
test_file = tmp_path / "test.py"
|
||
test_file.write_text("print('hello')\n")
|
||
|
||
# 替换 run_check 方法,避免真正执行
|
||
fake_results = [
|
||
{"tool": "pylint", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []},
|
||
{"tool": "mypy", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []},
|
||
{"tool": "black", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []}
|
||
]
|
||
call_count = 0
|
||
def fake_run_check(tool, file):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
return fake_results[call_count - 1]
|
||
monkeypatch.setattr(checker, "run_check", fake_run_check)
|
||
|
||
results = checker.run_parallel_checks([test_file])
|
||
|
||
assert len(results) == 1
|
||
assert all(r["returncode"] == 0 for r in results)
|
||
assert call_count == 1
|
||
|
||
def test_save_results(self, checker, tmp_path):
|
||
"""测试保存检查结果"""
|
||
results = [{"tool": "pylint", "file": "file1.py", "returncode": 0}]
|
||
checker.save_results(results)
|
||
|
||
results_file = checker.output_dir / "check_results.json"
|
||
assert results_file.exists()
|
||
with open(results_file, 'r') as f:
|
||
loaded = json.load(f)
|
||
assert loaded == results
|
||
|
||
def test_collect_errors(self, checker, tmp_path):
|
||
"""测试收集错误"""
|
||
results = [
|
||
{
|
||
"tool": "pylint",
|
||
"file": "file1.py",
|
||
"returncode": 1,
|
||
"stdout": "",
|
||
"stderr": "",
|
||
"errors": ["未使用的导入"],
|
||
},
|
||
{
|
||
"tool": "mypy",
|
||
"file": "file2.py",
|
||
"returncode": 0,
|
||
"stdout": "",
|
||
"stderr": "",
|
||
"errors": [],
|
||
},
|
||
]
|
||
checker.save_results(results)
|
||
errors = checker.collect_errors()
|
||
|
||
assert len(errors) == 1
|
||
assert errors[0]["file"] == "file1.py"
|
||
assert errors[0]["tool"] == "pylint"
|
||
assert errors[0]["error"] == "未使用的导入"
|
||
|
||
def test_collect_errors_no_results(self, checker):
|
||
"""测试收集错误时无结果文件"""
|
||
errors = checker.collect_errors()
|
||
assert errors == []
|
||
|
||
def test_auto_fix(self, checker, tmp_path):
|
||
"""测试自动修复错误"""
|
||
errors = [{"file": "test.py", "tool": "pylint", "error": "未使用的导入"}]
|
||
|
||
# 文件应放在 output_dir 下
|
||
test_file = checker.output_dir / "test.py"
|
||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||
test_file.write_text("import os\nprint('hi')\n")
|
||
|
||
# 设置假的 _call_llm 返回值
|
||
fake_return = {
|
||
"patches": [{"file": "test.py", "code": "print('hi')\n"}],
|
||
"description": "移除未使用的导入",
|
||
}
|
||
checker.code_generator.return_value = fake_return
|
||
|
||
success = checker.auto_fix(errors, context_files=["test.py"])
|
||
|
||
assert success is True
|
||
with open(test_file, 'r') as f:
|
||
assert f.read() == "print('hi')\n"
|
||
assert checker.code_generator._call_llm_called is True
|
||
|
||
def test_auto_fix_no_errors(self, checker):
|
||
"""测试自动修复无错误时"""
|
||
success = checker.auto_fix([])
|
||
assert success is True
|
||
|
||
def test_run_full_check_and_fix(self, checker, monkeypatch):
|
||
"""测试完整检查与修复循环"""
|
||
# 替换相关方法,模拟行为
|
||
fake_results = []
|
||
fake_errors_1 = [{"error": "err"}]
|
||
fake_errors_2 = []
|
||
fake_fix_success = True
|
||
|
||
call_checks = 0
|
||
call_collect = 0
|
||
call_fix = 0
|
||
|
||
def fake_run_parallel_checks():
|
||
nonlocal call_checks
|
||
call_checks += 1
|
||
return fake_results
|
||
|
||
def fake_collect_errors(results=None):
|
||
nonlocal call_collect
|
||
call_collect += 1
|
||
if call_collect == 1:
|
||
return fake_errors_1
|
||
else:
|
||
return fake_errors_2
|
||
|
||
def fake_auto_fix(errors, context_files=None):
|
||
nonlocal call_fix
|
||
call_fix += 1
|
||
return fake_fix_success
|
||
|
||
monkeypatch.setattr(checker, "run_parallel_checks", fake_run_parallel_checks)
|
||
monkeypatch.setattr(checker, "collect_errors", fake_collect_errors)
|
||
monkeypatch.setattr(checker, "auto_fix", fake_auto_fix)
|
||
|
||
result = checker.run_full_check_and_fix(max_retries=2)
|
||
|
||
assert result is True
|
||
assert call_checks == 2
|
||
assert call_collect == 2
|
||
assert call_fix == 1
|
||
|
||
def test_run_full_check_and_fix_failure(self, checker, monkeypatch):
|
||
"""测试完整检查与修复循环失败"""
|
||
fake_results = []
|
||
fake_errors = [{"error": "err"}]
|
||
fake_fix_success = False
|
||
|
||
call_checks = 0
|
||
call_collect = 0
|
||
call_fix = 0
|
||
|
||
def fake_run_parallel_checks():
|
||
nonlocal call_checks
|
||
call_checks += 1
|
||
return fake_results
|
||
|
||
def fake_collect_errors(results=None):
|
||
nonlocal call_collect
|
||
call_collect += 1
|
||
return fake_errors
|
||
|
||
def fake_auto_fix(errors, context_files=None):
|
||
nonlocal call_fix
|
||
call_fix += 1
|
||
return fake_fix_success
|
||
|
||
monkeypatch.setattr(checker, "run_parallel_checks", fake_run_parallel_checks)
|
||
monkeypatch.setattr(checker, "collect_errors", fake_collect_errors)
|
||
monkeypatch.setattr(checker, "auto_fix", fake_auto_fix)
|
||
|
||
result = checker.run_full_check_and_fix(max_retries=1)
|
||
|
||
assert result is False
|
||
assert call_checks == 1
|
||
assert call_fix == 1
|