410 lines
10 KiB
Python
410 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
代码分块性能测试脚本
|
|
|
|
测试不同分块策略的效果和性能:
|
|
1. langchain - 语言感知分块
|
|
2. semantic - 语义分块
|
|
3. simple - 简单字符分块
|
|
|
|
使用方法:
|
|
uv run python scripts/test_chunking.py
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from ocrag.chunker import (
|
|
chunk_code,
|
|
ChunkStrategy,
|
|
get_available_strategies,
|
|
LANGCHAIN_AVAILABLE,
|
|
SEMANTIC_AVAILABLE,
|
|
)
|
|
|
|
|
|
class ChunkingBenchmark:
|
|
"""分块性能测试"""
|
|
|
|
# 测试代码样本
|
|
TEST_CASES = {
|
|
"python_small": {
|
|
"code": '''
|
|
class UserManager:
|
|
"""用户管理类"""
|
|
def __init__(self):
|
|
self.users = {}
|
|
|
|
def add_user(self, user_id, name):
|
|
self.users[user_id] = {'name': name, 'active': True}
|
|
return True
|
|
|
|
def remove_user(self, user_id):
|
|
if user_id in self.users:
|
|
del self.users[user_id]
|
|
return True
|
|
return False
|
|
''',
|
|
"file": "test.py",
|
|
"expected_classes": 1,
|
|
},
|
|
"python_medium": {
|
|
"code": '''
|
|
class User:
|
|
"""用户基类"""
|
|
def __init__(self, name, email):
|
|
self.name = name
|
|
self.email = email
|
|
self.active = True
|
|
|
|
def deactivate(self):
|
|
self.active = False
|
|
|
|
class Admin(User):
|
|
"""管理员类"""
|
|
def __init__(self, name, email, permissions):
|
|
super().__init__(name, email)
|
|
self.permissions = permissions
|
|
|
|
def has_permission(self, permission):
|
|
return permission in self.permissions
|
|
|
|
class Guest(User):
|
|
"""访客类"""
|
|
def __init__(self, name, email):
|
|
super().__init__(name, email)
|
|
self.active = False
|
|
|
|
def create_user(user_type, name, email, **kwargs):
|
|
"""用户工厂函数"""
|
|
if user_type == 'admin':
|
|
return Admin(name, email, kwargs.get('permissions', []))
|
|
elif user_type == 'guest':
|
|
return Guest(name, email)
|
|
else:
|
|
return User(name, email)
|
|
|
|
def validate_email(email):
|
|
"""验证邮箱格式"""
|
|
return '@' in email and '.' in email.split('@')[1]
|
|
|
|
def process_batch(users):
|
|
"""批量处理用户"""
|
|
results = []
|
|
for user in users:
|
|
if user.active:
|
|
results.append({
|
|
'name': user.name,
|
|
'email': user.email,
|
|
'status': 'active'
|
|
})
|
|
return results
|
|
''',
|
|
"file": "user_manager.py",
|
|
"expected_classes": 3,
|
|
},
|
|
"python_large": {
|
|
"code": '''
|
|
import json
|
|
from typing import List, Dict, Optional
|
|
|
|
class DataProcessor:
|
|
"""数据处理器"""
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.cache = {}
|
|
|
|
def process(self, data):
|
|
if not self.validate(data):
|
|
raise ValueError("Invalid data format")
|
|
return self.transform(data)
|
|
|
|
def validate(self, data):
|
|
return isinstance(data, dict) and 'id' in data
|
|
|
|
def transform(self, data):
|
|
return {
|
|
'id': data['id'],
|
|
'processed': True,
|
|
'timestamp': time.time()
|
|
}
|
|
|
|
class APIClient:
|
|
"""API 客户端"""
|
|
def __init__(self, base_url, api_key):
|
|
self.base_url = base_url
|
|
self.api_key = api_key
|
|
|
|
def get(self, endpoint):
|
|
url = f"{self.base_url}/{endpoint}"
|
|
headers = {'Authorization': f'Bearer {self.api_key}'}
|
|
return self._request('GET', url, headers)
|
|
|
|
def post(self, endpoint, data):
|
|
url = f"{self.base_url}/{endpoint}"
|
|
headers = {
|
|
'Authorization': f'Bearer {self.api_key}',
|
|
'Content-Type': 'application/json'
|
|
}
|
|
return self._request('POST', url, headers, json.dumps(data))
|
|
|
|
def _request(self, method, url, headers, body=None):
|
|
# 实现请求逻辑
|
|
pass
|
|
|
|
class CacheManager:
|
|
"""缓存管理器"""
|
|
def __init__(self, max_size=1000):
|
|
self.max_size = max_size
|
|
self.cache = {}
|
|
self.access_order = []
|
|
|
|
def get(self, key):
|
|
if key in self.cache:
|
|
self._update_access(key)
|
|
return self.cache[key]
|
|
return None
|
|
|
|
def set(self, key, value):
|
|
if len(self.cache) >= self.max_size:
|
|
oldest = self.access_order.pop(0)
|
|
del self.cache[oldest]
|
|
|
|
self.cache[key] = value
|
|
self.access_order.append(key)
|
|
|
|
def _update_access(self, key):
|
|
if key in self.access_order:
|
|
self.access_order.remove(key)
|
|
self.access_order.append(key)
|
|
|
|
def clear(self):
|
|
self.cache.clear()
|
|
self.access_order.clear()
|
|
''',
|
|
"file": "large_module.py",
|
|
"expected_classes": 3,
|
|
},
|
|
"rust": {
|
|
"code": """
|
|
struct User {
|
|
name: String,
|
|
email: String,
|
|
active: bool,
|
|
}
|
|
|
|
impl User {
|
|
fn new(name: String, email: String) -> Self {
|
|
User {
|
|
name,
|
|
email,
|
|
active: true,
|
|
}
|
|
}
|
|
|
|
fn deactivate(&mut self) {
|
|
self.active = false;
|
|
}
|
|
}
|
|
|
|
struct Admin {
|
|
user: User,
|
|
permissions: Vec<String>,
|
|
}
|
|
|
|
impl Admin {
|
|
fn has_permission(&self, permission: &str) -> bool {
|
|
self.permissions.contains(&permission.to_string())
|
|
}
|
|
}
|
|
|
|
fn create_user(name: String, email: String) -> User {
|
|
User::new(name, email)
|
|
}
|
|
|
|
fn process_batch(users: Vec<User>) -> Vec<String> {
|
|
users
|
|
.iter()
|
|
.filter(|u| u.active)
|
|
.map(|u| u.name.clone())
|
|
.collect()
|
|
}
|
|
""",
|
|
"file": "user.rs",
|
|
"expected_classes": 2,
|
|
},
|
|
"cpp": {
|
|
"code": """
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
class User {
|
|
private:
|
|
std::string name;
|
|
std::string email;
|
|
bool active;
|
|
|
|
public:
|
|
User(const std::string& name, const std::string& email)
|
|
: name(name), email(email), active(true) {}
|
|
|
|
void deactivate() {
|
|
active = false;
|
|
}
|
|
|
|
bool isActive() const {
|
|
return active;
|
|
}
|
|
};
|
|
|
|
class Admin : public User {
|
|
private:
|
|
std::vector<std::string> permissions;
|
|
|
|
public:
|
|
Admin(const std::string& name, const std::string& email)
|
|
: User(name, email) {}
|
|
|
|
void addPermission(const std::string& perm) {
|
|
permissions.push_back(perm);
|
|
}
|
|
};
|
|
|
|
void processUsers(std::vector<User>& users) {
|
|
for (auto& user : users) {
|
|
if (user.isActive()) {
|
|
// 处理活跃用户
|
|
}
|
|
}
|
|
}
|
|
""",
|
|
"file": "user.cpp",
|
|
"expected_classes": 2,
|
|
},
|
|
}
|
|
|
|
def __init__(self):
|
|
self.results = {}
|
|
|
|
def run_benchmark(self, strategy: ChunkStrategy):
|
|
"""运行单个策略的基准测试"""
|
|
print(f"\n{'=' * 60}")
|
|
print(f"测试策略: {strategy.value}")
|
|
print("=" * 60)
|
|
|
|
total_chunks = 0
|
|
total_time = 0
|
|
first_init_time = None
|
|
results = []
|
|
|
|
for name, test_case in self.TEST_CASES.items():
|
|
# 测试初始化时间(首次)
|
|
start_time = time.time()
|
|
chunks = chunk_code(
|
|
test_case["code"],
|
|
test_case["file"],
|
|
strategy=strategy,
|
|
chunk_size=512,
|
|
chunk_overlap=50,
|
|
)
|
|
elapsed = time.time() - start_time
|
|
|
|
if first_init_time is None:
|
|
first_init_time = elapsed
|
|
|
|
total_time += elapsed
|
|
total_chunks += len(chunks)
|
|
|
|
results.append(
|
|
{
|
|
"name": name,
|
|
"chunks": len(chunks),
|
|
"time_ms": elapsed * 1000,
|
|
"chars": len(test_case["code"]),
|
|
}
|
|
)
|
|
|
|
print(f"\n [{name}]")
|
|
print(f" 字符数: {len(test_case['code'])}")
|
|
print(f" 分块数: {len(chunks)}")
|
|
print(f" 耗时: {elapsed * 1000:.2f}ms")
|
|
|
|
print(f"\n总耗时: {total_time * 1000:.2f}ms")
|
|
print(f"总块数: {total_chunks}")
|
|
|
|
return {
|
|
"strategy": strategy.value,
|
|
"total_time_ms": total_time * 1000,
|
|
"total_chunks": total_chunks,
|
|
"first_init_time_ms": first_init_time * 1000 if first_init_time else 0,
|
|
"details": results,
|
|
}
|
|
|
|
def run_all(self):
|
|
"""运行所有策略的测试"""
|
|
print("=" * 60)
|
|
print("📊 代码分块性能测试")
|
|
print("=" * 60)
|
|
|
|
print(f"\n可用的分块策略: {get_available_strategies()}")
|
|
print(f"langchain 可用: {LANGCHAIN_AVAILABLE}")
|
|
print(f"semantic-text-splitter 可用: {SEMANTIC_AVAILABLE}")
|
|
|
|
all_results = []
|
|
|
|
# 测试每个策略
|
|
for strategy_name in get_available_strategies():
|
|
strategy = ChunkStrategy(strategy_name)
|
|
result = self.run_benchmark(strategy)
|
|
all_results.append(result)
|
|
|
|
# 打印汇总
|
|
self.print_summary(all_results)
|
|
|
|
return all_results
|
|
|
|
def print_summary(self, results: List[Dict]):
|
|
"""打印测试汇总"""
|
|
print("\n" + "=" * 60)
|
|
print("📈 测试结果汇总")
|
|
print("=" * 60)
|
|
|
|
print(f"\n{'策略':<15} {'总耗时':<12} {'首次初始化':<12} {'分块数':<10}")
|
|
print("-" * 50)
|
|
|
|
for result in results:
|
|
print(
|
|
f"{result['strategy']:<15} "
|
|
f"{result['total_time_ms']:>8.2f}ms "
|
|
f"{result['first_init_time_ms']:>8.2f}ms "
|
|
f"{result['total_chunks']:>6}"
|
|
)
|
|
|
|
# 找出最快的策略
|
|
fastest = min(results, key=lambda x: x["total_time_ms"])
|
|
most_chunks = max(results, key=lambda x: x["total_chunks"])
|
|
|
|
print(
|
|
f"\n🏆 最快策略: {fastest['strategy']} ({fastest['total_time_ms']:.2f}ms)"
|
|
)
|
|
print(
|
|
f"📦 最多分块: {most_chunks['strategy']} ({most_chunks['total_chunks']} 块)"
|
|
)
|
|
|
|
|
|
def main():
|
|
benchmark = ChunkingBenchmark()
|
|
results = benchmark.run_all()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("✅ 测试完成!")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|