312 lines
12 KiB
Python
312 lines
12 KiB
Python
import json
|
||
import subprocess
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
|
||
from src.llm_codegen.core import CodeGenerator
|
||
from src.llm_codegen.models import DesignModel
|
||
|
||
|
||
# ---------- Fake 类 ----------
|
||
class FakeChatCompletion:
|
||
"""模拟 OpenAI 的 chat.completions.create 返回值"""
|
||
def __init__(self, content):
|
||
self.choices = [FakeChoice(FakeMessage(content))]
|
||
|
||
class FakeChoice:
|
||
def __init__(self, message):
|
||
self.message = message
|
||
|
||
class FakeMessage:
|
||
def __init__(self, content):
|
||
self.content = content
|
||
self.reasoning_content = None
|
||
|
||
|
||
class FakeOpenAIClient:
|
||
"""假的 OpenAI 客户端,用于替换真实客户端"""
|
||
def __init__(self):
|
||
self.chat = FakeChat()
|
||
|
||
class FakeChat:
|
||
def __init__(self):
|
||
self.completions = FakeCompletions()
|
||
|
||
class FakeCompletions:
|
||
def __init__(self):
|
||
self.create_called = False
|
||
self.create_kwargs = None
|
||
self.create_return_value = None
|
||
|
||
def create(self, *args, **kwargs):
|
||
self.create_called = True
|
||
self.create_kwargs = kwargs
|
||
if self.create_return_value is None:
|
||
return FakeChatCompletion(json.dumps({"content": "default"}))
|
||
return self.create_return_value
|
||
|
||
|
||
# ---------- Fixtures ----------
|
||
@pytest.fixture
|
||
def fake_openai_client(monkeypatch):
|
||
"""用假的 OpenAI 客户端替换真实的客户端"""
|
||
fake_client = FakeOpenAIClient()
|
||
monkeypatch.setattr("src.llm_codegen.core.OpenAI", lambda *args, **kwargs: fake_client)
|
||
return fake_client
|
||
|
||
|
||
@pytest.fixture
|
||
def code_generator(tmp_path, monkeypatch, fake_openai_client):
|
||
"""创建 CodeGenerator 实例,使用临时输出目录,并设置环境变量"""
|
||
monkeypatch.setenv("DEEPSEEK_APIKEY", "fake-api-key")
|
||
generator = CodeGenerator(output_dir=str(tmp_path / "test_output"))
|
||
return generator
|
||
|
||
|
||
# ---------- 测试类 ----------
|
||
class TestCodeGenerator:
|
||
"""测试 CodeGenerator 类(无 mock)"""
|
||
|
||
def test_init_success(self, code_generator, tmp_path, fake_openai_client):
|
||
"""测试初始化成功"""
|
||
assert code_generator.api_key == "fake-api-key"
|
||
assert code_generator.model == "deepseek-reasoner"
|
||
assert code_generator.output_dir == tmp_path / "test_output"
|
||
# 验证客户端被替换为我们的 fake 客户端
|
||
assert code_generator.client is fake_openai_client
|
||
|
||
def test_init_no_api_key(self, monkeypatch):
|
||
"""测试没有 API 密钥时抛出错误"""
|
||
monkeypatch.delenv("DEEPSEEK_APIKEY", raising=False)
|
||
with pytest.raises(ValueError, match="必须提供API密钥"):
|
||
CodeGenerator()
|
||
|
||
def test_parse_readme_success(self, code_generator, tmp_path):
|
||
"""测试解析 README 文件成功"""
|
||
readme_path = tmp_path / "README.md"
|
||
readme_path.write_text("# Test README\nThis is a test.")
|
||
content = code_generator.parse_readme(readme_path)
|
||
assert content == "# Test README\nThis is a test."
|
||
|
||
def test_parse_readme_file_not_found(self, code_generator):
|
||
"""测试 README 文件不存在时抛出错误"""
|
||
with pytest.raises(Exception):
|
||
code_generator.parse_readme(Path("nonexistent.md"))
|
||
|
||
def test_generate_design_json(self, code_generator, monkeypatch):
|
||
"""测试生成 design.json"""
|
||
code_generator.readme_content = "# Test Project\nA test project."
|
||
|
||
# 模拟 _call_llm 的返回值
|
||
mock_response = {
|
||
"project_name": "test-project",
|
||
"version": "1.0.0",
|
||
"description": "A test project",
|
||
"files": [],
|
||
"commands": [],
|
||
"check_tools": []
|
||
}
|
||
|
||
def fake_call_llm(system_prompt, user_prompt, temperature=0.2, expect_json=True):
|
||
return mock_response
|
||
|
||
monkeypatch.setattr(code_generator, "_call_llm", fake_call_llm)
|
||
|
||
design = code_generator.generate_design_json()
|
||
|
||
assert isinstance(design, DesignModel)
|
||
assert design.project_name == "test-project"
|
||
# 验证文件已写入
|
||
design_path = code_generator.output_dir / "design.json"
|
||
assert design_path.exists()
|
||
with open(design_path) as f:
|
||
saved = json.load(f)
|
||
assert saved["project_name"] == "test-project"
|
||
|
||
def test_generate_file_with_dependencies(self, code_generator, monkeypatch, tmp_path):
|
||
"""测试生成文件,有依赖文件"""
|
||
# 创建依赖文件
|
||
dep_path = tmp_path / "dep.py"
|
||
dep_path.write_text("# Dependency file")
|
||
code_generator.output_dir = tmp_path
|
||
code_generator.readme_content = "# README"
|
||
|
||
# 模拟 _call_llm 的返回值
|
||
llm_response = {
|
||
"code": "print('Hello, world!')",
|
||
"description": "测试文件",
|
||
"commands": []
|
||
}
|
||
|
||
def fake_call_llm(system_prompt, user_prompt, temperature=0.2, expect_json=True):
|
||
return llm_response
|
||
|
||
monkeypatch.setattr(code_generator, "_call_llm", fake_call_llm)
|
||
|
||
code, desc, commands = code_generator.generate_file(
|
||
file_path="test.py",
|
||
prompt_instruction="生成测试文件",
|
||
dependency_files=[str(dep_path)]
|
||
)
|
||
|
||
assert code == "print('Hello, world!')"
|
||
assert desc == "测试文件"
|
||
assert commands == []
|
||
|
||
def test_execute_command_success(self, code_generator, monkeypatch):
|
||
"""测试执行命令成功"""
|
||
def fake_run(cmd, *args, **kwargs):
|
||
return subprocess.CompletedProcess(args=cmd, returncode=0, stdout="", stderr="")
|
||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||
|
||
success = code_generator.execute_command("echo test")
|
||
assert success is True
|
||
|
||
def test_execute_command_dangerous(self, code_generator, monkeypatch):
|
||
"""测试阻止危险命令"""
|
||
def fake_dangerous(cmd):
|
||
return (True, "包含危险关键词")
|
||
monkeypatch.setattr("src.llm_codegen.core.is_dangerous_command", fake_dangerous)
|
||
|
||
success = code_generator.execute_command("rm -rf /")
|
||
assert success is False
|
||
|
||
def test_execute_command_failure(self, code_generator, monkeypatch):
|
||
"""测试命令执行失败"""
|
||
def fake_run(cmd, *args, **kwargs):
|
||
return subprocess.CompletedProcess(args=cmd, returncode=1, stdout="", stderr="")
|
||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||
|
||
success = code_generator.execute_command("false")
|
||
assert success is False
|
||
|
||
def test_run_with_state_resume(self, code_generator, monkeypatch, tmp_path):
|
||
"""测试断点续写"""
|
||
# 创建状态文件
|
||
state_file = tmp_path / ".llm_generator_state.json"
|
||
state_data = {
|
||
"current_file_index": 1,
|
||
"generated_files": ["file1.py"],
|
||
"dependencies_map": {},
|
||
"total_files": 3,
|
||
"output_dir": str(tmp_path),
|
||
"readme_path": "test"
|
||
}
|
||
state_file.write_text(json.dumps(state_data))
|
||
|
||
# 创建设计文件
|
||
design_path = tmp_path / "design.json"
|
||
design_data = {
|
||
"project_name": "test",
|
||
"version": "1.0.0",
|
||
"description": "test",
|
||
"files": [
|
||
{"path": "file1.py", "summary": "", "dependencies": [], "functions": [], "classes": []},
|
||
{"path": "file2.py", "summary": "", "dependencies": [], "functions": [], "classes": []},
|
||
{"path": "file3.py", "summary": "", "dependencies": [], "functions": [], "classes": []}
|
||
],
|
||
"commands": [],
|
||
"check_tools": []
|
||
}
|
||
design_path.write_text(json.dumps(design_data))
|
||
|
||
code_generator.output_dir = tmp_path
|
||
code_generator.state_file = state_file
|
||
|
||
# 模拟内部方法
|
||
def fake_parse_readme(path):
|
||
return "# README"
|
||
monkeypatch.setattr(code_generator, "parse_readme", fake_parse_readme)
|
||
|
||
def fake_generate_file(file_path, prompt_instruction, dependency_files):
|
||
return ("code", "desc", [])
|
||
monkeypatch.setattr(code_generator, "generate_file", fake_generate_file)
|
||
|
||
def fake_execute_command(cmd, cwd=None):
|
||
return True
|
||
monkeypatch.setattr(code_generator, "execute_command", fake_execute_command)
|
||
|
||
# 运行,预期不抛出异常
|
||
code_generator.run(Path(tmp_path / "README.md"))
|
||
|
||
# 验证状态文件被清理
|
||
assert not state_file.exists()
|
||
|
||
def test_run_without_state(self, code_generator, monkeypatch, tmp_path):
|
||
"""测试没有状态时的首次运行"""
|
||
code_generator.output_dir = tmp_path
|
||
|
||
# 模拟 parse_readme
|
||
def fake_parse_readme(path):
|
||
return "# README"
|
||
monkeypatch.setattr(code_generator, "parse_readme", fake_parse_readme)
|
||
|
||
# 模拟 generate_design_json 返回设计
|
||
fake_design = DesignModel(
|
||
project_name="test",
|
||
version="1.0.0",
|
||
description="test",
|
||
files=[], # 无文件,简化流程
|
||
commands=[],
|
||
check_tools=[]
|
||
)
|
||
def fake_generate_design_json():
|
||
return fake_design
|
||
monkeypatch.setattr(code_generator, "generate_design_json", fake_generate_design_json)
|
||
|
||
# 模拟 get_project_structure
|
||
def fake_get_project_structure():
|
||
return [], {}
|
||
monkeypatch.setattr(code_generator, "get_project_structure", fake_get_project_structure)
|
||
|
||
# 运行,预期不抛出异常
|
||
code_generator.run(Path(tmp_path / "README.md"))
|
||
|
||
def test_topological_sort_normal(self, code_generator):
|
||
"""测试拓扑排序正常依赖排序"""
|
||
files = ["a", "b", "c"]
|
||
dependencies = {"a": ["b"], "b": ["c"], "c": []}
|
||
result = code_generator._topological_sort(files, dependencies)
|
||
# 验证排序顺序正确,c 在 b 之前,b 在 a 之前
|
||
assert result == ["c", "b", "a"] # 根据实现,顺序是确定的
|
||
# 验证每个文件的依赖都在其之前
|
||
for i, node in enumerate(result):
|
||
deps = dependencies.get(node, [])
|
||
for dep in deps:
|
||
assert result.index(dep) < i
|
||
|
||
def test_topological_sort_cycle_detection(self, code_generator):
|
||
"""测试拓扑排序循环依赖检测"""
|
||
files = ["a", "b"]
|
||
dependencies = {"a": ["b"], "b": ["a"]}
|
||
with pytest.raises(ValueError, match="检测到循环依赖"):
|
||
code_generator._topological_sort(files, dependencies)
|
||
|
||
def test_topological_sort_empty(self, code_generator):
|
||
"""测试拓扑排序空输入"""
|
||
files = []
|
||
dependencies = {}
|
||
result = code_generator._topological_sort(files, dependencies)
|
||
assert result == []
|
||
|
||
def test_topological_sort_partial_deps(self, code_generator):
|
||
"""测试拓扑排序部分依赖不在列表中"""
|
||
files = ["a", "b"]
|
||
dependencies = {"a": ["b", "c"], "b": []} # c 不在 files 中
|
||
result = code_generator._topological_sort(files, dependencies)
|
||
# c 被忽略,因为不在 in_degree 中,排序应基于 b 依赖
|
||
assert result == ["b", "a"]
|
||
|
||
def test_topological_sort_complex(self, code_generator):
|
||
"""测试拓扑排序复杂依赖关系"""
|
||
files = ["a", "b", "c", "d"]
|
||
dependencies = {"a": ["b", "c"], "b": ["d"], "c": ["d"], "d": []}
|
||
result = code_generator._topological_sort(files, dependencies)
|
||
# 验证排序结果满足所有依赖
|
||
for node in result:
|
||
deps = dependencies.get(node, [])
|
||
for dep in deps:
|
||
assert result.index(dep) < result.index(node)
|
||
# 验证所有文件都在结果中
|
||
assert set(result) == set(files) |