ocrag/tests/unit/test_db.py

116 lines
2.8 KiB
Python

import pytest
import os
import shutil
from ocrag.db import RagDB
@pytest.fixture
def temp_db(tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
yield RagDB(db_path)
if os.path.exists(db_path):
shutil.rmtree(db_path)
def test_db_initialization(temp_db):
assert temp_db.table is not None
def test_add_documents(temp_db):
documents = [
{
"text": "Test content",
"vector": [0.1] * 1024,
"metadata": {"source_file": "test.py"},
}
]
temp_db.add_documents(documents)
assert temp_db.table.to_pandas().shape[0] == 1
def test_search(temp_db):
# Add a document
documents = [
{
"text": "Database configuration",
"vector": [0.1] * 1024,
"metadata": {"source_file": "config.py"},
}
]
temp_db.add_documents(documents)
# Search
results = temp_db.search([0.1] * 1024)
assert len(results) == 1
assert "Database configuration" in results[0]["text"]
def test_list_sources(temp_db):
# Add documents from multiple sources
documents = [
{
"text": "Content 1",
"vector": [0.1] * 1024,
"metadata": {"source_file": "file1.py"},
},
{
"text": "Content 2",
"vector": [0.2] * 1024,
"metadata": {"source_file": "file2.py"},
},
]
temp_db.add_documents(documents)
sources = temp_db.list_sources()
assert len(sources) == 2
assert "file1.py" in sources
assert "file2.py" in sources
def test_delete_by_source(temp_db):
# Add documents from multiple sources
documents = [
{
"text": "Content A",
"vector": [0.1] * 1024,
"metadata": {"source_file": "file1.py"},
},
{
"text": "Content B",
"vector": [0.2] * 1024,
"metadata": {"source_file": "file2.py"},
},
{
"text": "Content C",
"vector": [0.3] * 1024,
"metadata": {"source_file": "file1.py"},
},
{
"text": "Content D",
"vector": [0.4] * 1024,
"metadata": {"source_file": "file3.py"},
},
]
temp_db.add_documents(documents)
# Verify initial state
sources = temp_db.list_sources()
assert len(sources) == 3
# Delete file1.py (should delete 2 chunks)
num_deleted = temp_db.delete_by_source("file1.py")
assert num_deleted == 2
# Verify deletion
sources = temp_db.list_sources()
assert len(sources) == 2
assert "file1.py" not in sources
assert "file2.py" in sources
assert "file3.py" in sources
def test_delete_nonexistent_source(temp_db):
# Try to delete a source that doesn't exist
num_deleted = temp_db.delete_by_source("nonexistent.py")
assert num_deleted == 0