llmcodegen/tests/test_core.py

185 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
import json
import os
import sys
from datetime import datetime
from src.llm_codegen.core import CodeGenerator
from src.llm_codegen.models import GeneratorState, ConfigModel
from src.llm_codegen.utils import is_dangerous_command
class TestCodeGenerator:
"""测试 CodeGenerator 核心类的单元测试。"""
@pytest.fixture
def mock_openai_client(self):
"""模拟 OpenAI 客户端。"""
with patch('src.llm_codegen.core.OpenAI') as mock:
client = Mock()
mock.return_value = client
yield client
@pytest.fixture
def generator(self, mock_openai_client, tmp_path):
"""创建 CodeGenerator 实例,使用临时目录和模拟 API。"""
output_dir = tmp_path / "output"
return CodeGenerator(
api_key="test-api-key",
base_url="https://api.deepseek.com",
model="deepseek-reasoner",
output_dir=str(output_dir),
resume=False,
)
def test_init(self, generator, tmp_path):
"""测试初始化。"""
assert generator.api_key == "test-api-key"
assert generator.model == "deepseek-reasoner"
assert generator.output_dir == tmp_path / "output"
assert generator.resume is False
assert isinstance(generator.config, ConfigModel)
assert generator.dangerous_commands == ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
assert generator.state is None
def test_parse_readme(self, generator, tmp_path):
"""测试读取 README 文件。"""
readme_path = tmp_path / "README.md"
readme_content = "# Test Project\nThis is a test README."
readme_path.write_text(readme_content)
result = generator.parse_readme(readme_path)
assert result == readme_content
def test_get_project_structure(self, generator, mock_openai_client):
"""测试获取项目结构,模拟 LLM 响应。"""
generator.readme_content = "# Test README"
mock_response = {
"files": ["src/__init__.py", "src/core.py"],
"dependencies": {"src/core.py": ["src/__init__.py"]}
}
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
files, dependencies = generator.get_project_structure()
assert files == ["src/__init__.py", "src/core.py"]
assert dependencies == {"src/core.py": ["src/__init__.py"]}
mock_openai_client.chat.completions.create.assert_called_once()
def test_generate_file(self, generator, mock_openai_client, tmp_path):
"""测试生成单个文件,模拟依赖文件和 LLM 响应。"""
generator.readme_content = "# Test README"
dep_file = tmp_path / "dep.txt"
dep_file.write_text("Dependency content")
mock_response = {
"code": "print('Hello, World!')",
"description": "测试文件生成",
"commands": ["echo 'test'"]
}
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_response)
code, desc, commands = generator.generate_file(
"test.py",
"生成测试文件",
[str(dep_file)]
)
assert code == "print('Hello, World!')"
assert desc == "测试文件生成"
assert commands == ["echo 'test'"]
mock_openai_client.chat.completions.create.assert_called_once()
def test_execute_command_safe(self, generator, tmp_path):
"""测试执行安全命令。"""
with patch('subprocess.run') as mock_run:
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = "output"
mock_run.return_value.stderr = ""
generator.execute_command("echo 'test'", cwd=tmp_path)
mock_run.assert_called_once_with(
"echo 'test'",
shell=True,
cwd=tmp_path,
capture_output=True,
text=True,
timeout=300
)
def test_execute_command_dangerous(self, generator):
"""测试阻止危险命令。"""
with pytest.raises(RuntimeError, match="危险命令"):
generator.execute_command("rm -rf /") # 假设在危险命令列表中
def test_run_without_resume(self, generator, mock_openai_client, tmp_path):
"""测试完整运行流程,禁用断点续写。"""
readme_path = tmp_path / "README.md"
readme_path.write_text("# Test README")
generator.readme_content = "# Test README"
# 模拟 get_project_structure 响应
mock_structure = {
"files": ["file1.py", "file2.py"],
"dependencies": {}
}
mock_openai_client.chat.completions.create.side_effect = [
Mock(choices=[Mock(message=Mock(content=json.dumps(mock_structure)))]),
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code1", "description": "desc1", "commands": []})))]),
Mock(choices=[Mock(message=Mock(content=json.dumps({"code": "code2", "description": "desc2", "commands": []})))])
]
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
patch('src.llm_codegen.core.safe_read_file') as mock_read, \
patch('src.llm_codegen.core.save_state') as mock_save:
mock_read.return_value = "content"
generator.run(readme_path)
# 验证文件生成和状态保存
assert mock_write.call_count == 2
assert mock_save.called
def test_run_with_resume(self, generator, mock_openai_client, tmp_path):
"""测试断点续写功能。"""
generator.resume = True
generator.state = GeneratorState(generated_files=["file1.py"], executed_commands=[])
readme_path = tmp_path / "README.md"
readme_path.write_text("# Test README")
generator.readme_content = "# Test README"
mock_structure = {
"files": ["file1.py", "file2.py"],
"dependencies": {}
}
mock_openai_client.chat.completions.create.return_value.choices[0].message.content = json.dumps(mock_structure)
with patch('src.llm_codegen.core.safe_write_file') as mock_write, \
patch('src.llm_codegen.core.safe_read_file') as mock_read:
mock_read.return_value = "content"
generator.run(readme_path)
# 只应生成 file2.py跳过 file1.py
assert mock_write.call_count == 1
def test_load_config_default(self, generator):
"""测试加载默认配置。"""
config = generator._load_config()
assert isinstance(config, ConfigModel)
assert config.check_tools == ["pytest", "pylint", "mypy", "black"]
assert config.max_retries == 3
def test_update_state(self, generator, tmp_path):
"""测试更新状态文件。"""
generator.state_path = tmp_path / "state.json"
generator.state = GeneratorState()
with patch('src.llm_codegen.core.save_state') as mock_save:
generator._update_state("new_file.py", ["cmd1"])
assert generator.state.generated_files == ["new_file.py"]
assert generator.state.executed_commands == ["cmd1"]
mock_save.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])