llmcodegen/tests/test_checker.py

285 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import subprocess
import pytest
from pathlib import Path
from src.llm_codegen.checker import Checker
from src.llm_codegen.core import CodeGenerator
from src.llm_codegen.utils import is_dangerous_command
# ---------- Fake 对象 ----------
class FakeCodeGenerator:
"""假的 CodeGenerator用于替代真实的 LLM 调用"""
def __init__(self, return_value=None):
self._call_llm_called = False
self._call_llm_args = None
self.return_value = return_value or {"patches": [], "description": "模拟修复"}
def _call_llm(self, system_prompt, user_prompt, temperature=0.1):
self._call_llm_called = True
self._call_llm_args = (system_prompt, user_prompt, temperature)
return self.return_value
# ---------- Fixtures ----------
@pytest.fixture
def fake_code_generator():
"""返回一个假的 CodeGenerator 实例"""
return FakeCodeGenerator()
@pytest.fixture
def checker(fake_code_generator, tmp_path):
"""创建 Checker 实例,使用临时目录和假的 code_generator"""
output_dir = tmp_path / "test_output"
output_dir.mkdir()
return Checker(
output_dir=output_dir,
check_tools=["pylint", "mypy", "black"],
code_generator=fake_code_generator,
)
# ---------- 测试 ----------
class TestChecker:
"""测试 Checker 类的功能(无 mock"""
def test_init(self, checker, tmp_path):
"""测试初始化方法"""
assert checker.output_dir == tmp_path / "test_output"
assert checker.check_tools == ["pylint", "mypy", "black"]
assert checker.results_file == checker.output_dir / "check_results.json"
assert isinstance(checker.code_generator, FakeCodeGenerator)
def test_run_check_success(self, checker, monkeypatch):
"""测试 run_check 方法成功运行检查工具"""
file_path = Path("test_file.py")
# 模拟危险检测返回安全
def fake_dangerous(cmd):
return (False, "")
monkeypatch.setattr("src.llm_codegen.checker.is_dangerous_command", fake_dangerous)
# 模拟 subprocess.run 返回成功
def fake_run(cmd, *args, **kwargs):
return subprocess.CompletedProcess(
args=cmd,
returncode=0,
stdout="",
stderr=""
)
monkeypatch.setattr(subprocess, "run", fake_run)
result = checker.run_check("pylint", file_path)
assert result["tool"] == "pylint"
assert result["file"] == str(file_path)
assert result["returncode"] == 0
assert result["errors"] == []
def test_run_check_dangerous_command(self, checker, monkeypatch):
"""测试 run_check 处理危险命令"""
file_path = Path("test_file.py")
# 替换 is_dangerous_command 返回危险
def fake_dangerous(cmd):
return (True, "包含危险关键词 'rm'")
monkeypatch.setattr("src.llm_codegen.checker.is_dangerous_command", fake_dangerous)
result = checker.run_check("rm -rf /", file_path)
assert result["returncode"] == -1
assert "危险命令被阻止" in result["stderr"]
def test_run_check_timeout(self, checker, monkeypatch):
"""测试 run_check 处理超时"""
file_path = Path("test_file.py")
# 模拟危险检测返回安全
def fake_dangerous(cmd):
return (False, "")
monkeypatch.setattr("src.llm_codegen.checker.is_dangerous_command", fake_dangerous)
# 让 subprocess.run 抛出超时异常
def fake_run_timeout(*args, **kwargs):
raise subprocess.TimeoutExpired(cmd="pylint", timeout=60)
monkeypatch.setattr(subprocess, "run", fake_run_timeout)
result = checker.run_check("pylint", file_path)
assert result["returncode"] == -1
assert "检查超时" in result["stderr"]
def test_run_parallel_checks(self, checker, tmp_path, monkeypatch):
"""测试并行运行检查"""
test_file = tmp_path / "test.py"
test_file.write_text("print('hello')\n")
# 替换 run_check 方法,避免真正执行
fake_results = [
{"tool": "pylint", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []},
{"tool": "mypy", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []},
{"tool": "black", "file": str(test_file), "returncode": 0, "stdout": "", "stderr": "", "errors": []}
]
call_count = 0
def fake_run_check(tool, file):
nonlocal call_count
call_count += 1
return fake_results[call_count - 1]
monkeypatch.setattr(checker, "run_check", fake_run_check)
results = checker.run_parallel_checks([test_file])
assert len(results) == 3
assert all(r["returncode"] == 0 for r in results)
assert call_count == 3
def test_save_results(self, checker, tmp_path):
"""测试保存检查结果"""
results = [{"tool": "pylint", "file": "file1.py", "returncode": 0}]
checker.save_results(results)
results_file = checker.output_dir / "check_results.json"
assert results_file.exists()
with open(results_file, 'r') as f:
loaded = json.load(f)
assert loaded == results
def test_collect_errors(self, checker, tmp_path):
"""测试收集错误"""
results = [
{
"tool": "pylint",
"file": "file1.py",
"returncode": 1,
"stdout": "",
"stderr": "",
"errors": ["未使用的导入"],
},
{
"tool": "mypy",
"file": "file2.py",
"returncode": 0,
"stdout": "",
"stderr": "",
"errors": [],
},
]
checker.save_results(results)
errors = checker.collect_errors()
assert len(errors) == 1
assert errors[0]["file"] == "file1.py"
assert errors[0]["tool"] == "pylint"
assert errors[0]["error"] == "未使用的导入"
def test_collect_errors_no_results(self, checker):
"""测试收集错误时无结果文件"""
errors = checker.collect_errors()
assert errors == []
def test_auto_fix(self, checker, tmp_path):
"""测试自动修复错误"""
errors = [{"file": "test.py", "tool": "pylint", "error": "未使用的导入"}]
# 文件应放在 output_dir 下
test_file = checker.output_dir / "test.py"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text("import os\nprint('hi')\n")
# 设置假的 _call_llm 返回值
fake_return = {
"patches": [{"file": "test.py", "code": "print('hi')\n"}],
"description": "移除未使用的导入",
}
checker.code_generator.return_value = fake_return
success = checker.auto_fix(errors, context_files=["test.py"])
assert success is True
with open(test_file, 'r') as f:
assert f.read() == "print('hi')\n"
assert checker.code_generator._call_llm_called is True
def test_auto_fix_no_errors(self, checker):
"""测试自动修复无错误时"""
success = checker.auto_fix([])
assert success is True
def test_run_full_check_and_fix(self, checker, monkeypatch):
"""测试完整检查与修复循环"""
# 替换相关方法,模拟行为
fake_results = []
fake_errors_1 = [{"error": "err"}]
fake_errors_2 = []
fake_fix_success = True
call_checks = 0
call_collect = 0
call_fix = 0
def fake_run_parallel_checks():
nonlocal call_checks
call_checks += 1
return fake_results
def fake_collect_errors(results=None):
nonlocal call_collect
call_collect += 1
if call_collect == 1:
return fake_errors_1
else:
return fake_errors_2
def fake_auto_fix(errors, context_files=None):
nonlocal call_fix
call_fix += 1
return fake_fix_success
monkeypatch.setattr(checker, "run_parallel_checks", fake_run_parallel_checks)
monkeypatch.setattr(checker, "collect_errors", fake_collect_errors)
monkeypatch.setattr(checker, "auto_fix", fake_auto_fix)
result = checker.run_full_check_and_fix(max_retries=2)
assert result is True
assert call_checks == 2
assert call_collect == 2
assert call_fix == 1
def test_run_full_check_and_fix_failure(self, checker, monkeypatch):
"""测试完整检查与修复循环失败"""
fake_results = []
fake_errors = [{"error": "err"}]
fake_fix_success = False
call_checks = 0
call_collect = 0
call_fix = 0
def fake_run_parallel_checks():
nonlocal call_checks
call_checks += 1
return fake_results
def fake_collect_errors(results=None):
nonlocal call_collect
call_collect += 1
return fake_errors
def fake_auto_fix(errors, context_files=None):
nonlocal call_fix
call_fix += 1
return fake_fix_success
monkeypatch.setattr(checker, "run_parallel_checks", fake_run_parallel_checks)
monkeypatch.setattr(checker, "collect_errors", fake_collect_errors)
monkeypatch.setattr(checker, "auto_fix", fake_auto_fix)
result = checker.run_full_check_and_fix(max_retries=1)
assert result is False
assert call_checks == 1
assert call_fix == 1