feat(代码分块): 使用 langchain 语法感知分块,支持 29 种编程语言

This commit is contained in:
songsenand 2026-04-16 11:41:36 +08:00
parent 28e557594a
commit 9aee24979f
3 changed files with 3 additions and 424 deletions

View File

@ -199,24 +199,12 @@
### 4.1 代码分块算法
**设计思路**:按代码的语义结构进行分块
```python
# Python 语言的分隔符优先级
separators = [
"\n\nclass ", # 类定义
"\n\ndef ", # 函数定义
"\n\nasync def ", # 异步函数
"\n\n", # 空行段落
"\n", # 换行
" ", # 空格
"" # 字符
]
```
**实现**:使用 langchain 的 `RecursiveCharacterTextSplitter`,支持 29 种编程语言的语法感知分块。
**分块参数**
- `chunk_size=2000`:每块约 2000 字符
- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯
- `separators`:按语言自动确定,按语义优先级(类 > 函数 > 空行段落 > 换行 > 空格 > 字符)递减
### 4.2 向量化策略

View File

@ -11,7 +11,7 @@
- 🚀 **零配置**:嵌入式数据库,安装即用
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
- 🔍 **语义搜索**:通过自然语言查询相关代码片段
- 📦 **多种语言**:支持 Python、JavaScript、TypeScript、Rust、Go 等主流编程语言
- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 29 种编程语言
- ⚡ **极速响应**:搜索延迟 < 100ms
- 🤖 **AI 友好**:输出格式专为 AI 消费设计

View File

