158 lines
5.4 KiB
Python
158 lines
5.4 KiB
Python
import pytest
|
|
import os
|
|
from click.testing import CliRunner
|
|
from ocrag.cli import main
|
|
from ocrag.db import RagDB
|
|
from ocrag.commands.add import collect_files
|
|
|
|
|
|
@pytest.fixture
|
|
def runner():
|
|
return CliRunner()
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_test_env(tmpdir):
|
|
os.makedirs(os.path.join(tmpdir, "test_dir"))
|
|
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
|
|
f.write("def test_func():\n pass")
|
|
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
|
|
f.write("print('Hello')")
|
|
return str(tmpdir)
|
|
|
|
|
|
def test_add_command(runner, setup_test_env, tmpdir):
|
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
|
os.environ["OCRAG_DB_PATH"] = db_path
|
|
|
|
result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
|
|
assert "test_file.py" in result.output
|
|
assert "总计添加" in result.output
|
|
assert result.exit_code == 0
|
|
|
|
|
|
def test_search_command(runner, setup_test_env, tmpdir):
|
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
|
os.environ["OCRAG_DB_PATH"] = db_path
|
|
|
|
add_result = runner.invoke(
|
|
main, ["add", os.path.join(setup_test_env, "test_file.py")]
|
|
)
|
|
assert add_result.exit_code == 0
|
|
|
|
result = runner.invoke(main, ["search", "Hello"])
|
|
assert "Hello" in result.output
|
|
assert "来源: " in result.output
|
|
|
|
|
|
def test_list_command(runner, setup_test_env, tmpdir):
|
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
|
os.environ["OCRAG_DB_PATH"] = db_path
|
|
|
|
add_result = runner.invoke(
|
|
main, ["add", os.path.join(setup_test_env, "test_file.py")]
|
|
)
|
|
assert add_result.exit_code == 0
|
|
|
|
result = runner.invoke(main, ["list"])
|
|
assert "test_file.py" in result.output
|
|
|
|
|
|
def test_remove_wildcard(runner, setup_test_env, tmpdir):
|
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
|
os.environ["OCRAG_DB_PATH"] = db_path
|
|
|
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
|
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
|
|
|
|
result = runner.invoke(main, ["remove", "*.py"])
|
|
assert result.exit_code == 0
|
|
assert "已删除" in result.output
|
|
|
|
list_result = runner.invoke(main, ["list"])
|
|
assert "test_file.py" not in list_result.output
|
|
|
|
|
|
def test_list_wildcard(runner, setup_test_env, tmpdir):
|
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
|
os.environ["OCRAG_DB_PATH"] = db_path
|
|
|
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
|
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
|
|
|
|
result = runner.invoke(main, ["list", "*.py"])
|
|
assert result.exit_code == 0
|
|
assert "test_file.py" in result.output
|
|
assert "file1.py" in result.output
|
|
|
|
|
|
class TestCollectFiles:
|
|
def test_exclude_pattern(self, setup_test_env):
|
|
files = collect_files(
|
|
[setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"]
|
|
)
|
|
names = [f.name for f in files]
|
|
assert "test_file.py" not in names
|
|
assert "file1.py" in names
|
|
|
|
def test_include_pattern(self, setup_test_env):
|
|
files = collect_files(
|
|
[setup_test_env], recursive=True, include_patterns=["*.py"]
|
|
)
|
|
names = [f.name for f in files]
|
|
assert "test_file.py" in names
|
|
assert "file1.py" in names
|
|
|
|
def test_binary_skip(self, tmpdir):
|
|
with open(os.path.join(tmpdir, "binary.bin"), "wb") as f:
|
|
f.write(b"\x00\x01\x02binary")
|
|
with open(os.path.join(tmpdir, "text.py"), "w") as f:
|
|
f.write("print('hello')")
|
|
|
|
files = collect_files([str(tmpdir)], recursive=True)
|
|
names = [f.name for f in files]
|
|
assert "binary.bin" not in names
|
|
assert "text.py" in names
|
|
|
|
def test_gitignore(self, tmpdir):
|
|
os.makedirs(os.path.join(tmpdir, "src"))
|
|
with open(os.path.join(tmpdir, "src", "main.py"), "w") as f:
|
|
f.write("print('hello')")
|
|
with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f:
|
|
f.write("def test(): pass")
|
|
with open(os.path.join(tmpdir, ".gitignore"), "w") as f:
|
|
f.write("**/test_*.py\n")
|
|
|
|
files = collect_files([os.path.join(tmpdir, "src")], recursive=True)
|
|
names = [f.name for f in files]
|
|
assert "main.py" in names
|
|
assert "test_main.py" not in names
|
|
|
|
def test_no_ignore_disables_gitignore(self, setup_test_env):
|
|
with open(os.path.join(setup_test_env, ".gitignore"), "w") as f:
|
|
f.write("test_file.py\n")
|
|
|
|
files = collect_files([setup_test_env], recursive=True, use_gitignore=False)
|
|
names = [f.name for f in files]
|
|
assert "test_file.py" in names
|
|
|
|
def test_combined_filters(self, tmpdir):
|
|
os.makedirs(os.path.join(tmpdir, "lib"))
|
|
with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f:
|
|
f.write("def main(): pass")
|
|
with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f:
|
|
f.write("def test(): pass")
|
|
with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f:
|
|
f.write(b"\x00binary")
|
|
|
|
files = collect_files(
|
|
[os.path.join(tmpdir, "lib")],
|
|
recursive=True,
|
|
include_patterns=["*.py"],
|
|
exclude_patterns=["**/test_*.py"],
|
|
)
|
|
names = [f.name for f in files]
|
|
assert "main.py" in names
|
|
assert "test_main.py" not in names
|
|
assert "data.bin" not in names
|