ocrag/scripts/benchmark.py

279 lines
8.6 KiB
Python
Executable File

#!/usr/bin/env python3
"""
性能测试脚本 - 测试 ocrag 的写入和读取性能
使用方法:
uv run python scripts/benchmark.py
uv run python scripts/benchmark.py --cleanup # 测试后清理测试数据
"""
import sys
import os
import time
import tempfile
import shutil
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from ocrag.chunker import chunk_code
from ocrag.embedder import embedder
from ocrag.db import RagDB
class PerformanceBenchmark:
def __init__(self, db_path=None):
if db_path is None:
self.db_path = tempfile.mkdtemp(prefix="ocrag_benchmark_")
else:
self.db_path = db_path
self.db = RagDB(self.db_path)
def cleanup(self):
"""清理测试数据库"""
if os.path.exists(self.db_path):
shutil.rmtree(self.db_path)
print(f"✅ 已清理测试数据库: {self.db_path}")
def generate_test_code(self, num_lines=100, language="python"):
"""生成测试代码"""
if language == "python":
code = "def function_{i}():\n '''Docstring for function {i}'''\n result = 0\n"
code += " for i in range(100):\n result += i\n"
code += " return result\n\n"
elif language == "javascript":
code = "function function_{i}() {{\n // Docstring for function {i}\n"
code += " let result = 0;\n"
code += (
" for (let i = 0; i < 100; i++) {{\n result += i;\n }}\n"
)
code += " return result;\n}}\n\n"
full_code = ""
for i in range(num_lines):
full_code += code.format(i=i)
return full_code
def benchmark_single_file_write(self, num_lines=100):
"""测试单个文件的写入性能"""
print(f"\n{'=' * 60}")
print(f"测试: 单个文件写入 ({num_lines} 行代码)")
print("=" * 60)
code = self.generate_test_code(num_lines)
file_path = "test_file.py"
# 分块
start_time = time.time()
chunks = chunk_code(code, file_path)
chunk_time = time.time() - start_time
print(f"分块耗时: {chunk_time * 1000:.2f} ms")
print(f"生成块数: {len(chunks)}")
# Embedding
start_time = time.time()
texts = [c["text"] for c in chunks]
vectors = embedder.embed(texts)
embed_time = time.time() - start_time
print(f"Embedding 耗时: {embed_time * 1000:.2f} ms")
# 构建文档
documents = []
for chunk, vec in zip(chunks, vectors):
documents.append(
{
"text": chunk["text"],
"vector": vec,
"metadata": chunk["metadata"],
}
)
# 写入数据库
start_time = time.time()
self.db.add_documents(documents)
db_write_time = time.time() - start_time
print(f"数据库写入耗时: {db_write_time * 1000:.2f} ms")
total_time = chunk_time + embed_time + db_write_time
print(f"\n📊 总耗时: {total_time * 1000:.2f} ms")
return {
"chunk_time": chunk_time * 1000,
"embed_time": embed_time * 1000,
"db_write_time": db_write_time * 1000,
"total_time": total_time * 1000,
"num_chunks": len(chunks),
}
def benchmark_batch_write(self, num_files=10, lines_per_file=50):
"""测试批量文件写入性能"""
print(f"\n{'=' * 60}")
print(f"测试: 批量写入 ({num_files} 个文件, 每个 {lines_per_file} 行)")
print("=" * 60)
total_chunks = 0
total_embed_time = 0
total_db_write_time = 0
for i in range(num_files):
code = self.generate_test_code(lines_per_file)
file_path = f"test_file_{i}.py"
# 分块
chunks = chunk_code(code, file_path)
total_chunks += len(chunks)
# Embedding
texts = [c["text"] for c in chunks]
start_time = time.time()
vectors = embedder.embed(texts)
total_embed_time += time.time() - start_time
# 构建文档并写入
documents = []
for chunk, vec in zip(chunks, vectors):
documents.append(
{
"text": chunk["text"],
"vector": vec,
"metadata": chunk["metadata"],
}
)
start_time = time.time()
self.db.add_documents(documents)
total_db_write_time += time.time() - start_time
total_time = total_embed_time + total_db_write_time
print(f"总块数: {total_chunks}")
print(f"平均每块 embedding: {total_embed_time / num_files * 1000:.2f} ms")
print(f"平均每块数据库写入: {total_db_write_time / num_files * 1000:.2f} ms")
print(f"📊 总耗时: {total_time * 1000:.2f} ms")
print(f"📊 吞吐量: {total_chunks / total_time:.2f} 块/秒")
return {
"total_chunks": total_chunks,
"total_embed_time": total_embed_time * 1000,
"total_db_write_time": total_db_write_time * 1000,
"total_time": total_time * 1000,
"throughput": total_chunks / total_time,
}
def benchmark_search(self, num_queries=10):
"""测试搜索性能"""
print(f"\n{'=' * 60}")
print(f"测试: 搜索性能 ({num_queries} 次查询)")
print("=" * 60)
queries = [
"how to implement user authentication",
"database configuration",
"API endpoint handler",
"error handling function",
"data validation logic",
]
search_times = []
for i in range(num_queries):
query = queries[i % len(queries)]
# Query embedding
start_time = time.time()
query_vec = embedder.embed_single(query)
embed_time = time.time() - start_time
# Search
start_time = time.time()
results = self.db.search(query_vec, top_k=5)
search_time = time.time() - start_time
total_time = embed_time + search_time
search_times.append(total_time)
print(
f"查询 {i + 1}: {query[:30]}... → {len(results)} 结果 ({total_time * 1000:.2f} ms)"
)
avg_time = sum(search_times) / len(search_times)
min_time = min(search_times)
max_time = max(search_times)
print(f"\n📊 搜索性能统计:")
print(f" 平均延迟: {avg_time * 1000:.2f} ms")
print(f" 最小延迟: {min_time * 1000:.2f} ms")
print(f" 最大延迟: {max_time * 1000:.2f} ms")
return {
"avg_time": avg_time * 1000,
"min_time": min_time * 1000,
"max_time": max_time * 1000,
"num_queries": num_queries,
}
def benchmark_list_sources(self):
"""测试 list_sources 性能"""
print(f"\n{'=' * 60}")
print("测试: list_sources 性能")
print("=" * 60)
start_time = time.time()
sources = self.db.list_sources()
elapsed = time.time() - start_time
print(f"列出来源数: {len(sources)}")
print(f"📊 耗时: {elapsed * 1000:.2f} ms")
return {
"num_sources": len(sources),
"time": elapsed * 1000,
}
def run_full_benchmark(self):
"""运行完整性能测试"""
print("\n" + "=" * 60)
print("🚀 Ocrag 性能基准测试")
print("=" * 60)
# 准备测试数据
print("\n📦 准备测试数据...")
self.benchmark_single_file_write(num_lines=100)
# 批量写入测试
self.benchmark_batch_write(num_files=20, lines_per_file=50)
# 搜索测试
self.benchmark_search(num_queries=10)
# List 测试
self.benchmark_list_sources()
print("\n" + "=" * 60)
print("✅ 性能测试完成!")
print("=" * 60)
def main():
import argparse
parser = argparse.ArgumentParser(description="Ocrag 性能测试")
parser.add_argument("--cleanup", action="store_true", help="测试后清理数据")
parser.add_argument("--db-path", type=str, help="指定数据库路径")
args = parser.parse_args()
benchmark = PerformanceBenchmark(db_path=args.db_path)
try:
benchmark.run_full_benchmark()
finally:
if args.cleanup:
benchmark.cleanup()
if __name__ == "__main__":
main()