279 lines
8.6 KiB
Python
Executable File
279 lines
8.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
性能测试脚本 - 测试 ocrag 的写入和读取性能
|
|
|
|
使用方法:
|
|
uv run python scripts/benchmark.py
|
|
uv run python scripts/benchmark.py --cleanup # 测试后清理测试数据
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import time
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from ocrag.chunker import chunk_code
|
|
from ocrag.embedder import embedder
|
|
from ocrag.db import RagDB
|
|
|
|
|
|
class PerformanceBenchmark:
|
|
def __init__(self, db_path=None):
|
|
if db_path is None:
|
|
self.db_path = tempfile.mkdtemp(prefix="ocrag_benchmark_")
|
|
else:
|
|
self.db_path = db_path
|
|
self.db = RagDB(self.db_path)
|
|
|
|
def cleanup(self):
|
|
"""清理测试数据库"""
|
|
if os.path.exists(self.db_path):
|
|
shutil.rmtree(self.db_path)
|
|
print(f"✅ 已清理测试数据库: {self.db_path}")
|
|
|
|
def generate_test_code(self, num_lines=100, language="python"):
|
|
"""生成测试代码"""
|
|
if language == "python":
|
|
code = "def function_{i}():\n '''Docstring for function {i}'''\n result = 0\n"
|
|
code += " for i in range(100):\n result += i\n"
|
|
code += " return result\n\n"
|
|
elif language == "javascript":
|
|
code = "function function_{i}() {{\n // Docstring for function {i}\n"
|
|
code += " let result = 0;\n"
|
|
code += (
|
|
" for (let i = 0; i < 100; i++) {{\n result += i;\n }}\n"
|
|
)
|
|
code += " return result;\n}}\n\n"
|
|
|
|
full_code = ""
|
|
for i in range(num_lines):
|
|
full_code += code.format(i=i)
|
|
return full_code
|
|
|
|
def benchmark_single_file_write(self, num_lines=100):
|
|
"""测试单个文件的写入性能"""
|
|
print(f"\n{'=' * 60}")
|
|
print(f"测试: 单个文件写入 ({num_lines} 行代码)")
|
|
print("=" * 60)
|
|
|
|
code = self.generate_test_code(num_lines)
|
|
file_path = "test_file.py"
|
|
|
|
# 分块
|
|
start_time = time.time()
|
|
chunks = chunk_code(code, file_path)
|
|
chunk_time = time.time() - start_time
|
|
|
|
print(f"分块耗时: {chunk_time * 1000:.2f} ms")
|
|
print(f"生成块数: {len(chunks)}")
|
|
|
|
# Embedding
|
|
start_time = time.time()
|
|
texts = [c["text"] for c in chunks]
|
|
vectors = embedder.embed(texts)
|
|
embed_time = time.time() - start_time
|
|
|
|
print(f"Embedding 耗时: {embed_time * 1000:.2f} ms")
|
|
|
|
# 构建文档
|
|
documents = []
|
|
for chunk, vec in zip(chunks, vectors):
|
|
documents.append(
|
|
{
|
|
"text": chunk["text"],
|
|
"vector": vec,
|
|
"metadata": chunk["metadata"],
|
|
}
|
|
)
|
|
|
|
# 写入数据库
|
|
start_time = time.time()
|
|
self.db.add_documents(documents)
|
|
db_write_time = time.time() - start_time
|
|
|
|
print(f"数据库写入耗时: {db_write_time * 1000:.2f} ms")
|
|
|
|
total_time = chunk_time + embed_time + db_write_time
|
|
print(f"\n📊 总耗时: {total_time * 1000:.2f} ms")
|
|
|
|
return {
|
|
"chunk_time": chunk_time * 1000,
|
|
"embed_time": embed_time * 1000,
|
|
"db_write_time": db_write_time * 1000,
|
|
"total_time": total_time * 1000,
|
|
"num_chunks": len(chunks),
|
|
}
|
|
|
|
def benchmark_batch_write(self, num_files=10, lines_per_file=50):
|
|
"""测试批量文件写入性能"""
|
|
print(f"\n{'=' * 60}")
|
|
print(f"测试: 批量写入 ({num_files} 个文件, 每个 {lines_per_file} 行)")
|
|
print("=" * 60)
|
|
|
|
total_chunks = 0
|
|
total_embed_time = 0
|
|
total_db_write_time = 0
|
|
|
|
for i in range(num_files):
|
|
code = self.generate_test_code(lines_per_file)
|
|
file_path = f"test_file_{i}.py"
|
|
|
|
# 分块
|
|
chunks = chunk_code(code, file_path)
|
|
total_chunks += len(chunks)
|
|
|
|
# Embedding
|
|
texts = [c["text"] for c in chunks]
|
|
start_time = time.time()
|
|
vectors = embedder.embed(texts)
|
|
total_embed_time += time.time() - start_time
|
|
|
|
# 构建文档并写入
|
|
documents = []
|
|
for chunk, vec in zip(chunks, vectors):
|
|
documents.append(
|
|
{
|
|
"text": chunk["text"],
|
|
"vector": vec,
|
|
"metadata": chunk["metadata"],
|
|
}
|
|
)
|
|
|
|
start_time = time.time()
|
|
self.db.add_documents(documents)
|
|
total_db_write_time += time.time() - start_time
|
|
|
|
total_time = total_embed_time + total_db_write_time
|
|
|
|
print(f"总块数: {total_chunks}")
|
|
print(f"平均每块 embedding: {total_embed_time / num_files * 1000:.2f} ms")
|
|
print(f"平均每块数据库写入: {total_db_write_time / num_files * 1000:.2f} ms")
|
|
print(f"📊 总耗时: {total_time * 1000:.2f} ms")
|
|
print(f"📊 吞吐量: {total_chunks / total_time:.2f} 块/秒")
|
|
|
|
return {
|
|
"total_chunks": total_chunks,
|
|
"total_embed_time": total_embed_time * 1000,
|
|
"total_db_write_time": total_db_write_time * 1000,
|
|
"total_time": total_time * 1000,
|
|
"throughput": total_chunks / total_time,
|
|
}
|
|
|
|
def benchmark_search(self, num_queries=10):
|
|
"""测试搜索性能"""
|
|
print(f"\n{'=' * 60}")
|
|
print(f"测试: 搜索性能 ({num_queries} 次查询)")
|
|
print("=" * 60)
|
|
|
|
queries = [
|
|
"how to implement user authentication",
|
|
"database configuration",
|
|
"API endpoint handler",
|
|
"error handling function",
|
|
"data validation logic",
|
|
]
|
|
|
|
search_times = []
|
|
|
|
for i in range(num_queries):
|
|
query = queries[i % len(queries)]
|
|
|
|
# Query embedding
|
|
start_time = time.time()
|
|
query_vec = embedder.embed_single(query)
|
|
embed_time = time.time() - start_time
|
|
|
|
# Search
|
|
start_time = time.time()
|
|
results = self.db.search(query_vec, top_k=5)
|
|
search_time = time.time() - start_time
|
|
|
|
total_time = embed_time + search_time
|
|
search_times.append(total_time)
|
|
|
|
print(
|
|
f"查询 {i + 1}: {query[:30]}... → {len(results)} 结果 ({total_time * 1000:.2f} ms)"
|
|
)
|
|
|
|
avg_time = sum(search_times) / len(search_times)
|
|
min_time = min(search_times)
|
|
max_time = max(search_times)
|
|
|
|
print(f"\n📊 搜索性能统计:")
|
|
print(f" 平均延迟: {avg_time * 1000:.2f} ms")
|
|
print(f" 最小延迟: {min_time * 1000:.2f} ms")
|
|
print(f" 最大延迟: {max_time * 1000:.2f} ms")
|
|
|
|
return {
|
|
"avg_time": avg_time * 1000,
|
|
"min_time": min_time * 1000,
|
|
"max_time": max_time * 1000,
|
|
"num_queries": num_queries,
|
|
}
|
|
|
|
def benchmark_list_sources(self):
|
|
"""测试 list_sources 性能"""
|
|
print(f"\n{'=' * 60}")
|
|
print("测试: list_sources 性能")
|
|
print("=" * 60)
|
|
|
|
start_time = time.time()
|
|
sources = self.db.list_sources()
|
|
elapsed = time.time() - start_time
|
|
|
|
print(f"列出来源数: {len(sources)}")
|
|
print(f"📊 耗时: {elapsed * 1000:.2f} ms")
|
|
|
|
return {
|
|
"num_sources": len(sources),
|
|
"time": elapsed * 1000,
|
|
}
|
|
|
|
def run_full_benchmark(self):
|
|
"""运行完整性能测试"""
|
|
print("\n" + "=" * 60)
|
|
print("🚀 Ocrag 性能基准测试")
|
|
print("=" * 60)
|
|
|
|
# 准备测试数据
|
|
print("\n📦 准备测试数据...")
|
|
self.benchmark_single_file_write(num_lines=100)
|
|
|
|
# 批量写入测试
|
|
self.benchmark_batch_write(num_files=20, lines_per_file=50)
|
|
|
|
# 搜索测试
|
|
self.benchmark_search(num_queries=10)
|
|
|
|
# List 测试
|
|
self.benchmark_list_sources()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("✅ 性能测试完成!")
|
|
print("=" * 60)
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Ocrag 性能测试")
|
|
parser.add_argument("--cleanup", action="store_true", help="测试后清理数据")
|
|
parser.add_argument("--db-path", type=str, help="指定数据库路径")
|
|
args = parser.parse_args()
|
|
|
|
benchmark = PerformanceBenchmark(db_path=args.db_path)
|
|
|
|
try:
|
|
benchmark.run_full_benchmark()
|
|
finally:
|
|
if args.cleanup:
|
|
benchmark.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|