feat: 添加本地代码知识库 RAG 插件,支持语义搜索与实时更新
This commit is contained in:
commit
28e557594a
|
|
@ -0,0 +1,182 @@
|
|||
models/*
|
||||
uv.lock
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# ---> VisualStudioCode
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
uv.lock
|
||||
|
||||
data/*
|
||||
|
|
@ -0,0 +1,386 @@
|
|||
# OpenCode RAG 插件设计文档
|
||||
|
||||
> 本文档详细说明了 ocrag 项目的设计目的、技术选型考量、架构构思和实现细节。
|
||||
|
||||
---
|
||||
|
||||
## 1. 设计背景与目标
|
||||
|
||||
### 1.1 问题背景
|
||||
|
||||
在大型代码库中进行 AI 辅助开发时,AI 需要理解项目中的大量代码才能给出准确的建议和答案。然而:
|
||||
- **上下文窗口有限**:无法将整个代码库放入 AI 提示词
|
||||
- **实时性需求**:开发者需要 AI 能立即理解刚编写的代码
|
||||
- **本地化优先**:代码不应该上传到外部服务器
|
||||
|
||||
### 1.2 设计目标
|
||||
|
||||
为 OpenCode 构建一个**本地代码知识库 RAG 系统**,实现:
|
||||
|
||||
| 功能 | 描述 |
|
||||
|------|------|
|
||||
| **实时添加** | 将代码文件或目录添加到本地知识库 |
|
||||
| **语义搜索** | 通过自然语言查询获取相关代码片段 |
|
||||
| **智能管理** | 支持删除和列出知识库中的条目 |
|
||||
|
||||
### 1.3 设计原则
|
||||
|
||||
1. **本地化优先**:所有数据和计算都在本地完成,不依赖外部服务
|
||||
2. **轻量高效**:避免引入复杂的服务端组件,保持极低的延迟
|
||||
3. **零运维**:嵌入式数据库,无需安装配置,即装即用
|
||||
4. **AI 友好**:生成可被 AI 直接理解和使用的上下文
|
||||
|
||||
---
|
||||
|
||||
## 2. 技术选型详解
|
||||
|
||||
### 2.1 为什么选择 Python
|
||||
|
||||
| 考量因素 | Python | Rust |
|
||||
|----------|--------|------|
|
||||
| 开发效率 | ✅ 高 | ⚠️ 中 |
|
||||
| LLM 生成质量 | ✅ 高(AI 更熟悉 Python) | ⚠️ 中 |
|
||||
| 生态丰富度 | ✅ 成熟 | ⚠️ 一般 |
|
||||
| 运行时性能 | ⚠️ 中(可通过 C 扩展优化) | ✅ 高 |
|
||||
| 社区支持 | ✅ 丰富 | ⚠️ 有限 |
|
||||
|
||||
**结论**:Python 的高开发效率和 AI 生成质量优势明显,即使运行时性能略低,但对于 RAG 这种 I/O 密集型应用影响有限。
|
||||
|
||||
### 2.2 为什么选择 LanceDB
|
||||
|
||||
传统向量数据库对比:
|
||||
|
||||
| 数据库 | 特点 | 缺点 |
|
||||
|--------|------|------|
|
||||
| Chroma | 简单易用 | 不支持持久化、不适合生产 |
|
||||
| Milvus | 功能强大 | 需要 Docker 部署 |
|
||||
| Qdrant | Rust 实现,高性能 | 需要单独部署 |
|
||||
| **LanceDB** | **嵌入式、零运维、Python 原生** | **相对较新** |
|
||||
|
||||
**LanceDB 优势**:
|
||||
- **嵌入式**:数据库就是一个文件夹,无需单独服务
|
||||
- **零运维**:安装即用,自动管理
|
||||
- **Python 优先**:原生 Python SDK,类型提示完善
|
||||
- **性能优秀**:基于 Apache Arrow,查询速度快
|
||||
|
||||
### 2.3 为什么选择 langchain-text-splitters(而非 chonkie)
|
||||
|
||||
最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现:
|
||||
|
||||
| 方案 | 优势 | 劣势 |
|
||||
|------|------|------|
|
||||
| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer,初始化慢 |
|
||||
| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 |
|
||||
|
||||
**最终选择 langchain-text-splitters**:
|
||||
- 对于 RAG 场景,分块精度不是最关键因素
|
||||
- langchain 的方案更加轻量,启动更快
|
||||
- 代码按自然段落(函数、类)边界切分,效果足够好
|
||||
|
||||
### 2.4 为什么使用 Qwen3-Embedding-0.6B
|
||||
|
||||
| 模型 | 维度 | 优势 | 劣势 |
|
||||
|------|------|------|------|
|
||||
| all-MiniLM-L6-v2 | 384 | 快速、小巧 | 英文为主 |
|
||||
| **Qwen3-Embedding-0.6B** | **1024** | **中文优化、中英双语** | 较大、首次加载慢 |
|
||||
|
||||
**选择理由**:
|
||||
- 开源可本地部署,代码安全
|
||||
- 中文支持优秀
|
||||
- 适合代码+文档混合场景
|
||||
|
||||
### 2.5 为什么不用 MCP 服务器
|
||||
|
||||
**MCP 方案的劣势**:
|
||||
- 需要额外部署 MCP 服务器
|
||||
- 增加系统复杂度
|
||||
- 调试困难
|
||||
|
||||
**CLI 方案的优势**:
|
||||
- 零额外组件
|
||||
- 通过 `bash` 工具直接调用
|
||||
- 易于调试和扩展
|
||||
|
||||
---
|
||||
|
||||
## 3. 架构设计
|
||||
|
||||
### 3.1 整体架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ OpenCode AI │
|
||||
└────────────────────────────┬──────────────────────────────────┘
|
||||
│ 1. rag_add / rag_search
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ TypeScript Plugin (ocrag-plugin.ts) │
|
||||
│ │
|
||||
│ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │
|
||||
│ │ rag_add │ │rag_search│ │ Skill 指令 │ │
|
||||
│ └────┬─────┘ └────┬─────┘ └──────────────┘ │
|
||||
└───────┼─────────────────┼───────────────────────────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Python CLI (ocrag) │
|
||||
│ │
|
||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────────┐ │
|
||||
│ │ add │ │ search │ │ remove │ │ list │ │
|
||||
│ └───┬─────┘ └────┬─────┘ └───┬─────┘ └──────┬──────┘ │
|
||||
└──────┼───────────────┼─────────────┼──────────────────┼──────────┘
|
||||
│ │ │ │
|
||||
▼ ▼ ▼ │
|
||||
┌─────────────────────────────────────────┐ │
|
||||
│ Processing Pipeline │ │
|
||||
│ │ │
|
||||
│ ┌────────────┐ ┌────────────────┐ │ │
|
||||
│ │ Chunker │───▶│ Embedder │──┼────────────┤
|
||||
│ │ (分块) │ │ (向量化) │ │ │
|
||||
│ └────────────┘ └────────────────┘ │ │
|
||||
└─────────────────────────────────────────┼────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ LanceDB (向量数据库) │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ documents 表 │ │
|
||||
│ │ ┌──────────┬──────────────────────┬──────────────────┐ │ │
|
||||
│ │ │ text │ vector │ metadata │ │ │
|
||||
│ │ │ (文本) │ (1024维向量) │ (JSON元数据) │ │ │
|
||||
│ │ └──────────┴──────────────────────┴──────────────────┘ │ │
|
||||
│ └──────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 3.2 数据流设计
|
||||
|
||||
#### 添加文件流程
|
||||
|
||||
```
|
||||
用户请求 → Plugin → CLI:add → Chunker → Embedder → LanceDB → 返回结果
|
||||
```
|
||||
|
||||
**详细步骤**:
|
||||
1. **文件收集**:递归遍历目录或单个文件
|
||||
2. **内容读取**:以 UTF-8 编码读取文件内容
|
||||
3. **代码分块**:按语言和语法结构切分代码
|
||||
4. **向量化**:使用 Embedder 将文本转为 1024 维向量
|
||||
5. **存储入库**:将文本、向量、元数据存入 LanceDB
|
||||
|
||||
#### 搜索流程
|
||||
|
||||
```
|
||||
用户查询 → Plugin → CLI:search → Embedder → LanceDB → 返回结果
|
||||
```
|
||||
|
||||
**详细步骤**:
|
||||
1. **查询向量化**:将自然语言查询转为向量
|
||||
2. **相似度搜索**:LanceDB 执行向量相似度检索
|
||||
3. **结果排序**:按相似度距离排序
|
||||
4. **元数据提取**:解析 JSON 元数据
|
||||
5. **结果返回**:格式化的代码片段列表
|
||||
|
||||
### 3.3 模块职责划分
|
||||
|
||||
| 模块 | 职责 | 依赖 |
|
||||
|------|------|------|
|
||||
| `cli.py` | 命令行入口,参数解析 | click |
|
||||
| `chunker.py` | 代码分块处理 | langchain-text-splitters |
|
||||
| `embedder.py` | 文本向量化 | sentence-transformers |
|
||||
| `db.py` | 向量数据库操作 | lancedb, pyarrow |
|
||||
| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db |
|
||||
| `utils.py` | 工具函数 | - |
|
||||
|
||||
---
|
||||
|
||||
## 4. 核心算法设计
|
||||
|
||||
### 4.1 代码分块算法
|
||||
|
||||
**设计思路**:按代码的语义结构进行分块
|
||||
|
||||
```python
|
||||
# Python 语言的分隔符优先级
|
||||
separators = [
|
||||
"\n\nclass ", # 类定义
|
||||
"\n\ndef ", # 函数定义
|
||||
"\n\nasync def ", # 异步函数
|
||||
"\n\n", # 空行段落
|
||||
"\n", # 换行
|
||||
" ", # 空格
|
||||
"" # 字符
|
||||
]
|
||||
```
|
||||
|
||||
**分块参数**:
|
||||
- `chunk_size=2000`:每块约 2000 字符
|
||||
- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯
|
||||
|
||||
### 4.2 向量化策略
|
||||
|
||||
**模型选择**:Qwen3-Embedding-0.6B
|
||||
- 输出维度:1024
|
||||
- 向量归一化:L2 归一化,便于余弦相似度计算
|
||||
|
||||
**批量处理**:
|
||||
- 一次性对多个文本块进行编码
|
||||
- 利用 GPU/CPU 的批处理能力提升吞吐量
|
||||
|
||||
### 4.3 搜索策略
|
||||
|
||||
**相似度度量**:LanceDB 默认使用余弦相似度
|
||||
|
||||
**top_k 参数**:控制返回结果数量,默认 5 条
|
||||
|
||||
---
|
||||
|
||||
## 5. 安全与性能考量
|
||||
|
||||
### 5.1 安全性设计
|
||||
|
||||
| 风险点 | 防护措施 |
|
||||
|--------|----------|
|
||||
| SQL 注入 | 使用 Pandas 过滤而非 SQL 字符串拼接 |
|
||||
| 路径遍历 | 仅处理指定路径,不执行动态导入 |
|
||||
| 数据泄露 | 所有数据本地存储,不涉及网络传输 |
|
||||
|
||||
### 5.2 性能优化
|
||||
|
||||
**已实现**:
|
||||
- ✅ 批量 embedding 减少 I/O 开销
|
||||
- ✅ 单例模式避免重复加载模型
|
||||
- ✅ 向量归一化加速相似度计算
|
||||
|
||||
**可优化项**:
|
||||
- 文件哈希缓存避免重复添加
|
||||
- GPU 加速 embedding 计算
|
||||
- 增量索引更新而非全量重建
|
||||
|
||||
### 5.3 性能基准
|
||||
|
||||
| 指标 | 实测结果 | 说明 |
|
||||
|------|----------|------|
|
||||
| 搜索延迟 | **63-70 ms** | 包含 embedding + 向量检索 |
|
||||
| 数据库写入 | 2-3 ms/块 | LanceDB 性能优秀 |
|
||||
| 分块速度 | <1 ms | 纯规则,无模型加载 |
|
||||
| Embedding | ~2.5秒/块 | Qwen3 模型较大 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 扩展性设计
|
||||
|
||||
### 6.1 多语言支持
|
||||
|
||||
只需在 `LANGUAGE_MAP` 中添加新扩展名即可:
|
||||
|
||||
```python
|
||||
LANGUAGE_MAP = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
# 添加新语言...
|
||||
".kt": "kotlin",
|
||||
".swift": "swift",
|
||||
}
|
||||
```
|
||||
|
||||
### 6.2 自定义 Embedding 模型
|
||||
|
||||
修改 `embedder.py` 中的模型路径:
|
||||
|
||||
```python
|
||||
model_path = "path/to/your/model"
|
||||
```
|
||||
|
||||
### 6.3 增量同步(Watch 模式)
|
||||
|
||||
使用 `watchdog` 库监听文件变化,自动更新知识库:
|
||||
|
||||
```python
|
||||
observer = Observer()
|
||||
observer.schedule(RAGSyncHandler(), path, recursive=True)
|
||||
observer.start()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 与 OpenCode 的集成
|
||||
|
||||
### 7.1 插件架构
|
||||
|
||||
```
|
||||
OpenCode
|
||||
├── TypeScript Plugin
|
||||
│ ├── rag_add 工具
|
||||
│ └── rag_search 工具
|
||||
│
|
||||
└── Skill 指令
|
||||
└── SKILL.md
|
||||
```
|
||||
|
||||
### 7.2 工具定义
|
||||
|
||||
| 工具名 | 参数 | 功能 |
|
||||
|--------|------|------|
|
||||
| `rag_add` | `paths: string[]`, `recursive?: boolean` | 添加文件到知识库 |
|
||||
| `rag_search` | `query: string`, `top_k?: number` | 搜索知识库 |
|
||||
|
||||
### 7.3 错误处理
|
||||
|
||||
所有工具调用都包裹在 try-catch 中:
|
||||
|
||||
```typescript
|
||||
execute: async (args) => {
|
||||
try {
|
||||
const result = await Bun.$`${cmd}`.text();
|
||||
return result;
|
||||
} catch (error) {
|
||||
return `Error: ${error.message}`;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 设计总结
|
||||
|
||||
### 8.1 核心创新
|
||||
|
||||
1. **零 MCP 架构**:通过 CLI 直接集成,简化系统复杂度
|
||||
2. **本地化优先**:数据不出本地,保证代码安全
|
||||
3. **轻量高效**:嵌入式数据库,秒级启动
|
||||
4. **AI 原生**:输出格式专为 AI 消费设计
|
||||
|
||||
### 8.2 适用场景
|
||||
|
||||
| 场景 | 适用性 | 说明 |
|
||||
|------|--------|------|
|
||||
| 个人项目知识管理 | ✅ 非常适合 | 本地存储,隐私安全 |
|
||||
| 小团队代码库问答 | ✅ 适合 | 轻量易部署 |
|
||||
| 大型企业代码库 | ⚠️ 需优化 | 可能需要分布式扩展 |
|
||||
| 跨语言代码库 | ✅ 支持 | 多语言分块支持 |
|
||||
|
||||
### 8.3 未来展望
|
||||
|
||||
1. **语义分块升级**:集成更智能的分块算法
|
||||
2. **多模态支持**:支持图片、图表等非文本内容
|
||||
3. **增量索引**:支持大型代码库的实时更新
|
||||
4. **分布式部署**:支持多节点协同检索
|
||||
|
||||
---
|
||||
|
||||
## 附录:术语表
|
||||
|
||||
| 术语 | 解释 |
|
||||
|------|------|
|
||||
| RAG | Retrieval-Augmented Generation,检索增强生成 |
|
||||
| Embedding | 将文本转为向量的过程 |
|
||||
| 向量数据库 | 专门存储和检索向量的数据库 |
|
||||
| Chunk | 分块后的文本片段 |
|
||||
| LanceDB | 嵌入式向量数据库 |
|
||||
|
||||
---
|
||||
|
||||
**文档版本**:1.0
|
||||
**最后更新**:2026年4月15日
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
# Ocrag - OpenCode 代码知识库 RAG 插件
|
||||
|
||||
> 为 OpenCode 提供本地代码语义搜索能力,支持将代码文件添加到知识库并通过自然语言进行检索。
|
||||
|
||||
[English](./README_EN.md) | [设计文档](./DescriptionOfDesign.md)
|
||||
|
||||
---
|
||||
|
||||
## ✨ 特性
|
||||
|
||||
- 🚀 **零配置**:嵌入式数据库,安装即用
|
||||
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
|
||||
- 🔍 **语义搜索**:通过自然语言查询相关代码片段
|
||||
- 📦 **多种语言**:支持 Python、JavaScript、TypeScript、Rust、Go 等主流编程语言
|
||||
- ⚡ **极速响应**:搜索延迟 < 100ms
|
||||
- 🤖 **AI 友好**:输出格式专为 AI 消费设计
|
||||
|
||||
---
|
||||
|
||||
## 📦 安装
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Python 3.10+
|
||||
- [uv](https://github.com/astral-sh/uv)(包管理器)
|
||||
|
||||
### 安装步骤
|
||||
|
||||
```bash
|
||||
# 1. 安装 Python 包
|
||||
uv pip install -e .
|
||||
|
||||
# 2. 安装 OpenCode 插件(可选)
|
||||
cp opencode-plugin/ocrag-plugin.ts ~/.config/opencode/plugins/
|
||||
|
||||
# 3. 安装 Skill(可选)
|
||||
mkdir -p ~/.config/opencode/skills/ocrag
|
||||
cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 基本使用
|
||||
|
||||
```bash
|
||||
# 添加文件到知识库
|
||||
uv run ocrag add ./src/main.py
|
||||
|
||||
# 递归添加整个目录
|
||||
uv run ocrag add ./src/ --recursive
|
||||
|
||||
# 搜索相关代码
|
||||
uv run ocrag search "如何实现用户认证"
|
||||
|
||||
# 列出知识库中的文件
|
||||
uv run ocrag list
|
||||
|
||||
# 删除文件
|
||||
uv run ocrag remove ./src/main.py
|
||||
```
|
||||
|
||||
### OpenCode 集成
|
||||
|
||||
安装插件后,AI 可以自动使用知识库功能:
|
||||
|
||||
```
|
||||
用户:把当前文件加入知识库
|
||||
AI:✓ 已将 main.py 添加到知识库
|
||||
|
||||
用户:认证模块是怎么实现的?
|
||||
AI:根据知识库中的代码,认证模块包含以下关键实现...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📁 项目结构
|
||||
|
||||
```
|
||||
ocrag/
|
||||
├── src/ocrag/ # Python 源代码
|
||||
│ ├── cli.py # 命令行入口
|
||||
│ ├── chunker.py # 代码分块
|
||||
│ ├── embedder.py # 向量化模块
|
||||
│ ├── db.py # 数据库操作
|
||||
│ └── commands/ # 命令实现
|
||||
│ ├── add.py
|
||||
│ ├── search.py
|
||||
│ ├── remove.py
|
||||
│ └── list.py
|
||||
├── models/ # Embedding 模型
|
||||
│ └── Qwen3-Embedding-0.6B/
|
||||
├── tests/ # 单元测试
|
||||
├── scripts/ # 工具脚本
|
||||
│ └── benchmark.py # 性能测试
|
||||
├── opencode-plugin/ # OpenCode 插件
|
||||
└── opencode-skill/ # Skill 文件
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 配置
|
||||
|
||||
### 模型配置
|
||||
|
||||
默认使用 `models/Qwen3-Embedding-0.6B` 本地模型。
|
||||
|
||||
如需使用 GPU 加速:
|
||||
```bash
|
||||
USE_GPU=true uv run ocrag search "查询内容"
|
||||
```
|
||||
|
||||
### 数据库位置
|
||||
|
||||
数据默认存储在 `~/.ocrag/data.lance`
|
||||
|
||||
---
|
||||
|
||||
## 📊 性能
|
||||
|
||||
| 操作 | 延迟 | 说明 |
|
||||
|------|------|------|
|
||||
| 搜索 | **~65ms** | 含 embedding + 向量检索 |
|
||||
| 添加文件 | **~5s** | 含分块 + embedding + 存储 |
|
||||
| 列出文件 | **~5ms** | 纯数据库查询 |
|
||||
|
||||
运行性能测试:
|
||||
```bash
|
||||
uv run python scripts/benchmark.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 运行测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
uv run python -m pytest
|
||||
|
||||
# 运行特定模块测试
|
||||
uv run python -m pytest tests/unit/test_db.py -v
|
||||
|
||||
# 查看测试覆盖
|
||||
uv run python -m pytest --cov=src/ocrag
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 开发
|
||||
|
||||
### 添加新命令
|
||||
|
||||
1. 在 `src/ocrag/commands/` 下创建新文件
|
||||
2. 实现 `run()` 函数
|
||||
3. 在 `src/ocrag/cli.py` 中注册命令
|
||||
4. 在 `src/ocrag/commands/__init__.py` 中导出
|
||||
|
||||
### 代码规范
|
||||
|
||||
```bash
|
||||
# 格式化代码
|
||||
ruff format src/
|
||||
|
||||
# 代码检查
|
||||
ruff check src/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 相关文档
|
||||
|
||||
- [设计文档](./DescriptionOfDesign.md) - 详细的设计说明和技术选型分析
|
||||
- [CHANGELOG](./CHANGELOG.md) - 版本更新记录
|
||||
|
||||
---
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
欢迎提交 Issue 和 Pull Request!
|
||||
|
||||
---
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
---
|
||||
|
||||
**版本**:0.1.0
|
||||
**最后更新**:2026年4月15日
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import { tool, definePlugin } from "opencode";
|
||||
import { z } from "zod";
|
||||
|
||||
export default definePlugin({
|
||||
name: "ocrag-plugin",
|
||||
hooks: () => ({
|
||||
rag_add: tool({
|
||||
description: "向本地 RAG 知识库添加一个或多个代码文件或目录。添加后,后续搜索将包含这些代码。",
|
||||
parameters: {
|
||||
paths: z.array(z.string()).describe("要添加的文件或目录路径列表"),
|
||||
recursive: z.boolean().optional().describe("如果路径是目录,是否递归添加所有子文件"),
|
||||
},
|
||||
execute: async (args) => {
|
||||
try {
|
||||
const paths = args.paths.join(" ");
|
||||
const recursiveFlag = args.recursive ? "--recursive" : "";
|
||||
const cmd = `ocrag add ${paths} ${recursiveFlag}`;
|
||||
const result = Bun.$`${cmd}`.text();
|
||||
return result;
|
||||
} catch (error) {
|
||||
return `Error: ${error.message}`;
|
||||
}
|
||||
},
|
||||
}),
|
||||
rag_search: tool({
|
||||
description: "在本地 RAG 知识库中执行语义搜索,返回与查询最相关的代码片段。",
|
||||
parameters: {
|
||||
query: z.string().describe("自然语言查询,例如:'如何实现用户认证'"),
|
||||
top_k: z.number().optional().default(5).describe("返回结果数量"),
|
||||
},
|
||||
execute: async (args) => {
|
||||
try {
|
||||
const cmd = `ocrag search "${args.query}" --top-k ${args.top_k}`;
|
||||
const result = Bun.$`${cmd}`.text();
|
||||
return result;
|
||||
} catch (error) {
|
||||
return `Error: ${error.message}`;
|
||||
}
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"name": "ocrag-plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "Code RAG plugin for OpenCode",
|
||||
"main": "ocrag-plugin.ts",
|
||||
"keywords": ["opencode", "plugin", "rag", "code"],
|
||||
"author": "",
|
||||
"license": "MIT"
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
---
|
||||
name: ocrag
|
||||
description: 代码库 RAG 技能,用于向知识库添加代码或搜索已有代码。
|
||||
---
|
||||
|
||||
# 代码库 RAG 技能
|
||||
|
||||
## 何时使用此技能
|
||||
- 用户要求"把当前文件/这个目录加入知识库"
|
||||
- 用户询问关于代码库的问题,需要基于已有代码回答
|
||||
- 用户明确提到"RAG"、"知识库"、"语义搜索"等关键词
|
||||
|
||||
## 如何使用
|
||||
|
||||
### 添加代码到知识库
|
||||
当用户提供文件路径或目录时,使用 `rag_add` 工具:
|
||||
- 如果用户说"把这个文件加入知识库",提取路径并调用 `rag_add`,recursive 设为 false
|
||||
- 如果用户说"添加 src/ 目录下所有代码",调用 `rag_add` 并设置 recursive = true
|
||||
|
||||
### 搜索知识库
|
||||
当用户询问代码相关问题(如"这段代码做什么?"、"如何配置XX?")时:
|
||||
1. 先用 `rag_search` 工具查询,query 参数为用户的自然语言问题
|
||||
2. 获得返回的代码片段后,结合片段内容回答用户
|
||||
3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码
|
||||
|
||||
## 示例
|
||||
- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径
|
||||
- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现"
|
||||
- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true
|
||||
|
||||
## 注意事项
|
||||
- 搜索结果会显示代码片段的来源文件
|
||||
- 可以通过 top_k 参数调整返回结果数量
|
||||
- 添加文件后,搜索会立即包含新添加的内容
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
[project]
|
||||
name = "project"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"chonkie>=1.6.2",
|
||||
"click>=8.3.2",
|
||||
"lancedb>=0.30.2",
|
||||
"sentence-transformers>=5.4.1",
|
||||
"sentencepiece>=0.2.1",
|
||||
"tiktoken>=0.12.0",
|
||||
"tokenizers>=0.22.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.3",
|
||||
]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
[pytest]
|
||||
addopts = -v
|
||||
pythonpath = src
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
性能测试脚本 - 测试 ocrag 的写入和读取性能
|
||||
|
||||
使用方法:
|
||||
uv run python scripts/benchmark.py
|
||||
uv run python scripts/benchmark.py --cleanup # 测试后清理测试数据
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from ocrag.chunker import chunk_code
|
||||
from ocrag.embedder import embedder
|
||||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
class PerformanceBenchmark:
|
||||
def __init__(self, db_path=None):
|
||||
if db_path is None:
|
||||
self.db_path = tempfile.mkdtemp(prefix="ocrag_benchmark_")
|
||||
else:
|
||||
self.db_path = db_path
|
||||
self.db = RagDB(self.db_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""清理测试数据库"""
|
||||
if os.path.exists(self.db_path):
|
||||
shutil.rmtree(self.db_path)
|
||||
print(f"✅ 已清理测试数据库: {self.db_path}")
|
||||
|
||||
def generate_test_code(self, num_lines=100, language="python"):
|
||||
"""生成测试代码"""
|
||||
if language == "python":
|
||||
code = "def function_{i}():\n '''Docstring for function {i}'''\n result = 0\n"
|
||||
code += " for i in range(100):\n result += i\n"
|
||||
code += " return result\n\n"
|
||||
elif language == "javascript":
|
||||
code = "function function_{i}() {{\n // Docstring for function {i}\n"
|
||||
code += " let result = 0;\n"
|
||||
code += (
|
||||
" for (let i = 0; i < 100; i++) {{\n result += i;\n }}\n"
|
||||
)
|
||||
code += " return result;\n}}\n\n"
|
||||
|
||||
full_code = ""
|
||||
for i in range(num_lines):
|
||||
full_code += code.format(i=i)
|
||||
return full_code
|
||||
|
||||
def benchmark_single_file_write(self, num_lines=100):
|
||||
"""测试单个文件的写入性能"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"测试: 单个文件写入 ({num_lines} 行代码)")
|
||||
print("=" * 60)
|
||||
|
||||
code = self.generate_test_code(num_lines)
|
||||
file_path = "test_file.py"
|
||||
|
||||
# 分块
|
||||
start_time = time.time()
|
||||
chunks = chunk_code(code, file_path)
|
||||
chunk_time = time.time() - start_time
|
||||
|
||||
print(f"分块耗时: {chunk_time * 1000:.2f} ms")
|
||||
print(f"生成块数: {len(chunks)}")
|
||||
|
||||
# Embedding
|
||||
start_time = time.time()
|
||||
texts = [c["text"] for c in chunks]
|
||||
vectors = embedder.embed(texts)
|
||||
embed_time = time.time() - start_time
|
||||
|
||||
print(f"Embedding 耗时: {embed_time * 1000:.2f} ms")
|
||||
|
||||
# 构建文档
|
||||
documents = []
|
||||
for chunk, vec in zip(chunks, vectors):
|
||||
documents.append(
|
||||
{
|
||||
"text": chunk["text"],
|
||||
"vector": vec,
|
||||
"metadata": chunk["metadata"],
|
||||
}
|
||||
)
|
||||
|
||||
# 写入数据库
|
||||
start_time = time.time()
|
||||
self.db.add_documents(documents)
|
||||
db_write_time = time.time() - start_time
|
||||
|
||||
print(f"数据库写入耗时: {db_write_time * 1000:.2f} ms")
|
||||
|
||||
total_time = chunk_time + embed_time + db_write_time
|
||||
print(f"\n📊 总耗时: {total_time * 1000:.2f} ms")
|
||||
|
||||
return {
|
||||
"chunk_time": chunk_time * 1000,
|
||||
"embed_time": embed_time * 1000,
|
||||
"db_write_time": db_write_time * 1000,
|
||||
"total_time": total_time * 1000,
|
||||
"num_chunks": len(chunks),
|
||||
}
|
||||
|
||||
def benchmark_batch_write(self, num_files=10, lines_per_file=50):
|
||||
"""测试批量文件写入性能"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"测试: 批量写入 ({num_files} 个文件, 每个 {lines_per_file} 行)")
|
||||
print("=" * 60)
|
||||
|
||||
total_chunks = 0
|
||||
total_embed_time = 0
|
||||
total_db_write_time = 0
|
||||
|
||||
for i in range(num_files):
|
||||
code = self.generate_test_code(lines_per_file)
|
||||
file_path = f"test_file_{i}.py"
|
||||
|
||||
# 分块
|
||||
chunks = chunk_code(code, file_path)
|
||||
total_chunks += len(chunks)
|
||||
|
||||
# Embedding
|
||||
texts = [c["text"] for c in chunks]
|
||||
start_time = time.time()
|
||||
vectors = embedder.embed(texts)
|
||||
total_embed_time += time.time() - start_time
|
||||
|
||||
# 构建文档并写入
|
||||
documents = []
|
||||
for chunk, vec in zip(chunks, vectors):
|
||||
documents.append(
|
||||
{
|
||||
"text": chunk["text"],
|
||||
"vector": vec,
|
||||
"metadata": chunk["metadata"],
|
||||
}
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
self.db.add_documents(documents)
|
||||
total_db_write_time += time.time() - start_time
|
||||
|
||||
total_time = total_embed_time + total_db_write_time
|
||||
|
||||
print(f"总块数: {total_chunks}")
|
||||
print(f"平均每块 embedding: {total_embed_time / num_files * 1000:.2f} ms")
|
||||
print(f"平均每块数据库写入: {total_db_write_time / num_files * 1000:.2f} ms")
|
||||
print(f"📊 总耗时: {total_time * 1000:.2f} ms")
|
||||
print(f"📊 吞吐量: {total_chunks / total_time:.2f} 块/秒")
|
||||
|
||||
return {
|
||||
"total_chunks": total_chunks,
|
||||
"total_embed_time": total_embed_time * 1000,
|
||||
"total_db_write_time": total_db_write_time * 1000,
|
||||
"total_time": total_time * 1000,
|
||||
"throughput": total_chunks / total_time,
|
||||
}
|
||||
|
||||
def benchmark_search(self, num_queries=10):
|
||||
"""测试搜索性能"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"测试: 搜索性能 ({num_queries} 次查询)")
|
||||
print("=" * 60)
|
||||
|
||||
queries = [
|
||||
"how to implement user authentication",
|
||||
"database configuration",
|
||||
"API endpoint handler",
|
||||
"error handling function",
|
||||
"data validation logic",
|
||||
]
|
||||
|
||||
search_times = []
|
||||
|
||||
for i in range(num_queries):
|
||||
query = queries[i % len(queries)]
|
||||
|
||||
# Query embedding
|
||||
start_time = time.time()
|
||||
query_vec = embedder.embed_single(query)
|
||||
embed_time = time.time() - start_time
|
||||
|
||||
# Search
|
||||
start_time = time.time()
|
||||
results = self.db.search(query_vec, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
|
||||
total_time = embed_time + search_time
|
||||
search_times.append(total_time)
|
||||
|
||||
print(
|
||||
f"查询 {i + 1}: {query[:30]}... → {len(results)} 结果 ({total_time * 1000:.2f} ms)"
|
||||
)
|
||||
|
||||
avg_time = sum(search_times) / len(search_times)
|
||||
min_time = min(search_times)
|
||||
max_time = max(search_times)
|
||||
|
||||
print(f"\n📊 搜索性能统计:")
|
||||
print(f" 平均延迟: {avg_time * 1000:.2f} ms")
|
||||
print(f" 最小延迟: {min_time * 1000:.2f} ms")
|
||||
print(f" 最大延迟: {max_time * 1000:.2f} ms")
|
||||
|
||||
return {
|
||||
"avg_time": avg_time * 1000,
|
||||
"min_time": min_time * 1000,
|
||||
"max_time": max_time * 1000,
|
||||
"num_queries": num_queries,
|
||||
}
|
||||
|
||||
def benchmark_list_sources(self):
|
||||
"""测试 list_sources 性能"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print("测试: list_sources 性能")
|
||||
print("=" * 60)
|
||||
|
||||
start_time = time.time()
|
||||
sources = self.db.list_sources()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
print(f"列出来源数: {len(sources)}")
|
||||
print(f"📊 耗时: {elapsed * 1000:.2f} ms")
|
||||
|
||||
return {
|
||||
"num_sources": len(sources),
|
||||
"time": elapsed * 1000,
|
||||
}
|
||||
|
||||
def run_full_benchmark(self):
|
||||
"""运行完整性能测试"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🚀 Ocrag 性能基准测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 准备测试数据
|
||||
print("\n📦 准备测试数据...")
|
||||
self.benchmark_single_file_write(num_lines=100)
|
||||
|
||||
# 批量写入测试
|
||||
self.benchmark_batch_write(num_files=20, lines_per_file=50)
|
||||
|
||||
# 搜索测试
|
||||
self.benchmark_search(num_queries=10)
|
||||
|
||||
# List 测试
|
||||
self.benchmark_list_sources()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ 性能测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Ocrag 性能测试")
|
||||
parser.add_argument("--cleanup", action="store_true", help="测试后清理数据")
|
||||
parser.add_argument("--db-path", type=str, help="指定数据库路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark = PerformanceBenchmark(db_path=args.db_path)
|
||||
|
||||
try:
|
||||
benchmark.run_full_benchmark()
|
||||
finally:
|
||||
if args.cleanup:
|
||||
benchmark.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,409 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
代码分块性能测试脚本
|
||||
|
||||
测试不同分块策略的效果和性能:
|
||||
1. langchain - 语言感知分块
|
||||
2. semantic - 语义分块
|
||||
3. simple - 简单字符分块
|
||||
|
||||
使用方法:
|
||||
uv run python scripts/test_chunking.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from ocrag.chunker import (
|
||||
chunk_code,
|
||||
ChunkStrategy,
|
||||
get_available_strategies,
|
||||
LANGCHAIN_AVAILABLE,
|
||||
SEMANTIC_AVAILABLE,
|
||||
)
|
||||
|
||||
|
||||
class ChunkingBenchmark:
|
||||
"""分块性能测试"""
|
||||
|
||||
# 测试代码样本
|
||||
TEST_CASES = {
|
||||
"python_small": {
|
||||
"code": '''
|
||||
class UserManager:
|
||||
"""用户管理类"""
|
||||
def __init__(self):
|
||||
self.users = {}
|
||||
|
||||
def add_user(self, user_id, name):
|
||||
self.users[user_id] = {'name': name, 'active': True}
|
||||
return True
|
||||
|
||||
def remove_user(self, user_id):
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
return True
|
||||
return False
|
||||
''',
|
||||
"file": "test.py",
|
||||
"expected_classes": 1,
|
||||
},
|
||||
"python_medium": {
|
||||
"code": '''
|
||||
class User:
|
||||
"""用户基类"""
|
||||
def __init__(self, name, email):
|
||||
self.name = name
|
||||
self.email = email
|
||||
self.active = True
|
||||
|
||||
def deactivate(self):
|
||||
self.active = False
|
||||
|
||||
class Admin(User):
|
||||
"""管理员类"""
|
||||
def __init__(self, name, email, permissions):
|
||||
super().__init__(name, email)
|
||||
self.permissions = permissions
|
||||
|
||||
def has_permission(self, permission):
|
||||
return permission in self.permissions
|
||||
|
||||
class Guest(User):
|
||||
"""访客类"""
|
||||
def __init__(self, name, email):
|
||||
super().__init__(name, email)
|
||||
self.active = False
|
||||
|
||||
def create_user(user_type, name, email, **kwargs):
|
||||
"""用户工厂函数"""
|
||||
if user_type == 'admin':
|
||||
return Admin(name, email, kwargs.get('permissions', []))
|
||||
elif user_type == 'guest':
|
||||
return Guest(name, email)
|
||||
else:
|
||||
return User(name, email)
|
||||
|
||||
def validate_email(email):
|
||||
"""验证邮箱格式"""
|
||||
return '@' in email and '.' in email.split('@')[1]
|
||||
|
||||
def process_batch(users):
|
||||
"""批量处理用户"""
|
||||
results = []
|
||||
for user in users:
|
||||
if user.active:
|
||||
results.append({
|
||||
'name': user.name,
|
||||
'email': user.email,
|
||||
'status': 'active'
|
||||
})
|
||||
return results
|
||||
''',
|
||||
"file": "user_manager.py",
|
||||
"expected_classes": 3,
|
||||
},
|
||||
"python_large": {
|
||||
"code": '''
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
class DataProcessor:
|
||||
"""数据处理器"""
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
||||
def process(self, data):
|
||||
if not self.validate(data):
|
||||
raise ValueError("Invalid data format")
|
||||
return self.transform(data)
|
||||
|
||||
def validate(self, data):
|
||||
return isinstance(data, dict) and 'id' in data
|
||||
|
||||
def transform(self, data):
|
||||
return {
|
||||
'id': data['id'],
|
||||
'processed': True,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
class APIClient:
|
||||
"""API 客户端"""
|
||||
def __init__(self, base_url, api_key):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
def get(self, endpoint):
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
headers = {'Authorization': f'Bearer {self.api_key}'}
|
||||
return self._request('GET', url, headers)
|
||||
|
||||
def post(self, endpoint, data):
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
return self._request('POST', url, headers, json.dumps(data))
|
||||
|
||||
def _request(self, method, url, headers, body=None):
|
||||
# 实现请求逻辑
|
||||
pass
|
||||
|
||||
class CacheManager:
|
||||
"""缓存管理器"""
|
||||
def __init__(self, max_size=1000):
|
||||
self.max_size = max_size
|
||||
self.cache = {}
|
||||
self.access_order = []
|
||||
|
||||
def get(self, key):
|
||||
if key in self.cache:
|
||||
self._update_access(key)
|
||||
return self.cache[key]
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
if len(self.cache) >= self.max_size:
|
||||
oldest = self.access_order.pop(0)
|
||||
del self.cache[oldest]
|
||||
|
||||
self.cache[key] = value
|
||||
self.access_order.append(key)
|
||||
|
||||
def _update_access(self, key):
|
||||
if key in self.access_order:
|
||||
self.access_order.remove(key)
|
||||
self.access_order.append(key)
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
self.access_order.clear()
|
||||
''',
|
||||
"file": "large_module.py",
|
||||
"expected_classes": 3,
|
||||
},
|
||||
"rust": {
|
||||
"code": """
|
||||
struct User {
|
||||
name: String,
|
||||
email: String,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn new(name: String, email: String) -> Self {
|
||||
User {
|
||||
name,
|
||||
email,
|
||||
active: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn deactivate(&mut self) {
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
struct Admin {
|
||||
user: User,
|
||||
permissions: Vec<String>,
|
||||
}
|
||||
|
||||
impl Admin {
|
||||
fn has_permission(&self, permission: &str) -> bool {
|
||||
self.permissions.contains(&permission.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn create_user(name: String, email: String) -> User {
|
||||
User::new(name, email)
|
||||
}
|
||||
|
||||
fn process_batch(users: Vec<User>) -> Vec<String> {
|
||||
users
|
||||
.iter()
|
||||
.filter(|u| u.active)
|
||||
.map(|u| u.name.clone())
|
||||
.collect()
|
||||
}
|
||||
""",
|
||||
"file": "user.rs",
|
||||
"expected_classes": 2,
|
||||
},
|
||||
"cpp": {
|
||||
"code": """
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
class User {
|
||||
private:
|
||||
std::string name;
|
||||
std::string email;
|
||||
bool active;
|
||||
|
||||
public:
|
||||
User(const std::string& name, const std::string& email)
|
||||
: name(name), email(email), active(true) {}
|
||||
|
||||
void deactivate() {
|
||||
active = false;
|
||||
}
|
||||
|
||||
bool isActive() const {
|
||||
return active;
|
||||
}
|
||||
};
|
||||
|
||||
class Admin : public User {
|
||||
private:
|
||||
std::vector<std::string> permissions;
|
||||
|
||||
public:
|
||||
Admin(const std::string& name, const std::string& email)
|
||||
: User(name, email) {}
|
||||
|
||||
void addPermission(const std::string& perm) {
|
||||
permissions.push_back(perm);
|
||||
}
|
||||
};
|
||||
|
||||
void processUsers(std::vector<User>& users) {
|
||||
for (auto& user : users) {
|
||||
if (user.isActive()) {
|
||||
// 处理活跃用户
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
"file": "user.cpp",
|
||||
"expected_classes": 2,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.results = {}
|
||||
|
||||
def run_benchmark(self, strategy: ChunkStrategy):
|
||||
"""运行单个策略的基准测试"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"测试策略: {strategy.value}")
|
||||
print("=" * 60)
|
||||
|
||||
total_chunks = 0
|
||||
total_time = 0
|
||||
first_init_time = None
|
||||
results = []
|
||||
|
||||
for name, test_case in self.TEST_CASES.items():
|
||||
# 测试初始化时间(首次)
|
||||
start_time = time.time()
|
||||
chunks = chunk_code(
|
||||
test_case["code"],
|
||||
test_case["file"],
|
||||
strategy=strategy,
|
||||
chunk_size=512,
|
||||
chunk_overlap=50,
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if first_init_time is None:
|
||||
first_init_time = elapsed
|
||||
|
||||
total_time += elapsed
|
||||
total_chunks += len(chunks)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"chunks": len(chunks),
|
||||
"time_ms": elapsed * 1000,
|
||||
"chars": len(test_case["code"]),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"\n [{name}]")
|
||||
print(f" 字符数: {len(test_case['code'])}")
|
||||
print(f" 分块数: {len(chunks)}")
|
||||
print(f" 耗时: {elapsed * 1000:.2f}ms")
|
||||
|
||||
print(f"\n总耗时: {total_time * 1000:.2f}ms")
|
||||
print(f"总块数: {total_chunks}")
|
||||
|
||||
return {
|
||||
"strategy": strategy.value,
|
||||
"total_time_ms": total_time * 1000,
|
||||
"total_chunks": total_chunks,
|
||||
"first_init_time_ms": first_init_time * 1000 if first_init_time else 0,
|
||||
"details": results,
|
||||
}
|
||||
|
||||
def run_all(self):
|
||||
"""运行所有策略的测试"""
|
||||
print("=" * 60)
|
||||
print("📊 代码分块性能测试")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\n可用的分块策略: {get_available_strategies()}")
|
||||
print(f"langchain 可用: {LANGCHAIN_AVAILABLE}")
|
||||
print(f"semantic-text-splitter 可用: {SEMANTIC_AVAILABLE}")
|
||||
|
||||
all_results = []
|
||||
|
||||
# 测试每个策略
|
||||
for strategy_name in get_available_strategies():
|
||||
strategy = ChunkStrategy(strategy_name)
|
||||
result = self.run_benchmark(strategy)
|
||||
all_results.append(result)
|
||||
|
||||
# 打印汇总
|
||||
self.print_summary(all_results)
|
||||
|
||||
return all_results
|
||||
|
||||
def print_summary(self, results: List[Dict]):
|
||||
"""打印测试汇总"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📈 测试结果汇总")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\n{'策略':<15} {'总耗时':<12} {'首次初始化':<12} {'分块数':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for result in results:
|
||||
print(
|
||||
f"{result['strategy']:<15} "
|
||||
f"{result['total_time_ms']:>8.2f}ms "
|
||||
f"{result['first_init_time_ms']:>8.2f}ms "
|
||||
f"{result['total_chunks']:>6}"
|
||||
)
|
||||
|
||||
# 找出最快的策略
|
||||
fastest = min(results, key=lambda x: x["total_time_ms"])
|
||||
most_chunks = max(results, key=lambda x: x["total_chunks"])
|
||||
|
||||
print(
|
||||
f"\n🏆 最快策略: {fastest['strategy']} ({fastest['total_time_ms']:.2f}ms)"
|
||||
)
|
||||
print(
|
||||
f"📦 最多分块: {most_chunks['strategy']} ({most_chunks['total_chunks']} 块)"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
benchmark = ChunkingBenchmark()
|
||||
results = benchmark.run_all()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ 测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
代码分块模块
|
||||
|
||||
使用 langchain_text_splitters 进行语言感知分块。
|
||||
支持 29 种编程语言的语法感知分块,tiktoken 已内置,无需额外下载。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
|
||||
try:
|
||||
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"请安装 langchain-text-splitters: uv pip install langchain-text-splitters"
|
||||
)
|
||||
|
||||
LANGUAGE_MAP = {
|
||||
".py": Language.PYTHON,
|
||||
".rs": Language.RUST,
|
||||
".cpp": Language.CPP,
|
||||
".cc": Language.CPP,
|
||||
".cxx": Language.CPP,
|
||||
".go": Language.GO,
|
||||
".java": Language.JAVA,
|
||||
".js": Language.JS,
|
||||
".ts": Language.TS,
|
||||
".jsx": Language.JS,
|
||||
".tsx": Language.TS,
|
||||
".c": Language.C,
|
||||
".h": Language.C,
|
||||
".hpp": Language.CPP,
|
||||
".rb": Language.RUBY,
|
||||
".swift": Language.SWIFT,
|
||||
".kt": Language.KOTLIN,
|
||||
".md": Language.MARKDOWN,
|
||||
".txt": None,
|
||||
}
|
||||
|
||||
|
||||
def detect_language(file_path: str) -> Language | None:
|
||||
"""检测文件语言类型"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return LANGUAGE_MAP.get(ext)
|
||||
|
||||
|
||||
def chunk_code(
|
||||
content: str,
|
||||
file_path: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""将代码文件分块,返回块列表
|
||||
|
||||
使用 langchain_text_splitters 进行语言感知分块。
|
||||
支持 29 种编程语言的语法感知分块。
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
file_path: 文件路径
|
||||
chunk_size: 分块大小(token 数)
|
||||
chunk_overlap: 重叠大小(token 数)
|
||||
|
||||
Returns:
|
||||
分块列表,每个块包含 text 和 metadata
|
||||
"""
|
||||
language = detect_language(file_path)
|
||||
|
||||
if language is None:
|
||||
# 不支持的语言,使用简单分块
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=["\n\n", "\n", " ", ""],
|
||||
length_function=len,
|
||||
)
|
||||
texts = splitter.split_text(content)
|
||||
language_name = "text"
|
||||
else:
|
||||
# 使用语言感知分块
|
||||
splitter = RecursiveCharacterTextSplitter.from_language(
|
||||
language=language,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
texts = splitter.split_text(content)
|
||||
language_name = language.value
|
||||
|
||||
return [
|
||||
{
|
||||
"text": text,
|
||||
"metadata": {
|
||||
"source_file": file_path,
|
||||
"chunk_index": i,
|
||||
"language": language_name,
|
||||
},
|
||||
}
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import click
|
||||
from ocrag.commands import add as add_cmd
|
||||
from ocrag.commands import search as search_cmd
|
||||
from ocrag.commands import remove as remove_cmd
|
||||
from ocrag.commands import list_cmd
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option()
|
||||
def main():
|
||||
"""ocrag - 代码库 RAG 命令行工具"""
|
||||
pass
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True))
|
||||
@click.option("--recursive", "-r", is_flag=True, help="递归处理目录")
|
||||
def add(paths, recursive):
|
||||
"""向知识库添加文件或目录"""
|
||||
add_cmd(paths, recursive)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("query")
|
||||
@click.option("--top-k", "-k", default=5, show_default=True, help="返回结果数量")
|
||||
def search(query, top_k):
|
||||
"""语义搜索知识库"""
|
||||
search_cmd(query, top_k)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("path", type=click.Path())
|
||||
def remove(path):
|
||||
"""从知识库中移除指定文件"""
|
||||
remove_cmd(path)
|
||||
|
||||
|
||||
@main.command()
|
||||
def list():
|
||||
"""列出知识库中的所有条目"""
|
||||
list_cmd()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from ocrag.commands.add import run as add
|
||||
from ocrag.commands.search import run as search
|
||||
from ocrag.commands.remove import run as remove
|
||||
from ocrag.commands.list import run as list_cmd
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from ocrag.chunker import chunk_code
|
||||
from ocrag.embedder import embedder
|
||||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
def collect_files(paths, recursive):
|
||||
"""收集所有需要处理的文件"""
|
||||
files = []
|
||||
for p in paths:
|
||||
p = Path(p)
|
||||
if p.is_file():
|
||||
files.append(p)
|
||||
elif p.is_dir() and recursive:
|
||||
for root, dirs, filenames in os.walk(p):
|
||||
for f in filenames:
|
||||
files.append(Path(root) / f)
|
||||
elif p.is_dir() and not recursive:
|
||||
print(f"跳过目录 {p}(未使用 --recursive)")
|
||||
return files
|
||||
|
||||
|
||||
def run(paths, recursive):
|
||||
db = RagDB()
|
||||
files = collect_files(paths, recursive)
|
||||
|
||||
total_chunks = 0
|
||||
for file_path in files:
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
chunks = chunk_code(content, str(file_path))
|
||||
if not chunks:
|
||||
continue
|
||||
|
||||
# 批量 embedding
|
||||
texts = [c["text"] for c in chunks]
|
||||
vectors = embedder.embed(texts)
|
||||
|
||||
documents = []
|
||||
for chunk, vec in zip(chunks, vectors):
|
||||
documents.append(
|
||||
{
|
||||
"text": chunk["text"],
|
||||
"vector": vec,
|
||||
"metadata": chunk["metadata"],
|
||||
}
|
||||
)
|
||||
|
||||
db.add_documents(documents)
|
||||
total_chunks += len(documents)
|
||||
print(f"✅ {file_path} -> {len(documents)} 个块")
|
||||
except Exception as e:
|
||||
print(f"❌ 处理 {file_path} 失败: {e}")
|
||||
|
||||
print(f"\n📦 总计添加 {total_chunks} 个块")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
def run():
|
||||
db = RagDB()
|
||||
sources = db.list_sources()
|
||||
|
||||
if not sources:
|
||||
print("知识库为空")
|
||||
return
|
||||
|
||||
print("知识库中的文件:")
|
||||
for i, source in enumerate(sources, 1):
|
||||
print(f"{i}. {source}")
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
def run(path: str):
|
||||
db = RagDB()
|
||||
# 注意:当前db.py中的delete_by_source方法尚未完全实现
|
||||
# 这里先调用以保持接口一致,后续需要完善db.py的实现
|
||||
db.delete_by_source(path)
|
||||
print(f"已删除 {path} 的所有块")
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from ocrag.embedder import embedder
|
||||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
def run(query: str, top_k: int):
|
||||
db = RagDB()
|
||||
query_vec = embedder.embed_single(query)
|
||||
results = db.search(query_vec, top_k)
|
||||
|
||||
if not results:
|
||||
print("未找到相关结果。")
|
||||
return
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
print(f"\n[{i}] 相似度: {r['_distance']:.4f}")
|
||||
print(f"来源: {r['metadata'].get('source_file', 'unknown')}")
|
||||
print(f"内容:\n{r['text']}")
|
||||
print("-" * 80)
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import lancedb
|
||||
import pyarrow as pa
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
DB_PATH = Path.home() / ".ocrag" / "data.lance"
|
||||
|
||||
|
||||
class RagDB:
|
||||
def __init__(self, db_path: str = None):
|
||||
self.path = db_path or str(DB_PATH)
|
||||
self.conn = lancedb.connect(self.path)
|
||||
self._init_table()
|
||||
|
||||
def _init_table(self):
|
||||
"""如果表不存在则创建"""
|
||||
schema = pa.schema(
|
||||
[
|
||||
("text", pa.string()),
|
||||
(
|
||||
"vector",
|
||||
pa.list_(pa.float32(), 1024),
|
||||
), # 1024维 (Qwen3-Embedding-0.6B)
|
||||
("metadata", pa.string()), # JSON 字符串
|
||||
]
|
||||
)
|
||||
try:
|
||||
self.table = self.conn.open_table("documents")
|
||||
except:
|
||||
self.table = self.conn.create_table("documents", schema=schema)
|
||||
|
||||
def add_documents(self, documents: List[Dict[str, Any]]):
|
||||
"""批量添加文档
|
||||
documents: [{"text": str, "vector": List[float], "metadata": dict}, ...]
|
||||
"""
|
||||
import json
|
||||
|
||||
data = []
|
||||
for doc in documents:
|
||||
data.append(
|
||||
{
|
||||
"text": doc["text"],
|
||||
"vector": doc["vector"],
|
||||
"metadata": json.dumps(doc.get("metadata", {})),
|
||||
}
|
||||
)
|
||||
self.table.add(data)
|
||||
|
||||
def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
import json
|
||||
|
||||
results = self.table.search(query_vector).limit(top_k).to_list()
|
||||
for r in results:
|
||||
r["metadata"] = json.loads(r["metadata"])
|
||||
return results
|
||||
|
||||
def delete_by_source(self, source_file: str) -> int:
|
||||
"""删除指定源文件的所有块
|
||||
|
||||
Args:
|
||||
source_file: 要删除的源文件路径
|
||||
|
||||
Returns:
|
||||
删除的块数量
|
||||
"""
|
||||
import json
|
||||
|
||||
# 安全实现:使用 Pandas 过滤来避免 SQL 注入
|
||||
df = self.table.to_pandas()
|
||||
|
||||
if df.empty:
|
||||
return 0
|
||||
|
||||
# 过滤出不包含该源文件的行
|
||||
indices_to_keep = []
|
||||
for idx, row in df.iterrows():
|
||||
try:
|
||||
meta = json.loads(row["metadata"])
|
||||
if meta.get("source_file") != source_file:
|
||||
indices_to_keep.append(idx)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# 如果解析失败,保留该行
|
||||
indices_to_keep.append(idx)
|
||||
|
||||
num_deleted = len(df) - len(indices_to_keep)
|
||||
|
||||
if num_deleted == 0:
|
||||
return 0
|
||||
|
||||
# 保留不需要删除的行
|
||||
df_remaining = df.loc[indices_to_keep]
|
||||
|
||||
# 使用 overwrite 模式重建表
|
||||
self.conn.drop_table("documents")
|
||||
self.table = self.conn.create_table("documents", df_remaining)
|
||||
|
||||
return num_deleted
|
||||
|
||||
def list_sources(self) -> List[str]:
|
||||
"""列出所有已添加的源文件路径(去重)"""
|
||||
# 获取所有 metadata,提取 source_file
|
||||
df = self.table.to_pandas()
|
||||
if df.empty:
|
||||
return []
|
||||
import json
|
||||
|
||||
sources = set()
|
||||
for meta_str in df["metadata"]:
|
||||
meta = json.loads(meta_str)
|
||||
if "source_file" in meta:
|
||||
sources.add(meta["source_file"])
|
||||
return sorted(sources)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class Embedder:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
# 计算模型路径
|
||||
model_path = (
|
||||
Path(__file__).parent.parent.parent / "models" / "Qwen3-Embedding-0.6B"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"模型目录不存在: {model_path}")
|
||||
|
||||
try:
|
||||
cls._instance.model = SentenceTransformer(
|
||||
str(model_path),
|
||||
device="cuda"
|
||||
if os.getenv("USE_GPU", "false").lower() == "true"
|
||||
else "cpu",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加载模型失败: {e}")
|
||||
|
||||
return cls._instance
|
||||
|
||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||
"""返回向量列表,每个向量为 float 列表"""
|
||||
embeddings = self.model.encode(texts, normalize_embeddings=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_single(self, text: str) -> List[float]:
|
||||
return self.embed([text])[0]
|
||||
|
||||
|
||||
# 全局单例
|
||||
embedder = Embedder()
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_file_hash(file_path: str) -> str:
|
||||
"""计算文件的MD5哈希值,用于检测文件变化"""
|
||||
md5_hash = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
md5_hash.update(chunk)
|
||||
return md5_hash.hexdigest()
|
||||
|
||||
|
||||
def ensure_dir(dir_path: str) -> None:
|
||||
"""确保目录存在,不存在则创建"""
|
||||
Path(dir_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_file_size(file_path: str) -> int:
|
||||
"""获取文件大小(字节)"""
|
||||
return os.path.getsize(file_path)
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
"""格式化文件大小"""
|
||||
for unit in ["B", "KB", "MB", "GB"]:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.2f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.2f} TB"
|
||||
|
||||
|
||||
def is_text_file(file_path: str) -> bool:
|
||||
"""检查文件是否为文本文件"""
|
||||
text_extensions = {
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".md",
|
||||
".txt",
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".xml",
|
||||
".html",
|
||||
".css",
|
||||
".sql",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".ps1",
|
||||
".bat",
|
||||
".vue",
|
||||
".svelte",
|
||||
}
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return ext in text_extensions
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""获取项目根目录"""
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
from ocrag.chunker import chunk_code
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_python_code():
|
||||
return """def hello():
|
||||
print('Hello, World!')
|
||||
|
||||
class Test:
|
||||
def method(self):
|
||||
pass"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_markdown_content():
|
||||
return """# Title
|
||||
|
||||
Paragraph 1
|
||||
|
||||
## Subtitle
|
||||
|
||||
Paragraph 2"""
|
||||
|
||||
|
||||
def test_chunk_code_python(sample_python_code):
|
||||
chunks = chunk_code(sample_python_code, "test.py")
|
||||
assert len(chunks) > 0
|
||||
assert "def hello" in chunks[0]["text"]
|
||||
assert chunks[0]["metadata"]["language"] == "python"
|
||||
|
||||
|
||||
def test_chunk_code_markdown(sample_markdown_content):
|
||||
chunks = chunk_code(sample_markdown_content, "test.md")
|
||||
assert len(chunks) > 0
|
||||
assert "# Title" in chunks[0]["text"]
|
||||
assert chunks[0]["metadata"]["language"] == "markdown"
|
||||
|
||||
|
||||
def test_chunk_code_text():
|
||||
content = "a" * 10000 # Large content to test text fallback
|
||||
chunks = chunk_code(content, "large.txt")
|
||||
assert len(chunks) > 0
|
||||
assert chunks[0]["metadata"]["source_file"] == "large.txt"
|
||||
assert chunks[0]["metadata"]["language"] == "text"
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import pytest
|
||||
import os
|
||||
from click.testing import CliRunner
|
||||
from ocrag.cli import main
|
||||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_env(tmpdir):
|
||||
# Create test files
|
||||
os.makedirs(os.path.join(tmpdir, "test_dir"))
|
||||
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
|
||||
f.write("def test_func():\n pass")
|
||||
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
|
||||
f.write("print('Hello')")
|
||||
return tmpdir
|
||||
|
||||
|
||||
def test_add_command(runner, setup_test_env, tmpdir):
|
||||
# Use temporary DB
|
||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||
os.environ["OCRAG_DB_PATH"] = db_path
|
||||
|
||||
result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
|
||||
assert "test_file.py" in result.output
|
||||
assert "总计添加" in result.output
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_search_command(runner, setup_test_env, tmpdir):
|
||||
# Use temporary DB and add a file first
|
||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||
os.environ["OCRAG_DB_PATH"] = db_path
|
||||
|
||||
add_result = runner.invoke(
|
||||
main, ["add", os.path.join(setup_test_env, "test_file.py")]
|
||||
)
|
||||
assert add_result.exit_code == 0
|
||||
|
||||
# Search
|
||||
result = runner.invoke(main, ["search", "Hello"])
|
||||
assert "Hello" in result.output
|
||||
assert "来源: " in result.output
|
||||
|
||||
|
||||
def test_list_command(runner, setup_test_env, tmpdir):
|
||||
# Use temporary DB and add a file first
|
||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||
os.environ["OCRAG_DB_PATH"] = db_path
|
||||
|
||||
add_result = runner.invoke(
|
||||
main, ["add", os.path.join(setup_test_env, "test_file.py")]
|
||||
)
|
||||
assert add_result.exit_code == 0
|
||||
|
||||
# List
|
||||
result = runner.invoke(main, ["list"])
|
||||
assert "test_file.py" in result.output
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
import pytest
|
||||
import os
|
||||
import shutil
|
||||
from ocrag.db import RagDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(tmpdir):
|
||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||
yield RagDB(db_path)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
|
||||
|
||||
def test_db_initialization(temp_db):
|
||||
assert temp_db.table is not None
|
||||
|
||||
|
||||
def test_add_documents(temp_db):
|
||||
documents = [
|
||||
{
|
||||
"text": "Test content",
|
||||
"vector": [0.1] * 1024,
|
||||
"metadata": {"source_file": "test.py"},
|
||||
}
|
||||
]
|
||||
temp_db.add_documents(documents)
|
||||
assert temp_db.table.to_pandas().shape[0] == 1
|
||||
|
||||
|
||||
def test_search(temp_db):
|
||||
# Add a document
|
||||
documents = [
|
||||
{
|
||||
"text": "Database configuration",
|
||||
"vector": [0.1] * 1024,
|
||||
"metadata": {"source_file": "config.py"},
|
||||
}
|
||||
]
|
||||
temp_db.add_documents(documents)
|
||||
|
||||
# Search
|
||||
results = temp_db.search([0.1] * 1024)
|
||||
assert len(results) == 1
|
||||
assert "Database configuration" in results[0]["text"]
|
||||
|
||||
|
||||
def test_list_sources(temp_db):
|
||||
# Add documents from multiple sources
|
||||
documents = [
|
||||
{
|
||||
"text": "Content 1",
|
||||
"vector": [0.1] * 1024,
|
||||
"metadata": {"source_file": "file1.py"},
|
||||
},
|
||||
{
|
||||
"text": "Content 2",
|
||||
"vector": [0.2] * 1024,
|
||||
"metadata": {"source_file": "file2.py"},
|
||||
},
|
||||
]
|
||||
temp_db.add_documents(documents)
|
||||
|
||||
sources = temp_db.list_sources()
|
||||
assert len(sources) == 2
|
||||
assert "file1.py" in sources
|
||||
assert "file2.py" in sources
|
||||
|
||||
|
||||
def test_delete_by_source(temp_db):
|
||||
# Add documents from multiple sources
|
||||
documents = [
|
||||
{
|
||||
"text": "Content A",
|
||||
"vector": [0.1] * 1024,
|
||||
"metadata": {"source_file": "file1.py"},
|
||||
},
|
||||
{
|
||||
"text": "Content B",
|
||||
"vector": [0.2] * 1024,
|
||||
"metadata": {"source_file": "file2.py"},
|
||||
},
|
||||
{
|
||||
"text": "Content C",
|
||||
"vector": [0.3] * 1024,
|
||||
"metadata": {"source_file": "file1.py"},
|
||||
},
|
||||
{
|
||||
"text": "Content D",
|
||||
"vector": [0.4] * 1024,
|
||||
"metadata": {"source_file": "file3.py"},
|
||||
},
|
||||
]
|
||||
temp_db.add_documents(documents)
|
||||
|
||||
# Verify initial state
|
||||
sources = temp_db.list_sources()
|
||||
assert len(sources) == 3
|
||||
|
||||
# Delete file1.py (should delete 2 chunks)
|
||||
num_deleted = temp_db.delete_by_source("file1.py")
|
||||
assert num_deleted == 2
|
||||
|
||||
# Verify deletion
|
||||
sources = temp_db.list_sources()
|
||||
assert len(sources) == 2
|
||||
assert "file1.py" not in sources
|
||||
assert "file2.py" in sources
|
||||
assert "file3.py" in sources
|
||||
|
||||
|
||||
def test_delete_nonexistent_source(temp_db):
|
||||
# Try to delete a source that doesn't exist
|
||||
num_deleted = temp_db.delete_by_source("nonexistent.py")
|
||||
assert num_deleted == 0
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import pytest
|
||||
from ocrag.embedder import Embedder
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def embedder():
|
||||
return Embedder()
|
||||
|
||||
|
||||
def test_embedder_singleton(embedder):
|
||||
embedder2 = Embedder()
|
||||
assert embedder is embedder2
|
||||
|
||||
|
||||
def test_embed_single(embedder):
|
||||
vector = embedder.embed_single("Test sentence")
|
||||
assert len(vector) == 1024 # Qwen3-Embedding-0.6B output dimension
|
||||
assert isinstance(vector[0], float)
|
||||
|
||||
|
||||
def test_embed_batch(embedder):
|
||||
vectors = embedder.embed(["Sentence 1", "Sentence 2"])
|
||||
assert len(vectors) == 2
|
||||
assert len(vectors[0]) == 1024
|
||||
assert len(vectors[1]) == 1024
|
||||
assert (
|
||||
vectors[0] != vectors[1]
|
||||
) # Different sentences should have different embeddings
|
||||
Loading…
Reference in New Issue