import pytest import os import shutil from ocrag.db import RagDB @pytest.fixture def temp_db(tmpdir): db_path = os.path.join(tmpdir, "test_db.lance") yield RagDB(db_path) if os.path.exists(db_path): shutil.rmtree(db_path) def test_db_initialization(temp_db): assert temp_db.table is not None def test_add_documents(temp_db): documents = [ { "text": "Test content", "vector": [0.1] * 1024, "metadata": {"source_file": "test.py"}, } ] temp_db.add_documents(documents) assert temp_db.table.to_pandas().shape[0] == 1 def test_search(temp_db): # Add a document documents = [ { "text": "Database configuration", "vector": [0.1] * 1024, "metadata": {"source_file": "config.py"}, } ] temp_db.add_documents(documents) # Search results = temp_db.search([0.1] * 1024) assert len(results) == 1 assert "Database configuration" in results[0]["text"] def test_list_sources(temp_db): # Add documents from multiple sources documents = [ { "text": "Content 1", "vector": [0.1] * 1024, "metadata": {"source_file": "file1.py"}, }, { "text": "Content 2", "vector": [0.2] * 1024, "metadata": {"source_file": "file2.py"}, }, ] temp_db.add_documents(documents) sources = temp_db.list_sources() assert len(sources) == 2 assert "file1.py" in sources assert "file2.py" in sources def test_delete_by_source(temp_db): # Add documents from multiple sources documents = [ { "text": "Content A", "vector": [0.1] * 1024, "metadata": {"source_file": "file1.py"}, }, { "text": "Content B", "vector": [0.2] * 1024, "metadata": {"source_file": "file2.py"}, }, { "text": "Content C", "vector": [0.3] * 1024, "metadata": {"source_file": "file1.py"}, }, { "text": "Content D", "vector": [0.4] * 1024, "metadata": {"source_file": "file3.py"}, }, ] temp_db.add_documents(documents) # Verify initial state sources = temp_db.list_sources() assert len(sources) == 3 # Delete file1.py (should delete 2 chunks) num_deleted = temp_db.delete_by_source("file1.py") assert num_deleted == 2 # Verify deletion sources = temp_db.list_sources() assert len(sources) == 2 assert "file1.py" not in sources assert "file2.py" in sources assert "file3.py" in sources def test_delete_nonexistent_source(temp_db): # Try to delete a source that doesn't exist num_deleted = temp_db.delete_by_source("nonexistent.py") assert num_deleted == 0