ocrag/tests/unit/test_commands.py

64 lines
1.8 KiB
Python

import pytest
import os
from click.testing import CliRunner
from ocrag.cli import main
from ocrag.db import RagDB
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture
def setup_test_env(tmpdir):
# Create test files
os.makedirs(os.path.join(tmpdir, "test_dir"))
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
f.write("def test_func():\n pass")
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
f.write("print('Hello')")
return tmpdir
def test_add_command(runner, setup_test_env, tmpdir):
# Use temporary DB
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
assert "test_file.py" in result.output
assert "总计添加" in result.output
assert result.exit_code == 0
def test_search_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
add_result = runner.invoke(
main, ["add", os.path.join(setup_test_env, "test_file.py")]
)
assert add_result.exit_code == 0
# Search
result = runner.invoke(main, ["search", "Hello"])
assert "Hello" in result.output
assert "来源: " in result.output
def test_list_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
add_result = runner.invoke(
main, ["add", os.path.join(setup_test_env, "test_file.py")]
)
assert add_result.exit_code == 0
# List
result = runner.invoke(main, ["list"])
assert "test_file.py" in result.output