import pytest import os from click.testing import CliRunner from ocrag.cli import main from ocrag.db import RagDB from ocrag.commands.add import collect_files @pytest.fixture def runner(): return CliRunner() @pytest.fixture def setup_test_env(tmpdir): os.makedirs(os.path.join(tmpdir, "test_dir")) with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f: f.write("def test_func():\n pass") with open(os.path.join(tmpdir, "test_file.py"), "w") as f: f.write("print('Hello')") return str(tmpdir) def test_add_command(runner, setup_test_env, tmpdir): db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")]) assert "test_file.py" in result.output assert "总计添加" in result.output assert result.exit_code == 0 def test_search_command(runner, setup_test_env, tmpdir): db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path add_result = runner.invoke( main, ["add", os.path.join(setup_test_env, "test_file.py")] ) assert add_result.exit_code == 0 result = runner.invoke(main, ["search", "Hello"]) assert "Hello" in result.output assert "来源: " in result.output def test_list_command(runner, setup_test_env, tmpdir): db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path add_result = runner.invoke( main, ["add", os.path.join(setup_test_env, "test_file.py")] ) assert add_result.exit_code == 0 result = runner.invoke(main, ["list"]) assert "test_file.py" in result.output def test_remove_wildcard(runner, setup_test_env, tmpdir): db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")]) runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")]) result = runner.invoke(main, ["remove", "*.py"]) assert result.exit_code == 0 assert "已删除" in result.output list_result = runner.invoke(main, ["list"]) assert "test_file.py" not in list_result.output def test_list_wildcard(runner, setup_test_env, tmpdir): db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")]) runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")]) result = runner.invoke(main, ["list", "*.py"]) assert result.exit_code == 0 assert "test_file.py" in result.output assert "file1.py" in result.output class TestCollectFiles: def test_exclude_pattern(self, setup_test_env): files = collect_files( [setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"] ) names = [f.name for f in files] assert "test_file.py" not in names assert "file1.py" in names def test_include_pattern(self, setup_test_env): files = collect_files( [setup_test_env], recursive=True, include_patterns=["*.py"] ) names = [f.name for f in files] assert "test_file.py" in names assert "file1.py" in names def test_binary_skip(self, tmpdir): with open(os.path.join(tmpdir, "binary.bin"), "wb") as f: f.write(b"\x00\x01\x02binary") with open(os.path.join(tmpdir, "text.py"), "w") as f: f.write("print('hello')") files = collect_files([str(tmpdir)], recursive=True) names = [f.name for f in files] assert "binary.bin" not in names assert "text.py" in names def test_gitignore(self, tmpdir): os.makedirs(os.path.join(tmpdir, "src")) with open(os.path.join(tmpdir, "src", "main.py"), "w") as f: f.write("print('hello')") with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f: f.write("def test(): pass") with open(os.path.join(tmpdir, ".gitignore"), "w") as f: f.write("**/test_*.py\n") files = collect_files([os.path.join(tmpdir, "src")], recursive=True) names = [f.name for f in files] assert "main.py" in names assert "test_main.py" not in names def test_no_ignore_disables_gitignore(self, setup_test_env): with open(os.path.join(setup_test_env, ".gitignore"), "w") as f: f.write("test_file.py\n") files = collect_files([setup_test_env], recursive=True, use_gitignore=False) names = [f.name for f in files] assert "test_file.py" in names def test_combined_filters(self, tmpdir): os.makedirs(os.path.join(tmpdir, "lib")) with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f: f.write("def main(): pass") with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f: f.write("def test(): pass") with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f: f.write(b"\x00binary") files = collect_files( [os.path.join(tmpdir, "lib")], recursive=True, include_patterns=["*.py"], exclude_patterns=["**/test_*.py"], ) names = [f.name for f in files] assert "main.py" in names assert "test_main.py" not in names assert "data.bin" not in names