116 lines
2.8 KiB
Python
116 lines
2.8 KiB
Python
import pytest
|
|
import os
|
|
import shutil
|
|
from ocrag.db import RagDB
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db(tmpdir):
|
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
|
yield RagDB(db_path)
|
|
if os.path.exists(db_path):
|
|
shutil.rmtree(db_path)
|
|
|
|
|
|
def test_db_initialization(temp_db):
|
|
assert temp_db.table is not None
|
|
|
|
|
|
def test_add_documents(temp_db):
|
|
documents = [
|
|
{
|
|
"text": "Test content",
|
|
"vector": [0.1] * 1024,
|
|
"metadata": {"source_file": "test.py"},
|
|
}
|
|
]
|
|
temp_db.add_documents(documents)
|
|
assert temp_db.table.to_pandas().shape[0] == 1
|
|
|
|
|
|
def test_search(temp_db):
|
|
# Add a document
|
|
documents = [
|
|
{
|
|
"text": "Database configuration",
|
|
"vector": [0.1] * 1024,
|
|
"metadata": {"source_file": "config.py"},
|
|
}
|
|
]
|
|
temp_db.add_documents(documents)
|
|
|
|
# Search
|
|
results = temp_db.search([0.1] * 1024)
|
|
assert len(results) == 1
|
|
assert "Database configuration" in results[0]["text"]
|
|
|
|
|
|
def test_list_sources(temp_db):
|
|
# Add documents from multiple sources
|
|
documents = [
|
|
{
|
|
"text": "Content 1",
|
|
"vector": [0.1] * 1024,
|
|
"metadata": {"source_file": "file1.py"},
|
|
},
|
|
{
|
|
"text": "Content 2",
|
|
"vector": [0.2] * 1024,
|
|
"metadata": {"source_file": "file2.py"},
|
|
},
|
|
]
|
|
temp_db.add_documents(documents)
|
|
|
|
sources = temp_db.list_sources()
|
|
assert len(sources) == 2
|
|
assert "file1.py" in sources
|
|
assert "file2.py" in sources
|
|
|
|
|
|
def test_delete_by_source(temp_db):
|
|
# Add documents from multiple sources
|
|
documents = [
|
|
{
|
|
"text": "Content A",
|
|
"vector": [0.1] * 1024,
|
|
"metadata": {"source_file": "file1.py"},
|
|
},
|
|
{
|
|
"text": "Content B",
|
|
"vector": [0.2] * 1024,
|
|
"metadata": {"source_file": "file2.py"},
|
|
},
|
|
{
|
|
"text": "Content C",
|
|
"vector": [0.3] * 1024,
|
|
"metadata": {"source_file": "file1.py"},
|
|
},
|
|
{
|
|
"text": "Content D",
|
|
"vector": [0.4] * 1024,
|
|
"metadata": {"source_file": "file3.py"},
|
|
},
|
|
]
|
|
temp_db.add_documents(documents)
|
|
|
|
# Verify initial state
|
|
sources = temp_db.list_sources()
|
|
assert len(sources) == 3
|
|
|
|
# Delete file1.py (should delete 2 chunks)
|
|
num_deleted = temp_db.delete_by_source("file1.py")
|
|
assert num_deleted == 2
|
|
|
|
# Verify deletion
|
|
sources = temp_db.list_sources()
|
|
assert len(sources) == 2
|
|
assert "file1.py" not in sources
|
|
assert "file2.py" in sources
|
|
assert "file3.py" in sources
|
|
|
|
|
|
def test_delete_nonexistent_source(temp_db):
|
|
# Try to delete a source that doesn't exist
|
|
num_deleted = temp_db.delete_by_source("nonexistent.py")
|
|
assert num_deleted == 0
|