#!/usr/bin/env python3 """ 代码分块性能测试脚本 测试不同分块策略的效果和性能: 1. langchain - 语言感知分块 2. semantic - 语义分块 3. simple - 简单字符分块 使用方法: uv run python scripts/test_chunking.py """ import sys import time from pathlib import Path from typing import List, Dict, Any sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from ocrag.chunker import ( chunk_code, ChunkStrategy, get_available_strategies, LANGCHAIN_AVAILABLE, SEMANTIC_AVAILABLE, ) class ChunkingBenchmark: """分块性能测试""" # 测试代码样本 TEST_CASES = { "python_small": { "code": ''' class UserManager: """用户管理类""" def __init__(self): self.users = {} def add_user(self, user_id, name): self.users[user_id] = {'name': name, 'active': True} return True def remove_user(self, user_id): if user_id in self.users: del self.users[user_id] return True return False ''', "file": "test.py", "expected_classes": 1, }, "python_medium": { "code": ''' class User: """用户基类""" def __init__(self, name, email): self.name = name self.email = email self.active = True def deactivate(self): self.active = False class Admin(User): """管理员类""" def __init__(self, name, email, permissions): super().__init__(name, email) self.permissions = permissions def has_permission(self, permission): return permission in self.permissions class Guest(User): """访客类""" def __init__(self, name, email): super().__init__(name, email) self.active = False def create_user(user_type, name, email, **kwargs): """用户工厂函数""" if user_type == 'admin': return Admin(name, email, kwargs.get('permissions', [])) elif user_type == 'guest': return Guest(name, email) else: return User(name, email) def validate_email(email): """验证邮箱格式""" return '@' in email and '.' in email.split('@')[1] def process_batch(users): """批量处理用户""" results = [] for user in users: if user.active: results.append({ 'name': user.name, 'email': user.email, 'status': 'active' }) return results ''', "file": "user_manager.py", "expected_classes": 3, }, "python_large": { "code": ''' import json from typing import List, Dict, Optional class DataProcessor: """数据处理器""" def __init__(self, config): self.config = config self.cache = {} def process(self, data): if not self.validate(data): raise ValueError("Invalid data format") return self.transform(data) def validate(self, data): return isinstance(data, dict) and 'id' in data def transform(self, data): return { 'id': data['id'], 'processed': True, 'timestamp': time.time() } class APIClient: """API 客户端""" def __init__(self, base_url, api_key): self.base_url = base_url self.api_key = api_key def get(self, endpoint): url = f"{self.base_url}/{endpoint}" headers = {'Authorization': f'Bearer {self.api_key}'} return self._request('GET', url, headers) def post(self, endpoint, data): url = f"{self.base_url}/{endpoint}" headers = { 'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json' } return self._request('POST', url, headers, json.dumps(data)) def _request(self, method, url, headers, body=None): # 实现请求逻辑 pass class CacheManager: """缓存管理器""" def __init__(self, max_size=1000): self.max_size = max_size self.cache = {} self.access_order = [] def get(self, key): if key in self.cache: self._update_access(key) return self.cache[key] return None def set(self, key, value): if len(self.cache) >= self.max_size: oldest = self.access_order.pop(0) del self.cache[oldest] self.cache[key] = value self.access_order.append(key) def _update_access(self, key): if key in self.access_order: self.access_order.remove(key) self.access_order.append(key) def clear(self): self.cache.clear() self.access_order.clear() ''', "file": "large_module.py", "expected_classes": 3, }, "rust": { "code": """ struct User { name: String, email: String, active: bool, } impl User { fn new(name: String, email: String) -> Self { User { name, email, active: true, } } fn deactivate(&mut self) { self.active = false; } } struct Admin { user: User, permissions: Vec, } impl Admin { fn has_permission(&self, permission: &str) -> bool { self.permissions.contains(&permission.to_string()) } } fn create_user(name: String, email: String) -> User { User::new(name, email) } fn process_batch(users: Vec) -> Vec { users .iter() .filter(|u| u.active) .map(|u| u.name.clone()) .collect() } """, "file": "user.rs", "expected_classes": 2, }, "cpp": { "code": """ #include #include class User { private: std::string name; std::string email; bool active; public: User(const std::string& name, const std::string& email) : name(name), email(email), active(true) {} void deactivate() { active = false; } bool isActive() const { return active; } }; class Admin : public User { private: std::vector permissions; public: Admin(const std::string& name, const std::string& email) : User(name, email) {} void addPermission(const std::string& perm) { permissions.push_back(perm); } }; void processUsers(std::vector& users) { for (auto& user : users) { if (user.isActive()) { // 处理活跃用户 } } } """, "file": "user.cpp", "expected_classes": 2, }, } def __init__(self): self.results = {} def run_benchmark(self, strategy: ChunkStrategy): """运行单个策略的基准测试""" print(f"\n{'=' * 60}") print(f"测试策略: {strategy.value}") print("=" * 60) total_chunks = 0 total_time = 0 first_init_time = None results = [] for name, test_case in self.TEST_CASES.items(): # 测试初始化时间(首次) start_time = time.time() chunks = chunk_code( test_case["code"], test_case["file"], strategy=strategy, chunk_size=512, chunk_overlap=50, ) elapsed = time.time() - start_time if first_init_time is None: first_init_time = elapsed total_time += elapsed total_chunks += len(chunks) results.append( { "name": name, "chunks": len(chunks), "time_ms": elapsed * 1000, "chars": len(test_case["code"]), } ) print(f"\n [{name}]") print(f" 字符数: {len(test_case['code'])}") print(f" 分块数: {len(chunks)}") print(f" 耗时: {elapsed * 1000:.2f}ms") print(f"\n总耗时: {total_time * 1000:.2f}ms") print(f"总块数: {total_chunks}") return { "strategy": strategy.value, "total_time_ms": total_time * 1000, "total_chunks": total_chunks, "first_init_time_ms": first_init_time * 1000 if first_init_time else 0, "details": results, } def run_all(self): """运行所有策略的测试""" print("=" * 60) print("📊 代码分块性能测试") print("=" * 60) print(f"\n可用的分块策略: {get_available_strategies()}") print(f"langchain 可用: {LANGCHAIN_AVAILABLE}") print(f"semantic-text-splitter 可用: {SEMANTIC_AVAILABLE}") all_results = [] # 测试每个策略 for strategy_name in get_available_strategies(): strategy = ChunkStrategy(strategy_name) result = self.run_benchmark(strategy) all_results.append(result) # 打印汇总 self.print_summary(all_results) return all_results def print_summary(self, results: List[Dict]): """打印测试汇总""" print("\n" + "=" * 60) print("📈 测试结果汇总") print("=" * 60) print(f"\n{'策略':<15} {'总耗时':<12} {'首次初始化':<12} {'分块数':<10}") print("-" * 50) for result in results: print( f"{result['strategy']:<15} " f"{result['total_time_ms']:>8.2f}ms " f"{result['first_init_time_ms']:>8.2f}ms " f"{result['total_chunks']:>6}" ) # 找出最快的策略 fastest = min(results, key=lambda x: x["total_time_ms"]) most_chunks = max(results, key=lambda x: x["total_chunks"]) print( f"\n🏆 最快策略: {fastest['strategy']} ({fastest['total_time_ms']:.2f}ms)" ) print( f"📦 最多分块: {most_chunks['strategy']} ({most_chunks['total_chunks']} 块)" ) def main(): benchmark = ChunkingBenchmark() results = benchmark.run_all() print("\n" + "=" * 60) print("✅ 测试完成!") print("=" * 60) if __name__ == "__main__": main()