185 lines
7.5 KiB
Python
185 lines
7.5 KiB
Python
import pytest
|
||
from unittest.mock import Mock, patch, MagicMock
|
||
from pathlib import Path
|
||
import json
|
||
import os
|
||
import sys
|
||
from datetime import datetime
|
||
|
||
from src.llm_codegen.core import CodeGenerator
|
||
from src.llm_codegen.models import GeneratorState, ConfigModel
|
||
from src.llm_codegen.utils import is_dangerous_command
|
||
|
||
|
||
class TestCodeGenerator:
|
||
"""测试 CodeGenerator 核心类的单元测试。"""
|
||
|
||
@pytest.fixture
|
||
def mock_openai_client(self):
|
||
"""模拟 OpenAI 客户端。"""
|
||
with patch('src.llm_codegen.core.OpenAI') as mock:
|
||
client = Mock()
|
||
mock.return_value = client
|
||
yield client
|
||
|
||
@pytest.fixture
|
||
def generator(self, mock_openai_client, tmp_path):
|
||
"""创建 CodeGenerator 实例,使用临时目录和模拟 API。"""
|
||
output_dir = tmp_path / "output"
|
||
return CodeGenerator(
|
||
api_key="test-api-key",
|
||
base_url="https://api.deepseek.com",
|
||
model="deepseek-reasoner",
|
||
output_dir=str(output_dir),
|
||
resume=False,
|
||
)
|
||
|
||
def test_init(self, generator, tmp_path):
|
||
"""测试初始化。"""
|
||
assert generator.api_key == "test-api-key"
|
||
assert generator.model == "deepseek-reasoner"
|
||
assert generator.output_dir == tmp_path / "output"
|
||
assert generator.resume is False
|
||
assert isinstance(generator.config, ConfigModel)
|
||
assert generator.dangerous_commands == ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||
assert generator.state is None
|
||
|
||
def test_parse_readme(self, generator, tmp_path):
|
||
"""测试读取 README 文件。"""
|
||
readme_path = tmp_path / "README.md"
|
||
readme_content = "# Test Project\nThis is a test README."
|
||
readme_path.write_text(readme_content)
|
||
|
||
result = generator.parse_readme(readme_path)
|
||
assert result == readme_content
|
||
|
||
def test_get_project_structure(self, generator, mock_openai_client):
|
||
"""测试获取项目结构,模拟 LLM 响应。"""
|
||
generator.readme_content = "# Test README"
|
||
mock_response = {
|
||
"files": ["src/__init__.py", "src/core.py"],
|
||
"dependencies": {"src/core.py": ["src/__init__.py"]}
|
||
}
|
||
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
|
||
|
||
files, dependencies = generator.get_project_structure()
|
||
assert files == ["src/__init__.py", "src/core.py"]
|
||
assert dependencies == {"src/core.py": ["src/__init__.py"]}
|
||
mock_openai_client.chat.completions.create.assert_called_once()
|
||
|
||
def test_generate_file(self, generator, mock_openai_client, tmp_path):
|
||
"""测试生成单个文件,模拟依赖文件和 LLM 响应。"""
|
||
generator.readme_content = "# Test README"
|
||
dep_file = tmp_path / "dep.txt"
|
||
dep_file.write_text("Dependency content")
|
||
|
||
mock_response = {
|
||
"code": "print('Hello, World!')",
|
||
"description": "测试文件生成",
|
||
"commands": ["echo 'test'"]
|
||
}
|
||
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
|
||
|
||
code, desc, commands = generator.generate_file(
|
||
"test.py",
|
||
"生成测试文件",
|
||
[str(dep_file)]
|
||
)
|
||
assert code == "print('Hello, World!')"
|
||
assert desc == "测试文件生成"
|
||
assert commands == ["echo 'test'"]
|
||
mock_openai_client.chat.completions.create.assert_called_once()
|
||
|
||
def test_execute_command_safe(self, generator, tmp_path):
|
||
"""测试执行安全命令。"""
|
||
with patch('subprocess.run') as mock_run:
|
||
mock_run.return_value.returncode = 0
|
||
mock_run.return_value.stdout = "output"
|
||
mock_run.return_value.stderr = ""
|
||
|
||
generator.execute_command("echo 'test'", cwd=tmp_path)
|
||
mock_run.assert_called_once_with(
|
||
"echo 'test'",
|
||
shell=True,
|
||
cwd=tmp_path,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=300
|
||
)
|
||
|
||
def test_execute_command_dangerous(self, generator):
|
||
"""测试阻止危险命令。"""
|
||
with pytest.raises(RuntimeError, match="危险命令"):
|
||
generator.execute_command("rm -rf /") # 假设在危险命令列表中
|
||
|
||
def test_run_without_resume(self, generator, mock_openai_client, tmp_path):
|
||
"""测试完整运行流程,禁用断点续写。"""
|
||
readme_path = tmp_path / "README.md"
|
||
readme_path.write_text("# Test README")
|
||
generator.readme_content = "# Test README"
|
||
|
||
# 模拟 get_project_structure 响应
|
||
mock_structure = {
|
||
"files": ["file1.py", "file2.py"],
|
||
"dependencies": {}
|
||
}
|
||
mock_openai_client.chat.completions.create.side_effect = [
|
||
Mock(choices=[Mock(message=Mock(content=json.dumps(mock_structure)))]),
|
||
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code1", "description": "desc1", "commands": []})))]),
|
||
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code2", "description": "desc2", "commands": []})))])
|
||
]
|
||
|
||
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
|
||
patch('src.llm_codegen.core.safe_read_file') as mock_read, \
|
||
patch('src.llm_codegen.core.save_state') as mock_save:
|
||
mock_read.return_value = "content"
|
||
generator.run(readme_path)
|
||
|
||
# 验证文件生成和状态保存
|
||
assert mock_write.call_count == 2
|
||
assert mock_save.called
|
||
|
||
def test_run_with_resume(self, generator, mock_openai_client, tmp_path):
|
||
"""测试断点续写功能。"""
|
||
generator.resume = True
|
||
generator.state = GeneratorState(generated_files=["file1.py"], executed_commands=[])
|
||
readme_path = tmp_path / "README.md"
|
||
readme_path.write_text("# Test README")
|
||
generator.readme_content = "# Test README"
|
||
|
||
mock_structure = {
|
||
"files": ["file1.py", "file2.py"],
|
||
"dependencies": {}
|
||
}
|
||
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_structure)
|
||
|
||
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
|
||
patch('src.llm_codegen.core.safe_read_file') as mock_read:
|
||
mock_read.return_value = "content"
|
||
generator.run(readme_path)
|
||
|
||
# 只应生成 file2.py,跳过 file1.py
|
||
assert mock_write.call_count == 1
|
||
|
||
def test_load_config_default(self, generator):
|
||
"""测试加载默认配置。"""
|
||
config = generator._load_config()
|
||
assert isinstance(config, ConfigModel)
|
||
assert config.check_tools == ["pytest", "pylint", "mypy", "black"]
|
||
assert config.max_retries == 3
|
||
|
||
def test_update_state(self, generator, tmp_path):
|
||
"""测试更新状态文件。"""
|
||
generator.state_path = tmp_path / "state.json"
|
||
generator.state = GeneratorState()
|
||
|
||
with patch('src.llm_codegen.core.save_state') as mock_save:
|
||
generator._update_state("new_file.py", ["cmd1"])
|
||
assert generator.state.generated_files == ["new_file.py"]
|
||
assert generator.state.executed_commands == ["cmd1"]
|
||
mock_save.assert_called_once()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|