ocrag/tests/unit/test_commands.py

158 lines
5.4 KiB
Python

import pytest
import os
from click.testing import CliRunner
from ocrag.cli import main
from ocrag.db import RagDB
from ocrag.commands.add import collect_files
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture
def setup_test_env(tmpdir):
os.makedirs(os.path.join(tmpdir, "test_dir"))
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
f.write("def test_func():\n pass")
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
f.write("print('Hello')")
return str(tmpdir)
def test_add_command(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
assert "test_file.py" in result.output
assert "总计添加" in result.output
assert result.exit_code == 0
def test_search_command(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
add_result = runner.invoke(
main, ["add", os.path.join(setup_test_env, "test_file.py")]
)
assert add_result.exit_code == 0
result = runner.invoke(main, ["search", "Hello"])
assert "Hello" in result.output
assert "来源: " in result.output
def test_list_command(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
add_result = runner.invoke(
main, ["add", os.path.join(setup_test_env, "test_file.py")]
)
assert add_result.exit_code == 0
result = runner.invoke(main, ["list"])
assert "test_file.py" in result.output
def test_remove_wildcard(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
result = runner.invoke(main, ["remove", "*.py"])
assert result.exit_code == 0
assert "已删除" in result.output
list_result = runner.invoke(main, ["list"])
assert "test_file.py" not in list_result.output
def test_list_wildcard(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
result = runner.invoke(main, ["list", "*.py"])
assert result.exit_code == 0
assert "test_file.py" in result.output
assert "file1.py" in result.output
class TestCollectFiles:
def test_exclude_pattern(self, setup_test_env):
files = collect_files(
[setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"]
)
names = [f.name for f in files]
assert "test_file.py" not in names
assert "file1.py" in names
def test_include_pattern(self, setup_test_env):
files = collect_files(
[setup_test_env], recursive=True, include_patterns=["*.py"]
)
names = [f.name for f in files]
assert "test_file.py" in names
assert "file1.py" in names
def test_binary_skip(self, tmpdir):
with open(os.path.join(tmpdir, "binary.bin"), "wb") as f:
f.write(b"\x00\x01\x02binary")
with open(os.path.join(tmpdir, "text.py"), "w") as f:
f.write("print('hello')")
files = collect_files([str(tmpdir)], recursive=True)
names = [f.name for f in files]
assert "binary.bin" not in names
assert "text.py" in names
def test_gitignore(self, tmpdir):
os.makedirs(os.path.join(tmpdir, "src"))
with open(os.path.join(tmpdir, "src", "main.py"), "w") as f:
f.write("print('hello')")
with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f:
f.write("def test(): pass")
with open(os.path.join(tmpdir, ".gitignore"), "w") as f:
f.write("**/test_*.py\n")
files = collect_files([os.path.join(tmpdir, "src")], recursive=True)
names = [f.name for f in files]
assert "main.py" in names
assert "test_main.py" not in names
def test_no_ignore_disables_gitignore(self, setup_test_env):
with open(os.path.join(setup_test_env, ".gitignore"), "w") as f:
f.write("test_file.py\n")
files = collect_files([setup_test_env], recursive=True, use_gitignore=False)
names = [f.name for f in files]
assert "test_file.py" in names
def test_combined_filters(self, tmpdir):
os.makedirs(os.path.join(tmpdir, "lib"))
with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f:
f.write("def main(): pass")
with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f:
f.write("def test(): pass")
with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f:
f.write(b"\x00binary")
files = collect_files(
[os.path.join(tmpdir, "lib")],
recursive=True,
include_patterns=["*.py"],
exclude_patterns=["**/test_*.py"],
)
names = [f.name for f in files]
assert "main.py" in names
assert "test_main.py" not in names
assert "data.bin" not in names