llmcodegen/tests/test_core.py

312 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import subprocess
from pathlib import Path
import pytest
from src.llm_codegen.core import CodeGenerator
from src.llm_codegen.models import DesignModel
# ---------- Fake 类 ----------
class FakeChatCompletion:
"""模拟 OpenAI 的 chat.completions.create 返回值"""
def __init__(self, content):
self.choices = [FakeChoice(FakeMessage(content))]
class FakeChoice:
def __init__(self, message):
self.message = message
class FakeMessage:
def __init__(self, content):
self.content = content
self.reasoning_content = None
class FakeOpenAIClient:
"""假的 OpenAI 客户端,用于替换真实客户端"""
def __init__(self):
self.chat = FakeChat()
class FakeChat:
def __init__(self):
self.completions = FakeCompletions()
class FakeCompletions:
def __init__(self):
self.create_called = False
self.create_kwargs = None
self.create_return_value = None
def create(self, *args, **kwargs):
self.create_called = True
self.create_kwargs = kwargs
if self.create_return_value is None:
return FakeChatCompletion(json.dumps({"content": "default"}))
return self.create_return_value
# ---------- Fixtures ----------
@pytest.fixture
def fake_openai_client(monkeypatch):
"""用假的 OpenAI 客户端替换真实的客户端"""
fake_client = FakeOpenAIClient()
monkeypatch.setattr("src.llm_codegen.core.OpenAI", lambda *args, **kwargs: fake_client)
return fake_client
@pytest.fixture
def code_generator(tmp_path, monkeypatch, fake_openai_client):
"""创建 CodeGenerator 实例,使用临时输出目录,并设置环境变量"""
monkeypatch.setenv("DEEPSEEK_APIKEY", "fake-api-key")
generator = CodeGenerator(output_dir=str(tmp_path / "test_output"))
return generator
# ---------- 测试类 ----------
class TestCodeGenerator:
"""测试 CodeGenerator 类(无 mock"""
def test_init_success(self, code_generator, tmp_path, fake_openai_client):
"""测试初始化成功"""
assert code_generator.api_key == "fake-api-key"
assert code_generator.model == "deepseek-reasoner"
assert code_generator.output_dir == tmp_path / "test_output"
# 验证客户端被替换为我们的 fake 客户端
assert code_generator.client is fake_openai_client
def test_init_no_api_key(self, monkeypatch):
"""测试没有 API 密钥时抛出错误"""
monkeypatch.delenv("DEEPSEEK_APIKEY", raising=False)
with pytest.raises(ValueError, match="必须提供API密钥"):
CodeGenerator()
def test_parse_readme_success(self, code_generator, tmp_path):
"""测试解析 README 文件成功"""
readme_path = tmp_path / "README.md"
readme_path.write_text("# Test README\nThis is a test.")
content = code_generator.parse_readme(readme_path)
assert content == "# Test README\nThis is a test."
def test_parse_readme_file_not_found(self, code_generator):
"""测试 README 文件不存在时抛出错误"""
with pytest.raises(Exception):
code_generator.parse_readme(Path("nonexistent.md"))
def test_generate_design_json(self, code_generator, monkeypatch):
"""测试生成 design.json"""
code_generator.readme_content = "# Test Project\nA test project."
# 模拟 _call_llm 的返回值
mock_response = {
"project_name": "test-project",
"version": "1.0.0",
"description": "A test project",
"files": [],
"commands": [],
"check_tools": []
}
def fake_call_llm(system_prompt, user_prompt, temperature=0.2, expect_json=True):
return mock_response
monkeypatch.setattr(code_generator, "_call_llm", fake_call_llm)
design = code_generator.generate_design_json()
assert isinstance(design, DesignModel)
assert design.project_name == "test-project"
# 验证文件已写入
design_path = code_generator.output_dir / "design.json"
assert design_path.exists()
with open(design_path) as f:
saved = json.load(f)
assert saved["project_name"] == "test-project"
def test_generate_file_with_dependencies(self, code_generator, monkeypatch, tmp_path):
"""测试生成文件,有依赖文件"""
# 创建依赖文件
dep_path = tmp_path / "dep.py"
dep_path.write_text("# Dependency file")
code_generator.output_dir = tmp_path
code_generator.readme_content = "# README"
# 模拟 _call_llm 的返回值
llm_response = {
"code": "print('Hello, world!')",
"description": "测试文件",
"commands": []
}
def fake_call_llm(system_prompt, user_prompt, temperature=0.2, expect_json=True):
return llm_response
monkeypatch.setattr(code_generator, "_call_llm", fake_call_llm)
code, desc, commands = code_generator.generate_file(
file_path="test.py",
prompt_instruction="生成测试文件",
dependency_files=[str(dep_path)]
)
assert code == "print('Hello, world!')"
assert desc == "测试文件"
assert commands == []
def test_execute_command_success(self, code_generator, monkeypatch):
"""测试执行命令成功"""
def fake_run(cmd, *args, **kwargs):
return subprocess.CompletedProcess(args=cmd, returncode=0, stdout="", stderr="")
monkeypatch.setattr(subprocess, "run", fake_run)
success = code_generator.execute_command("echo test")
assert success is True
def test_execute_command_dangerous(self, code_generator, monkeypatch):
"""测试阻止危险命令"""
def fake_dangerous(cmd):
return (True, "包含危险关键词")
monkeypatch.setattr("src.llm_codegen.core.is_dangerous_command", fake_dangerous)
success = code_generator.execute_command("rm -rf /")
assert success is False
def test_execute_command_failure(self, code_generator, monkeypatch):
"""测试命令执行失败"""
def fake_run(cmd, *args, **kwargs):
return subprocess.CompletedProcess(args=cmd, returncode=1, stdout="", stderr="")
monkeypatch.setattr(subprocess, "run", fake_run)
success = code_generator.execute_command("false")
assert success is False
def test_run_with_state_resume(self, code_generator, monkeypatch, tmp_path):
"""测试断点续写"""
# 创建状态文件
state_file = tmp_path / ".llm_generator_state.json"
state_data = {
"current_file_index": 1,
"generated_files": ["file1.py"],
"dependencies_map": {},
"total_files": 3,
"output_dir": str(tmp_path),
"readme_path": "test"
}
state_file.write_text(json.dumps(state_data))
# 创建设计文件
design_path = tmp_path / "design.json"
design_data = {
"project_name": "test",
"version": "1.0.0",
"description": "test",
"files": [
{"path": "file1.py", "summary": "", "dependencies": [], "functions": [], "classes": []},
{"path": "file2.py", "summary": "", "dependencies": [], "functions": [], "classes": []},
{"path": "file3.py", "summary": "", "dependencies": [], "functions": [], "classes": []}
],
"commands": [],
"check_tools": []
}
design_path.write_text(json.dumps(design_data))
code_generator.output_dir = tmp_path
code_generator.state_file = state_file
# 模拟内部方法
def fake_parse_readme(path):
return "# README"
monkeypatch.setattr(code_generator, "parse_readme", fake_parse_readme)
def fake_generate_file(file_path, prompt_instruction, dependency_files):
return ("code", "desc", [])
monkeypatch.setattr(code_generator, "generate_file", fake_generate_file)
def fake_execute_command(cmd, cwd=None):
return True
monkeypatch.setattr(code_generator, "execute_command", fake_execute_command)
# 运行,预期不抛出异常
code_generator.run(Path(tmp_path / "README.md"))
# 验证状态文件被清理
assert not state_file.exists()
def test_run_without_state(self, code_generator, monkeypatch, tmp_path):
"""测试没有状态时的首次运行"""
code_generator.output_dir = tmp_path
# 模拟 parse_readme
def fake_parse_readme(path):
return "# README"
monkeypatch.setattr(code_generator, "parse_readme", fake_parse_readme)
# 模拟 generate_design_json 返回设计
fake_design = DesignModel(
project_name="test",
version="1.0.0",
description="test",
files=[], # 无文件,简化流程
commands=[],
check_tools=[]
)
def fake_generate_design_json():
return fake_design
monkeypatch.setattr(code_generator, "generate_design_json", fake_generate_design_json)
# 模拟 get_project_structure
def fake_get_project_structure():
return [], {}
monkeypatch.setattr(code_generator, "get_project_structure", fake_get_project_structure)
# 运行,预期不抛出异常
code_generator.run(Path(tmp_path / "README.md"))
def test_topological_sort_normal(self, code_generator):
"""测试拓扑排序正常依赖排序"""
files = ["a", "b", "c"]
dependencies = {"a": ["b"], "b": ["c"], "c": []}
result = code_generator._topological_sort(files, dependencies)
# 验证排序顺序正确c 在 b 之前b 在 a 之前
assert result == ["c", "b", "a"] # 根据实现,顺序是确定的
# 验证每个文件的依赖都在其之前
for i, node in enumerate(result):
deps = dependencies.get(node, [])
for dep in deps:
assert result.index(dep) < i
def test_topological_sort_cycle_detection(self, code_generator):
"""测试拓扑排序循环依赖检测"""
files = ["a", "b"]
dependencies = {"a": ["b"], "b": ["a"]}
with pytest.raises(ValueError, match="检测到循环依赖"):
code_generator._topological_sort(files, dependencies)
def test_topological_sort_empty(self, code_generator):
"""测试拓扑排序空输入"""
files = []
dependencies = {}
result = code_generator._topological_sort(files, dependencies)
assert result == []
def test_topological_sort_partial_deps(self, code_generator):
"""测试拓扑排序部分依赖不在列表中"""
files = ["a", "b"]
dependencies = {"a": ["b", "c"], "b": []} # c 不在 files 中
result = code_generator._topological_sort(files, dependencies)
# c 被忽略,因为不在 in_degree 中,排序应基于 b 依赖
assert result == ["b", "a"]
def test_topological_sort_complex(self, code_generator):
"""测试拓扑排序复杂依赖关系"""
files = ["a", "b", "c", "d"]
dependencies = {"a": ["b", "c"], "b": ["d"], "c": ["d"], "d": []}
result = code_generator._topological_sort(files, dependencies)
# 验证排序结果满足所有依赖
for node in result:
deps = dependencies.get(node, [])
for dep in deps:
assert result.index(dep) < result.index(node)
# 验证所有文件都在结果中
assert set(result) == set(files)