@ -1,409 +0,0 @@
#!/usr/bin/env python3
"""
代码分块性能测试脚本
测试不同分块策略的效果和性能
1. langchain - 语言感知分块
2. semantic - 语义分块
3. simple - 简单字符分块
使用方法:
uv run python scripts/test_chunking.py
"""
import sys
import time
from pathlib import Path
from typing import List, Dict, Any
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from ocrag.chunker import (
chunk_code,
ChunkStrategy,
get_available_strategies,
LANGCHAIN_AVAILABLE,
SEMANTIC_AVAILABLE,
)
class ChunkingBenchmark:
"""分块性能测试"""
# 测试代码样本
TEST_CASES = {
"python_small": {
"code": '''
class UserManager:
"""用户管理类"""
def __init__(self):
self.users = {}
def add_user(self, user_id, name):
self.users[user_id] = {'name': name, 'active': True}
return True
def remove_user(self, user_id):
if user_id in self.users:
del self.users[user_id]
return True
return False
''',
"file": "test.py",
"expected_classes": 1,
},
"python_medium": {
"code": '''
class User:
"""用户基类"""
def __init__(self, name, email):
self.name = name
self.email = email
self.active = True
def deactivate(self):
self.active = False
class Admin(User):
"""管理员类"""
def __init__(self, name, email, permissions):
super().__init__(name, email)
self.permissions = permissions
def has_permission(self, permission):
return permission in self.permissions
class Guest(User):
"""访客类"""
def __init__(self, name, email):
super().__init__(name, email)
self.active = False
def create_user(user_type, name, email, **kwargs):
"""用户工厂函数"""
if user_type == 'admin':
return Admin(name, email, kwargs.get('permissions', []))
elif user_type == 'guest':
return Guest(name, email)
else:
return User(name, email)
def validate_email(email):
"""验证邮箱格式"""
return '@' in email and '.' in email.split('@')[1]
def process_batch(users):
"""批量处理用户"""
results = []
for user in users:
if user.active:
results.append({
'name': user.name,
'email': user.email,
'status': 'active'
})
return results
''',
"file": "user_manager.py",
"expected_classes": 3,
},
"python_large": {
"code": '''
import json
from typing import List, Dict, Optional
class DataProcessor:
"""数据处理器"""
def __init__(self, config):
self.config = config
self.cache = {}
def process(self, data):
if not self.validate(data):
raise ValueError("Invalid data format")
return self.transform(data)
def validate(self, data):
return isinstance(data, dict) and 'id' in data
def transform(self, data):
return {
'id': data['id'],
'processed': True,
'timestamp': time.time()
}
class APIClient:
"""API 客户端"""
def __init__(self, base_url, api_key):
self.base_url = base_url
self.api_key = api_key
def get(self, endpoint):
url = f"{self.base_url}/{endpoint}"
headers = {'Authorization': f'Bearer {self.api_key}'}
return self._request('GET', url, headers)
def post(self, endpoint, data):
url = f"{self.base_url}/{endpoint}"
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json'
}
return self._request('POST', url, headers, json.dumps(data))
def _request(self, method, url, headers, body=None):
# 实现请求逻辑
pass
class CacheManager:
"""缓存管理器"""
def __init__(self, max_size=1000):
self.max_size = max_size
self.cache = {}
self.access_order = []
def get(self, key):
if key in self.cache:
self._update_access(key)
return self.cache[key]
return None
def set(self, key, value):
if len(self.cache) >= self.max_size:
oldest = self.access_order.pop(0)
del self.cache[oldest]
self.cache[key] = value
self.access_order.append(key)
def _update_access(self, key):
if key in self.access_order:
self.access_order.remove(key)
self.access_order.append(key)
def clear(self):
self.cache.clear()
self.access_order.clear()
''',
"file": "large_module.py",
"expected_classes": 3,
},
"rust": {
"code": """
struct User {
name: String,
email: String,
active: bool,
}
impl User {
fn new(name: String, email: String) -> Self {
User {
name,
email,
active: true,
}
}
fn deactivate(&mut self) {
self.active = false;
}
}
struct Admin {
user: User,
permissions: Vec<String>,
}
impl Admin {
fn has_permission(&self, permission: &str) -> bool {
self.permissions.contains(&permission.to_string())
}
}
fn create_user(name: String, email: String) -> User {
User::new(name, email)
}
fn process_batch(users: Vec<User>) -> Vec<String> {
users
.iter()
.filter(|u| u.active)
.map(|u| u.name.clone())
.collect()
}
""",
"file": "user.rs",
"expected_classes": 2,
},
"cpp": {
"code": """
#include <string>
#include <vector>
class User {
private:
std::string name;
std::string email;
bool active;
public:
User(const std::string& name, const std::string& email)
: name(name), email(email), active(true) {}
void deactivate() {
active = false;
}
bool isActive() const {
return active;
}
};
class Admin : public User {
private:
std::vector<std::string> permissions;
public:
Admin(const std::string& name, const std::string& email)
: User(name, email) {}
void addPermission(const std::string& perm) {
permissions.push_back(perm);
}
};
void processUsers(std::vector<User>& users) {
for (auto& user : users) {
if (user.isActive()) {
// 处理活跃用户
}
}
}
""",
"file": "user.cpp",
"expected_classes": 2,
},
}
def __init__(self):
self.results = {}
def run_benchmark(self, strategy: ChunkStrategy):
"""运行单个策略的基准测试"""
print(f"\n{'=' * 60}")
print(f"测试策略: {strategy.value}")
print("=" * 60)
total_chunks = 0
total_time = 0
first_init_time = None
results = []
for name, test_case in self.TEST_CASES.items():
# 测试初始化时间(首次)
start_time = time.time()
chunks = chunk_code(
test_case["code"],
test_case["file"],
strategy=strategy,
chunk_size=512,
chunk_overlap=50,
)
elapsed = time.time() - start_time
if first_init_time is None:
first_init_time = elapsed
total_time += elapsed
total_chunks += len(chunks)
results.append(
{
"name": name,
"chunks": len(chunks),
"time_ms": elapsed * 1000,
"chars": len(test_case["code"]),
}
)
print(f"\n [{name}]")
print(f" 字符数: {len(test_case['code'])}")
print(f" 分块数: {len(chunks)}")
print(f" 耗时: {elapsed * 1000:.2f}ms")
print(f"\n总耗时: {total_time * 1000:.2f}ms")
print(f"总块数: {total_chunks}")
return {
"strategy": strategy.value,
"total_time_ms": total_time * 1000,
"total_chunks": total_chunks,
"first_init_time_ms": first_init_time * 1000 if first_init_time else 0,
"details": results,
}
def run_all(self):
"""运行所有策略的测试"""
print("=" * 60)
print("📊 代码分块性能测试")
print("=" * 60)
print(f"\n可用的分块策略: {get_available_strategies()}")
print(f"langchain 可用: {LANGCHAIN_AVAILABLE}")
print(f"semantic-text-splitter 可用: {SEMANTIC_AVAILABLE}")
all_results = []
# 测试每个策略
for strategy_name in get_available_strategies():
strategy = ChunkStrategy(strategy_name)
result = self.run_benchmark(strategy)
all_results.append(result)
# 打印汇总
self.print_summary(all_results)
return all_results
def print_summary(self, results: List[Dict]):
"""打印测试汇总"""
print("\n" + "=" * 60)
print("📈 测试结果汇总")
print("=" * 60)
print(f"\n{'策略':<15} {'总耗时':<12} {'首次初始化':<12} {'分块数':<10}")
print("-" * 50)
for result in results:
print(
f"{result['strategy']:<15} "
f"{result['total_time_ms']:>8.2f}ms "
f"{result['first_init_time_ms']:>8.2f}ms "
f"{result['total_chunks']:>6}"
)
# 找出最快的策略
fastest = min(results, key=lambda x: x["total_time_ms"])
most_chunks = max(results, key=lambda x: x["total_chunks"])
print(
f"\n🏆 最快策略: {fastest['strategy']} ({fastest['total_time_ms']:.2f}ms)"
)
print(
f"📦 最多分块: {most_chunks['strategy']} ({most_chunks['total_chunks']} 块)"
)
def main():
benchmark = ChunkingBenchmark()
results = benchmark.run_all()
print("\n" + "=" * 60)
print("✅ 测试完成!")
print("=" * 60)
if __name__ == "__main__":
main()