#!/usr/bin/env python3 """ 性能测试脚本 - 测试 ocrag 的写入和读取性能 使用方法: uv run python scripts/benchmark.py uv run python scripts/benchmark.py --cleanup # 测试后清理测试数据 """ import sys import os import time import tempfile import shutil from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from ocrag.chunker import chunk_code from ocrag.embedder import embedder from ocrag.db import RagDB class PerformanceBenchmark: def __init__(self, db_path=None): if db_path is None: self.db_path = tempfile.mkdtemp(prefix="ocrag_benchmark_") else: self.db_path = db_path self.db = RagDB(self.db_path) def cleanup(self): """清理测试数据库""" if os.path.exists(self.db_path): shutil.rmtree(self.db_path) print(f"✅ 已清理测试数据库: {self.db_path}") def generate_test_code(self, num_lines=100, language="python"): """生成测试代码""" if language == "python": code = "def function_{i}():\n '''Docstring for function {i}'''\n result = 0\n" code += " for i in range(100):\n result += i\n" code += " return result\n\n" elif language == "javascript": code = "function function_{i}() {{\n // Docstring for function {i}\n" code += " let result = 0;\n" code += ( " for (let i = 0; i < 100; i++) {{\n result += i;\n }}\n" ) code += " return result;\n}}\n\n" full_code = "" for i in range(num_lines): full_code += code.format(i=i) return full_code def benchmark_single_file_write(self, num_lines=100): """测试单个文件的写入性能""" print(f"\n{'=' * 60}") print(f"测试: 单个文件写入 ({num_lines} 行代码)") print("=" * 60) code = self.generate_test_code(num_lines) file_path = "test_file.py" # 分块 start_time = time.time() chunks = chunk_code(code, file_path) chunk_time = time.time() - start_time print(f"分块耗时: {chunk_time * 1000:.2f} ms") print(f"生成块数: {len(chunks)}") # Embedding start_time = time.time() texts = [c["text"] for c in chunks] vectors = embedder.embed(texts) embed_time = time.time() - start_time print(f"Embedding 耗时: {embed_time * 1000:.2f} ms") # 构建文档 documents = [] for chunk, vec in zip(chunks, vectors): documents.append( { "text": chunk["text"], "vector": vec, "metadata": chunk["metadata"], } ) # 写入数据库 start_time = time.time() self.db.add_documents(documents) db_write_time = time.time() - start_time print(f"数据库写入耗时: {db_write_time * 1000:.2f} ms") total_time = chunk_time + embed_time + db_write_time print(f"\n📊 总耗时: {total_time * 1000:.2f} ms") return { "chunk_time": chunk_time * 1000, "embed_time": embed_time * 1000, "db_write_time": db_write_time * 1000, "total_time": total_time * 1000, "num_chunks": len(chunks), } def benchmark_batch_write(self, num_files=10, lines_per_file=50): """测试批量文件写入性能""" print(f"\n{'=' * 60}") print(f"测试: 批量写入 ({num_files} 个文件, 每个 {lines_per_file} 行)") print("=" * 60) total_chunks = 0 total_embed_time = 0 total_db_write_time = 0 for i in range(num_files): code = self.generate_test_code(lines_per_file) file_path = f"test_file_{i}.py" # 分块 chunks = chunk_code(code, file_path) total_chunks += len(chunks) # Embedding texts = [c["text"] for c in chunks] start_time = time.time() vectors = embedder.embed(texts) total_embed_time += time.time() - start_time # 构建文档并写入 documents = [] for chunk, vec in zip(chunks, vectors): documents.append( { "text": chunk["text"], "vector": vec, "metadata": chunk["metadata"], } ) start_time = time.time() self.db.add_documents(documents) total_db_write_time += time.time() - start_time total_time = total_embed_time + total_db_write_time print(f"总块数: {total_chunks}") print(f"平均每块 embedding: {total_embed_time / num_files * 1000:.2f} ms") print(f"平均每块数据库写入: {total_db_write_time / num_files * 1000:.2f} ms") print(f"📊 总耗时: {total_time * 1000:.2f} ms") print(f"📊 吞吐量: {total_chunks / total_time:.2f} 块/秒") return { "total_chunks": total_chunks, "total_embed_time": total_embed_time * 1000, "total_db_write_time": total_db_write_time * 1000, "total_time": total_time * 1000, "throughput": total_chunks / total_time, } def benchmark_search(self, num_queries=10): """测试搜索性能""" print(f"\n{'=' * 60}") print(f"测试: 搜索性能 ({num_queries} 次查询)") print("=" * 60) queries = [ "how to implement user authentication", "database configuration", "API endpoint handler", "error handling function", "data validation logic", ] search_times = [] for i in range(num_queries): query = queries[i % len(queries)] # Query embedding start_time = time.time() query_vec = embedder.embed_single(query) embed_time = time.time() - start_time # Search start_time = time.time() results = self.db.search(query_vec, top_k=5) search_time = time.time() - start_time total_time = embed_time + search_time search_times.append(total_time) print( f"查询 {i + 1}: {query[:30]}... → {len(results)} 结果 ({total_time * 1000:.2f} ms)" ) avg_time = sum(search_times) / len(search_times) min_time = min(search_times) max_time = max(search_times) print(f"\n📊 搜索性能统计:") print(f" 平均延迟: {avg_time * 1000:.2f} ms") print(f" 最小延迟: {min_time * 1000:.2f} ms") print(f" 最大延迟: {max_time * 1000:.2f} ms") return { "avg_time": avg_time * 1000, "min_time": min_time * 1000, "max_time": max_time * 1000, "num_queries": num_queries, } def benchmark_list_sources(self): """测试 list_sources 性能""" print(f"\n{'=' * 60}") print("测试: list_sources 性能") print("=" * 60) start_time = time.time() sources = self.db.list_sources() elapsed = time.time() - start_time print(f"列出来源数: {len(sources)}") print(f"📊 耗时: {elapsed * 1000:.2f} ms") return { "num_sources": len(sources), "time": elapsed * 1000, } def run_full_benchmark(self): """运行完整性能测试""" print("\n" + "=" * 60) print("🚀 Ocrag 性能基准测试") print("=" * 60) # 准备测试数据 print("\n📦 准备测试数据...") self.benchmark_single_file_write(num_lines=100) # 批量写入测试 self.benchmark_batch_write(num_files=20, lines_per_file=50) # 搜索测试 self.benchmark_search(num_queries=10) # List 测试 self.benchmark_list_sources() print("\n" + "=" * 60) print("✅ 性能测试完成!") print("=" * 60) def main(): import argparse parser = argparse.ArgumentParser(description="Ocrag 性能测试") parser.add_argument("--cleanup", action="store_true", help="测试后清理数据") parser.add_argument("--db-path", type=str, help="指定数据库路径") args = parser.parse_args() benchmark = PerformanceBenchmark(db_path=args.db_path) try: benchmark.run_full_benchmark() finally: if args.cleanup: benchmark.cleanup() if __name__ == "__main__": main()