feat: 添加本地代码知识库 RAG 插件,支持语义搜索与实时更新

This commit is contained in:
songsenand 2026-04-16 11:31:18 +08:00
commit 28e557594a
25 changed files with 2285 additions and 0 deletions

182
.gitignore vendored Normal file
View File

@ -0,0 +1,182 @@
models/*
uv.lock
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# ---> VisualStudioCode
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
uv.lock
data/*

386
DescriptionOfDesign.md Normal file
View File

@ -0,0 +1,386 @@
# OpenCode RAG 插件设计文档
> 本文档详细说明了 ocrag 项目的设计目的、技术选型考量、架构构思和实现细节。
---
## 1. 设计背景与目标
### 1.1 问题背景
在大型代码库中进行 AI 辅助开发时AI 需要理解项目中的大量代码才能给出准确的建议和答案。然而:
- **上下文窗口有限**:无法将整个代码库放入 AI 提示词
- **实时性需求**:开发者需要 AI 能立即理解刚编写的代码
- **本地化优先**:代码不应该上传到外部服务器
### 1.2 设计目标
为 OpenCode 构建一个**本地代码知识库 RAG 系统**,实现:
| 功能 | 描述 |
|------|------|
| **实时添加** | 将代码文件或目录添加到本地知识库 |
| **语义搜索** | 通过自然语言查询获取相关代码片段 |
| **智能管理** | 支持删除和列出知识库中的条目 |
### 1.3 设计原则
1. **本地化优先**:所有数据和计算都在本地完成,不依赖外部服务
2. **轻量高效**:避免引入复杂的服务端组件,保持极低的延迟
3. **零运维**:嵌入式数据库,无需安装配置,即装即用
4. **AI 友好**:生成可被 AI 直接理解和使用的上下文
---
## 2. 技术选型详解
### 2.1 为什么选择 Python
| 考量因素 | Python | Rust |
|----------|--------|------|
| 开发效率 | ✅ 高 | ⚠️ 中 |
| LLM 生成质量 | ✅ 高AI 更熟悉 Python | ⚠️ 中 |
| 生态丰富度 | ✅ 成熟 | ⚠️ 一般 |
| 运行时性能 | ⚠️ 中(可通过 C 扩展优化) | ✅ 高 |
| 社区支持 | ✅ 丰富 | ⚠️ 有限 |
**结论**Python 的高开发效率和 AI 生成质量优势明显,即使运行时性能略低,但对于 RAG 这种 I/O 密集型应用影响有限。
### 2.2 为什么选择 LanceDB
传统向量数据库对比:
| 数据库 | 特点 | 缺点 |
|--------|------|------|
| Chroma | 简单易用 | 不支持持久化、不适合生产 |
| Milvus | 功能强大 | 需要 Docker 部署 |
| Qdrant | Rust 实现,高性能 | 需要单独部署 |
| **LanceDB** | **嵌入式、零运维、Python 原生** | **相对较新** |
**LanceDB 优势**
- **嵌入式**:数据库就是一个文件夹,无需单独服务
- **零运维**:安装即用,自动管理
- **Python 优先**:原生 Python SDK类型提示完善
- **性能优秀**:基于 Apache Arrow查询速度快
### 2.3 为什么选择 langchain-text-splitters而非 chonkie
最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现:
| 方案 | 优势 | 劣势 |
|------|------|------|
| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer初始化慢 |
| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 |
**最终选择 langchain-text-splitters**
- 对于 RAG 场景,分块精度不是最关键因素
- langchain 的方案更加轻量,启动更快
- 代码按自然段落(函数、类)边界切分,效果足够好
### 2.4 为什么使用 Qwen3-Embedding-0.6B
| 模型 | 维度 | 优势 | 劣势 |
|------|------|------|------|
| all-MiniLM-L6-v2 | 384 | 快速、小巧 | 英文为主 |
| **Qwen3-Embedding-0.6B** | **1024** | **中文优化、中英双语** | 较大、首次加载慢 |
**选择理由**
- 开源可本地部署,代码安全
- 中文支持优秀
- 适合代码+文档混合场景
### 2.5 为什么不用 MCP 服务器
**MCP 方案的劣势**
- 需要额外部署 MCP 服务器
- 增加系统复杂度
- 调试困难
**CLI 方案的优势**
- 零额外组件
- 通过 `bash` 工具直接调用
- 易于调试和扩展
---
## 3. 架构设计
### 3.1 整体架构
```
┌─────────────────────────────────────────────────────────────────┐
│ OpenCode AI │
└────────────────────────────┬──────────────────────────────────┘
│ 1. rag_add / rag_search
┌─────────────────────────────────────────────────────────────────┐
│ TypeScript Plugin (ocrag-plugin.ts) │
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │
│ │ rag_add │ │rag_search│ │ Skill 指令 │ │
│ └────┬─────┘ └────┬─────┘ └──────────────┘ │
└───────┼─────────────────┼───────────────────────────────────────┘
│ │
▼ ▼
┌─────────────────────────────────────────────────────────────────┐
│ Python CLI (ocrag) │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────────┐ │
│ │ add │ │ search │ │ remove │ │ list │ │
│ └───┬─────┘ └────┬─────┘ └───┬─────┘ └──────┬──────┘ │
└──────┼───────────────┼─────────────┼──────────────────┼──────────┘
│ │ │ │
▼ ▼ ▼ │
┌─────────────────────────────────────────┐ │
│ Processing Pipeline │ │
│ │ │
│ ┌────────────┐ ┌────────────────┐ │ │
│ │ Chunker │───▶│ Embedder │──┼────────────┤
│ │ (分块) │ │ (向量化) │ │ │
│ └────────────┘ └────────────────┘ │ │
└─────────────────────────────────────────┼────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ LanceDB (向量数据库) │
│ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ documents 表 │ │
│ │ ┌──────────┬──────────────────────┬──────────────────┐ │ │
│ │ │ text │ vector │ metadata │ │ │
│ │ │ (文本) │ (1024维向量) │ (JSON元数据) │ │ │
│ │ └──────────┴──────────────────────┴──────────────────┘ │ │
│ └──────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
```
### 3.2 数据流设计
#### 添加文件流程
```
用户请求 → Plugin → CLI:add → Chunker → Embedder → LanceDB → 返回结果
```
**详细步骤**
1. **文件收集**:递归遍历目录或单个文件
2. **内容读取**:以 UTF-8 编码读取文件内容
3. **代码分块**:按语言和语法结构切分代码
4. **向量化**:使用 Embedder 将文本转为 1024 维向量
5. **存储入库**:将文本、向量、元数据存入 LanceDB
#### 搜索流程
```
用户查询 → Plugin → CLI:search → Embedder → LanceDB → 返回结果
```
**详细步骤**
1. **查询向量化**:将自然语言查询转为向量
2. **相似度搜索**LanceDB 执行向量相似度检索
3. **结果排序**:按相似度距离排序
4. **元数据提取**:解析 JSON 元数据
5. **结果返回**:格式化的代码片段列表
### 3.3 模块职责划分
| 模块 | 职责 | 依赖 |
|------|------|------|
| `cli.py` | 命令行入口,参数解析 | click |
| `chunker.py` | 代码分块处理 | langchain-text-splitters |
| `embedder.py` | 文本向量化 | sentence-transformers |
| `db.py` | 向量数据库操作 | lancedb, pyarrow |
| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db |
| `utils.py` | 工具函数 | - |
---
## 4. 核心算法设计
### 4.1 代码分块算法
**设计思路**:按代码的语义结构进行分块
```python
# Python 语言的分隔符优先级
separators = [
"\n\nclass ", # 类定义
"\n\ndef ", # 函数定义
"\n\nasync def ", # 异步函数
"\n\n", # 空行段落
"\n", # 换行
" ", # 空格
"" # 字符
]
```
**分块参数**
- `chunk_size=2000`:每块约 2000 字符
- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯
### 4.2 向量化策略
**模型选择**Qwen3-Embedding-0.6B
- 输出维度1024
- 向量归一化L2 归一化,便于余弦相似度计算
**批量处理**
- 一次性对多个文本块进行编码
- 利用 GPU/CPU 的批处理能力提升吞吐量
### 4.3 搜索策略
**相似度度量**LanceDB 默认使用余弦相似度
**top_k 参数**:控制返回结果数量,默认 5 条
---
## 5. 安全与性能考量
### 5.1 安全性设计
| 风险点 | 防护措施 |
|--------|----------|
| SQL 注入 | 使用 Pandas 过滤而非 SQL 字符串拼接 |
| 路径遍历 | 仅处理指定路径,不执行动态导入 |
| 数据泄露 | 所有数据本地存储,不涉及网络传输 |
### 5.2 性能优化
**已实现**
- ✅ 批量 embedding 减少 I/O 开销
- ✅ 单例模式避免重复加载模型
- ✅ 向量归一化加速相似度计算
**可优化项**
- 文件哈希缓存避免重复添加
- GPU 加速 embedding 计算
- 增量索引更新而非全量重建
### 5.3 性能基准
| 指标 | 实测结果 | 说明 |
|------|----------|------|
| 搜索延迟 | **63-70 ms** | 包含 embedding + 向量检索 |
| 数据库写入 | 2-3 ms/块 | LanceDB 性能优秀 |
| 分块速度 | <1 ms | 纯规则无模型加载 |
| Embedding | ~2.5秒/块 | Qwen3 模型较大 |
---
## 6. 扩展性设计
### 6.1 多语言支持
只需在 `LANGUAGE_MAP` 中添加新扩展名即可:
```python
LANGUAGE_MAP = {
".py": "python",
".js": "javascript",
# 添加新语言...
".kt": "kotlin",
".swift": "swift",
}
```
### 6.2 自定义 Embedding 模型
修改 `embedder.py` 中的模型路径:
```python
model_path = "path/to/your/model"
```
### 6.3 增量同步Watch 模式)
使用 `watchdog` 库监听文件变化,自动更新知识库:
```python
observer = Observer()
observer.schedule(RAGSyncHandler(), path, recursive=True)
observer.start()
```
---
## 7. 与 OpenCode 的集成
### 7.1 插件架构
```
OpenCode
├── TypeScript Plugin
│ ├── rag_add 工具
│ └── rag_search 工具
└── Skill 指令
└── SKILL.md
```
### 7.2 工具定义
| 工具名 | 参数 | 功能 |
|--------|------|------|
| `rag_add` | `paths: string[]`, `recursive?: boolean` | 添加文件到知识库 |
| `rag_search` | `query: string`, `top_k?: number` | 搜索知识库 |
### 7.3 错误处理
所有工具调用都包裹在 try-catch 中:
```typescript
execute: async (args) => {
try {
const result = await Bun.$`${cmd}`.text();
return result;
} catch (error) {
return `Error: ${error.message}`;
}
}
```
---
## 8. 设计总结
### 8.1 核心创新
1. **零 MCP 架构**:通过 CLI 直接集成,简化系统复杂度
2. **本地化优先**:数据不出本地,保证代码安全
3. **轻量高效**:嵌入式数据库,秒级启动
4. **AI 原生**:输出格式专为 AI 消费设计
### 8.2 适用场景
| 场景 | 适用性 | 说明 |
|------|--------|------|
| 个人项目知识管理 | ✅ 非常适合 | 本地存储,隐私安全 |
| 小团队代码库问答 | ✅ 适合 | 轻量易部署 |
| 大型企业代码库 | ⚠️ 需优化 | 可能需要分布式扩展 |
| 跨语言代码库 | ✅ 支持 | 多语言分块支持 |
### 8.3 未来展望
1. **语义分块升级**:集成更智能的分块算法
2. **多模态支持**:支持图片、图表等非文本内容
3. **增量索引**:支持大型代码库的实时更新
4. **分布式部署**:支持多节点协同检索
---
## 附录:术语表
| 术语 | 解释 |
|------|------|
| RAG | Retrieval-Augmented Generation检索增强生成 |
| Embedding | 将文本转为向量的过程 |
| 向量数据库 | 专门存储和检索向量的数据库 |
| Chunk | 分块后的文本片段 |
| LanceDB | 嵌入式向量数据库 |
---
**文档版本**1.0
**最后更新**2026年4月15日

191
README.md Normal file
View File

@ -0,0 +1,191 @@
# Ocrag - OpenCode 代码知识库 RAG 插件
> 为 OpenCode 提供本地代码语义搜索能力,支持将代码文件添加到知识库并通过自然语言进行检索。
[English](./README_EN.md) | [设计文档](./DescriptionOfDesign.md)
---
## ✨ 特性
- 🚀 **零配置**:嵌入式数据库,安装即用
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
- 🔍 **语义搜索**:通过自然语言查询相关代码片段
- 📦 **多种语言**:支持 Python、JavaScript、TypeScript、Rust、Go 等主流编程语言
- ⚡ **极速响应**:搜索延迟 < 100ms
- 🤖 **AI 友好**:输出格式专为 AI 消费设计
---
## 📦 安装
### 环境要求
- Python 3.10+
- [uv](https://github.com/astral-sh/uv)(包管理器)
### 安装步骤
```bash
# 1. 安装 Python 包
uv pip install -e .
# 2. 安装 OpenCode 插件(可选)
cp opencode-plugin/ocrag-plugin.ts ~/.config/opencode/plugins/
# 3. 安装 Skill可选
mkdir -p ~/.config/opencode/skills/ocrag
cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/
```
---
## 🚀 快速开始
### 基本使用
```bash
# 添加文件到知识库
uv run ocrag add ./src/main.py
# 递归添加整个目录
uv run ocrag add ./src/ --recursive
# 搜索相关代码
uv run ocrag search "如何实现用户认证"
# 列出知识库中的文件
uv run ocrag list
# 删除文件
uv run ocrag remove ./src/main.py
```
### OpenCode 集成
安装插件后AI 可以自动使用知识库功能:
```
用户:把当前文件加入知识库
AI✓ 已将 main.py 添加到知识库
用户:认证模块是怎么实现的?
AI根据知识库中的代码认证模块包含以下关键实现...
```
---
## 📁 项目结构
```
ocrag/
├── src/ocrag/ # Python 源代码
│ ├── cli.py # 命令行入口
│ ├── chunker.py # 代码分块
│ ├── embedder.py # 向量化模块
│ ├── db.py # 数据库操作
│ └── commands/ # 命令实现
│ ├── add.py
│ ├── search.py
│ ├── remove.py
│ └── list.py
├── models/ # Embedding 模型
│ └── Qwen3-Embedding-0.6B/
├── tests/ # 单元测试
├── scripts/ # 工具脚本
│ └── benchmark.py # 性能测试
├── opencode-plugin/ # OpenCode 插件
└── opencode-skill/ # Skill 文件
```
---
## 🔧 配置
### 模型配置
默认使用 `models/Qwen3-Embedding-0.6B` 本地模型。
如需使用 GPU 加速:
```bash
USE_GPU=true uv run ocrag search "查询内容"
```
### 数据库位置
数据默认存储在 `~/.ocrag/data.lance`
---
## 📊 性能
| 操作 | 延迟 | 说明 |
|------|------|------|
| 搜索 | **~65ms** | 含 embedding + 向量检索 |
| 添加文件 | **~5s** | 含分块 + embedding + 存储 |
| 列出文件 | **~5ms** | 纯数据库查询 |
运行性能测试:
```bash
uv run python scripts/benchmark.py
```
---
## ✅ 运行测试
```bash
# 运行所有测试
uv run python -m pytest
# 运行特定模块测试
uv run python -m pytest tests/unit/test_db.py -v
# 查看测试覆盖
uv run python -m pytest --cov=src/ocrag
```
---
## 🛠️ 开发
### 添加新命令
1. 在 `src/ocrag/commands/` 下创建新文件
2. 实现 `run()` 函数
3. 在 `src/ocrag/cli.py` 中注册命令
4. 在 `src/ocrag/commands/__init__.py` 中导出
### 代码规范
```bash
# 格式化代码
ruff format src/
# 代码检查
ruff check src/
```
---
## 📚 相关文档
- [设计文档](./DescriptionOfDesign.md) - 详细的设计说明和技术选型分析
- [CHANGELOG](./CHANGELOG.md) - 版本更新记录
---
## 🤝 贡献
欢迎提交 Issue 和 Pull Request
---
## 📄 许可证
MIT License
---
**版本**0.1.0
**最后更新**2026年4月15日

View File

@ -0,0 +1,42 @@
import { tool, definePlugin } from "opencode";
import { z } from "zod";
export default definePlugin({
name: "ocrag-plugin",
hooks: () => ({
rag_add: tool({
description: "向本地 RAG 知识库添加一个或多个代码文件或目录。添加后,后续搜索将包含这些代码。",
parameters: {
paths: z.array(z.string()).describe("要添加的文件或目录路径列表"),
recursive: z.boolean().optional().describe("如果路径是目录,是否递归添加所有子文件"),
},
execute: async (args) => {
try {
const paths = args.paths.join(" ");
const recursiveFlag = args.recursive ? "--recursive" : "";
const cmd = `ocrag add ${paths} ${recursiveFlag}`;
const result = Bun.$`${cmd}`.text();
return result;
} catch (error) {
return `Error: ${error.message}`;
}
},
}),
rag_search: tool({
description: "在本地 RAG 知识库中执行语义搜索,返回与查询最相关的代码片段。",
parameters: {
query: z.string().describe("自然语言查询,例如:'如何实现用户认证'"),
top_k: z.number().optional().default(5).describe("返回结果数量"),
},
execute: async (args) => {
try {
const cmd = `ocrag search "${args.query}" --top-k ${args.top_k}`;
const result = Bun.$`${cmd}`.text();
return result;
} catch (error) {
return `Error: ${error.message}`;
}
},
}),
}),
});

View File

@ -0,0 +1,9 @@
{
"name": "ocrag-plugin",
"version": "1.0.0",
"description": "Code RAG plugin for OpenCode",
"main": "ocrag-plugin.ts",
"keywords": ["opencode", "plugin", "rag", "code"],
"author": "",
"license": "MIT"
}

34
opencode-skill/SKILL.md Normal file
View File

@ -0,0 +1,34 @@
---
name: ocrag
description: 代码库 RAG 技能,用于向知识库添加代码或搜索已有代码。
---
# 代码库 RAG 技能
## 何时使用此技能
- 用户要求"把当前文件/这个目录加入知识库"
- 用户询问关于代码库的问题,需要基于已有代码回答
- 用户明确提到"RAG"、"知识库"、"语义搜索"等关键词
## 如何使用
### 添加代码到知识库
当用户提供文件路径或目录时,使用 `rag_add` 工具:
- 如果用户说"把这个文件加入知识库",提取路径并调用 `rag_add`recursive 设为 false
- 如果用户说"添加 src/ 目录下所有代码",调用 `rag_add` 并设置 recursive = true
### 搜索知识库
当用户询问代码相关问题(如"这段代码做什么?"、"如何配置XX")时:
1. 先用 `rag_search` 工具查询query 参数为用户的自然语言问题
2. 获得返回的代码片段后,结合片段内容回答用户
3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码
## 示例
- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径
- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现"
- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true
## 注意事项
- 搜索结果会显示代码片段的来源文件
- 可以通过 top_k 参数调整返回结果数量
- 添加文件后,搜索会立即包含新添加的内容

20
pyproject.toml Normal file
View File

@ -0,0 +1,20 @@
[project]
name = "project"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"chonkie>=1.6.2",
"click>=8.3.2",
"lancedb>=0.30.2",
"sentence-transformers>=5.4.1",
"sentencepiece>=0.2.1",
"tiktoken>=0.12.0",
"tokenizers>=0.22.2",
]
[dependency-groups]
dev = [
"pytest>=9.0.3",
]

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
addopts = -v
pythonpath = src

278
scripts/benchmark.py Executable file
View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
"""
性能测试脚本 - 测试 ocrag 的写入和读取性能
使用方法:
uv run python scripts/benchmark.py
uv run python scripts/benchmark.py --cleanup # 测试后清理测试数据
"""
import sys
import os
import time
import tempfile
import shutil
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from ocrag.chunker import chunk_code
from ocrag.embedder import embedder
from ocrag.db import RagDB
class PerformanceBenchmark:
def __init__(self, db_path=None):
if db_path is None:
self.db_path = tempfile.mkdtemp(prefix="ocrag_benchmark_")
else:
self.db_path = db_path
self.db = RagDB(self.db_path)
def cleanup(self):
"""清理测试数据库"""
if os.path.exists(self.db_path):
shutil.rmtree(self.db_path)
print(f"✅ 已清理测试数据库: {self.db_path}")
def generate_test_code(self, num_lines=100, language="python"):
"""生成测试代码"""
if language == "python":
code = "def function_{i}():\n '''Docstring for function {i}'''\n result = 0\n"
code += " for i in range(100):\n result += i\n"
code += " return result\n\n"
elif language == "javascript":
code = "function function_{i}() {{\n // Docstring for function {i}\n"
code += " let result = 0;\n"
code += (
" for (let i = 0; i < 100; i++) {{\n result += i;\n }}\n"
)
code += " return result;\n}}\n\n"
full_code = ""
for i in range(num_lines):
full_code += code.format(i=i)
return full_code
def benchmark_single_file_write(self, num_lines=100):
"""测试单个文件的写入性能"""
print(f"\n{'=' * 60}")
print(f"测试: 单个文件写入 ({num_lines} 行代码)")
print("=" * 60)
code = self.generate_test_code(num_lines)
file_path = "test_file.py"
# 分块
start_time = time.time()
chunks = chunk_code(code, file_path)
chunk_time = time.time() - start_time
print(f"分块耗时: {chunk_time * 1000:.2f} ms")
print(f"生成块数: {len(chunks)}")
# Embedding
start_time = time.time()
texts = [c["text"] for c in chunks]
vectors = embedder.embed(texts)
embed_time = time.time() - start_time
print(f"Embedding 耗时: {embed_time * 1000:.2f} ms")
# 构建文档
documents = []
for chunk, vec in zip(chunks, vectors):
documents.append(
{
"text": chunk["text"],
"vector": vec,
"metadata": chunk["metadata"],
}
)
# 写入数据库
start_time = time.time()
self.db.add_documents(documents)
db_write_time = time.time() - start_time
print(f"数据库写入耗时: {db_write_time * 1000:.2f} ms")
total_time = chunk_time + embed_time + db_write_time
print(f"\n📊 总耗时: {total_time * 1000:.2f} ms")
return {
"chunk_time": chunk_time * 1000,
"embed_time": embed_time * 1000,
"db_write_time": db_write_time * 1000,
"total_time": total_time * 1000,
"num_chunks": len(chunks),
}
def benchmark_batch_write(self, num_files=10, lines_per_file=50):
"""测试批量文件写入性能"""
print(f"\n{'=' * 60}")
print(f"测试: 批量写入 ({num_files} 个文件, 每个 {lines_per_file} 行)")
print("=" * 60)
total_chunks = 0
total_embed_time = 0
total_db_write_time = 0
for i in range(num_files):
code = self.generate_test_code(lines_per_file)
file_path = f"test_file_{i}.py"
# 分块
chunks = chunk_code(code, file_path)
total_chunks += len(chunks)
# Embedding
texts = [c["text"] for c in chunks]
start_time = time.time()
vectors = embedder.embed(texts)
total_embed_time += time.time() - start_time
# 构建文档并写入
documents = []
for chunk, vec in zip(chunks, vectors):
documents.append(
{
"text": chunk["text"],
"vector": vec,
"metadata": chunk["metadata"],
}
)
start_time = time.time()
self.db.add_documents(documents)
total_db_write_time += time.time() - start_time
total_time = total_embed_time + total_db_write_time
print(f"总块数: {total_chunks}")
print(f"平均每块 embedding: {total_embed_time / num_files * 1000:.2f} ms")
print(f"平均每块数据库写入: {total_db_write_time / num_files * 1000:.2f} ms")
print(f"📊 总耗时: {total_time * 1000:.2f} ms")
print(f"📊 吞吐量: {total_chunks / total_time:.2f} 块/秒")
return {
"total_chunks": total_chunks,
"total_embed_time": total_embed_time * 1000,
"total_db_write_time": total_db_write_time * 1000,
"total_time": total_time * 1000,
"throughput": total_chunks / total_time,
}
def benchmark_search(self, num_queries=10):
"""测试搜索性能"""
print(f"\n{'=' * 60}")
print(f"测试: 搜索性能 ({num_queries} 次查询)")
print("=" * 60)
queries = [
"how to implement user authentication",
"database configuration",
"API endpoint handler",
"error handling function",
"data validation logic",
]
search_times = []
for i in range(num_queries):
query = queries[i % len(queries)]
# Query embedding
start_time = time.time()
query_vec = embedder.embed_single(query)
embed_time = time.time() - start_time
# Search
start_time = time.time()
results = self.db.search(query_vec, top_k=5)
search_time = time.time() - start_time
total_time = embed_time + search_time
search_times.append(total_time)
print(
f"查询 {i + 1}: {query[:30]}... → {len(results)} 结果 ({total_time * 1000:.2f} ms)"
)
avg_time = sum(search_times) / len(search_times)
min_time = min(search_times)
max_time = max(search_times)
print(f"\n📊 搜索性能统计:")
print(f" 平均延迟: {avg_time * 1000:.2f} ms")
print(f" 最小延迟: {min_time * 1000:.2f} ms")
print(f" 最大延迟: {max_time * 1000:.2f} ms")
return {
"avg_time": avg_time * 1000,
"min_time": min_time * 1000,
"max_time": max_time * 1000,
"num_queries": num_queries,
}
def benchmark_list_sources(self):
"""测试 list_sources 性能"""
print(f"\n{'=' * 60}")
print("测试: list_sources 性能")
print("=" * 60)
start_time = time.time()
sources = self.db.list_sources()
elapsed = time.time() - start_time
print(f"列出来源数: {len(sources)}")
print(f"📊 耗时: {elapsed * 1000:.2f} ms")
return {
"num_sources": len(sources),
"time": elapsed * 1000,
}
def run_full_benchmark(self):
"""运行完整性能测试"""
print("\n" + "=" * 60)
print("🚀 Ocrag 性能基准测试")
print("=" * 60)
# 准备测试数据
print("\n📦 准备测试数据...")
self.benchmark_single_file_write(num_lines=100)
# 批量写入测试
self.benchmark_batch_write(num_files=20, lines_per_file=50)
# 搜索测试
self.benchmark_search(num_queries=10)
# List 测试
self.benchmark_list_sources()
print("\n" + "=" * 60)
print("✅ 性能测试完成!")
print("=" * 60)
def main():
import argparse
parser = argparse.ArgumentParser(description="Ocrag 性能测试")
parser.add_argument("--cleanup", action="store_true", help="测试后清理数据")
parser.add_argument("--db-path", type=str, help="指定数据库路径")
args = parser.parse_args()
benchmark = PerformanceBenchmark(db_path=args.db_path)
try:
benchmark.run_full_benchmark()
finally:
if args.cleanup:
benchmark.cleanup()
if __name__ == "__main__":
main()

409
scripts/test_chunking.py Normal file
View File

@ -0,0 +1,409 @@
#!/usr/bin/env python3
"""
代码分块性能测试脚本
测试不同分块策略的效果和性能
1. langchain - 语言感知分块
2. semantic - 语义分块
3. simple - 简单字符分块
使用方法:
uv run python scripts/test_chunking.py
"""
import sys
import time
from pathlib import Path
from typing import List, Dict, Any
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from ocrag.chunker import (
chunk_code,
ChunkStrategy,
get_available_strategies,
LANGCHAIN_AVAILABLE,
SEMANTIC_AVAILABLE,
)
class ChunkingBenchmark:
"""分块性能测试"""
# 测试代码样本
TEST_CASES = {
"python_small": {
"code": '''
class UserManager:
"""用户管理类"""
def __init__(self):
self.users = {}
def add_user(self, user_id, name):
self.users[user_id] = {'name': name, 'active': True}
return True
def remove_user(self, user_id):
if user_id in self.users:
del self.users[user_id]
return True
return False
''',
"file": "test.py",
"expected_classes": 1,
},
"python_medium": {
"code": '''
class User:
"""用户基类"""
def __init__(self, name, email):
self.name = name
self.email = email
self.active = True
def deactivate(self):
self.active = False
class Admin(User):
"""管理员类"""
def __init__(self, name, email, permissions):
super().__init__(name, email)
self.permissions = permissions
def has_permission(self, permission):
return permission in self.permissions
class Guest(User):
"""访客类"""
def __init__(self, name, email):
super().__init__(name, email)
self.active = False
def create_user(user_type, name, email, **kwargs):
"""用户工厂函数"""
if user_type == 'admin':
return Admin(name, email, kwargs.get('permissions', []))
elif user_type == 'guest':
return Guest(name, email)
else:
return User(name, email)
def validate_email(email):
"""验证邮箱格式"""
return '@' in email and '.' in email.split('@')[1]
def process_batch(users):
"""批量处理用户"""
results = []
for user in users:
if user.active:
results.append({
'name': user.name,
'email': user.email,
'status': 'active'
})
return results
''',
"file": "user_manager.py",
"expected_classes": 3,
},
"python_large": {
"code": '''
import json
from typing import List, Dict, Optional
class DataProcessor:
"""数据处理器"""
def __init__(self, config):
self.config = config
self.cache = {}
def process(self, data):
if not self.validate(data):
raise ValueError("Invalid data format")
return self.transform(data)
def validate(self, data):
return isinstance(data, dict) and 'id' in data
def transform(self, data):
return {
'id': data['id'],
'processed': True,
'timestamp': time.time()
}
class APIClient:
"""API 客户端"""
def __init__(self, base_url, api_key):
self.base_url = base_url
self.api_key = api_key
def get(self, endpoint):
url = f"{self.base_url}/{endpoint}"
headers = {'Authorization': f'Bearer {self.api_key}'}
return self._request('GET', url, headers)
def post(self, endpoint, data):
url = f"{self.base_url}/{endpoint}"
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json'
}
return self._request('POST', url, headers, json.dumps(data))
def _request(self, method, url, headers, body=None):
# 实现请求逻辑
pass
class CacheManager:
"""缓存管理器"""
def __init__(self, max_size=1000):
self.max_size = max_size
self.cache = {}
self.access_order = []
def get(self, key):
if key in self.cache:
self._update_access(key)
return self.cache[key]
return None
def set(self, key, value):
if len(self.cache) >= self.max_size:
oldest = self.access_order.pop(0)
del self.cache[oldest]
self.cache[key] = value
self.access_order.append(key)
def _update_access(self, key):
if key in self.access_order:
self.access_order.remove(key)
self.access_order.append(key)
def clear(self):
self.cache.clear()
self.access_order.clear()
''',
"file": "large_module.py",
"expected_classes": 3,
},
"rust": {
"code": """
struct User {
name: String,
email: String,
active: bool,
}
impl User {
fn new(name: String, email: String) -> Self {
User {
name,
email,
active: true,
}
}
fn deactivate(&mut self) {
self.active = false;
}
}
struct Admin {
user: User,
permissions: Vec<String>,
}
impl Admin {
fn has_permission(&self, permission: &str) -> bool {
self.permissions.contains(&permission.to_string())
}
}
fn create_user(name: String, email: String) -> User {
User::new(name, email)
}
fn process_batch(users: Vec<User>) -> Vec<String> {
users
.iter()
.filter(|u| u.active)
.map(|u| u.name.clone())
.collect()
}
""",
"file": "user.rs",
"expected_classes": 2,
},
"cpp": {
"code": """
#include <string>
#include <vector>
class User {
private:
std::string name;
std::string email;
bool active;
public:
User(const std::string& name, const std::string& email)
: name(name), email(email), active(true) {}
void deactivate() {
active = false;
}
bool isActive() const {
return active;
}
};
class Admin : public User {
private:
std::vector<std::string> permissions;
public:
Admin(const std::string& name, const std::string& email)
: User(name, email) {}
void addPermission(const std::string& perm) {
permissions.push_back(perm);
}
};
void processUsers(std::vector<User>& users) {
for (auto& user : users) {
if (user.isActive()) {
// 处理活跃用户
}
}
}
""",
"file": "user.cpp",
"expected_classes": 2,
},
}
def __init__(self):
self.results = {}
def run_benchmark(self, strategy: ChunkStrategy):
"""运行单个策略的基准测试"""
print(f"\n{'=' * 60}")
print(f"测试策略: {strategy.value}")
print("=" * 60)
total_chunks = 0
total_time = 0
first_init_time = None
results = []
for name, test_case in self.TEST_CASES.items():
# 测试初始化时间(首次)
start_time = time.time()
chunks = chunk_code(
test_case["code"],
test_case["file"],
strategy=strategy,
chunk_size=512,
chunk_overlap=50,
)
elapsed = time.time() - start_time
if first_init_time is None:
first_init_time = elapsed
total_time += elapsed
total_chunks += len(chunks)
results.append(
{
"name": name,
"chunks": len(chunks),
"time_ms": elapsed * 1000,
"chars": len(test_case["code"]),
}
)
print(f"\n [{name}]")
print(f" 字符数: {len(test_case['code'])}")
print(f" 分块数: {len(chunks)}")
print(f" 耗时: {elapsed * 1000:.2f}ms")
print(f"\n总耗时: {total_time * 1000:.2f}ms")
print(f"总块数: {total_chunks}")
return {
"strategy": strategy.value,
"total_time_ms": total_time * 1000,
"total_chunks": total_chunks,
"first_init_time_ms": first_init_time * 1000 if first_init_time else 0,
"details": results,
}
def run_all(self):
"""运行所有策略的测试"""
print("=" * 60)
print("📊 代码分块性能测试")
print("=" * 60)
print(f"\n可用的分块策略: {get_available_strategies()}")
print(f"langchain 可用: {LANGCHAIN_AVAILABLE}")
print(f"semantic-text-splitter 可用: {SEMANTIC_AVAILABLE}")
all_results = []
# 测试每个策略
for strategy_name in get_available_strategies():
strategy = ChunkStrategy(strategy_name)
result = self.run_benchmark(strategy)
all_results.append(result)
# 打印汇总
self.print_summary(all_results)
return all_results
def print_summary(self, results: List[Dict]):
"""打印测试汇总"""
print("\n" + "=" * 60)
print("📈 测试结果汇总")
print("=" * 60)
print(f"\n{'策略':<15} {'总耗时':<12} {'首次初始化':<12} {'分块数':<10}")
print("-" * 50)
for result in results:
print(
f"{result['strategy']:<15} "
f"{result['total_time_ms']:>8.2f}ms "
f"{result['first_init_time_ms']:>8.2f}ms "
f"{result['total_chunks']:>6}"
)
# 找出最快的策略
fastest = min(results, key=lambda x: x["total_time_ms"])
most_chunks = max(results, key=lambda x: x["total_chunks"])
print(
f"\n🏆 最快策略: {fastest['strategy']} ({fastest['total_time_ms']:.2f}ms)"
)
print(
f"📦 最多分块: {most_chunks['strategy']} ({most_chunks['total_chunks']} 块)"
)
def main():
benchmark = ChunkingBenchmark()
results = benchmark.run_all()
print("\n" + "=" * 60)
print("✅ 测试完成!")
print("=" * 60)
if __name__ == "__main__":
main()

0
src/ocrag/__init__.py Normal file
View File

99
src/ocrag/chunker.py Normal file
View File

@ -0,0 +1,99 @@
"""
代码分块模块
使用 langchain_text_splitters 进行语言感知分块
支持 29 种编程语言的语法感知分块tiktoken 已内置无需额外下载
"""
from pathlib import Path
from typing import List, Dict, Any
try:
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
except ImportError:
raise ImportError(
"请安装 langchain-text-splitters: uv pip install langchain-text-splitters"
)
LANGUAGE_MAP = {
".py": Language.PYTHON,
".rs": Language.RUST,
".cpp": Language.CPP,
".cc": Language.CPP,
".cxx": Language.CPP,
".go": Language.GO,
".java": Language.JAVA,
".js": Language.JS,
".ts": Language.TS,
".jsx": Language.JS,
".tsx": Language.TS,
".c": Language.C,
".h": Language.C,
".hpp": Language.CPP,
".rb": Language.RUBY,
".swift": Language.SWIFT,
".kt": Language.KOTLIN,
".md": Language.MARKDOWN,
".txt": None,
}
def detect_language(file_path: str) -> Language | None:
"""检测文件语言类型"""
ext = Path(file_path).suffix.lower()
return LANGUAGE_MAP.get(ext)
def chunk_code(
content: str,
file_path: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
) -> List[Dict[str, Any]]:
"""将代码文件分块,返回块列表
使用 langchain_text_splitters 进行语言感知分块
支持 29 种编程语言的语法感知分块
Args:
content: 文件内容
file_path: 文件路径
chunk_size: 分块大小token
chunk_overlap: 重叠大小token
Returns:
分块列表每个块包含 text metadata
"""
language = detect_language(file_path)
if language is None:
# 不支持的语言,使用简单分块
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""],
length_function=len,
)
texts = splitter.split_text(content)
language_name = "text"
else:
# 使用语言感知分块
splitter = RecursiveCharacterTextSplitter.from_language(
language=language,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
texts = splitter.split_text(content)
language_name = language.value
return [
{
"text": text,
"metadata": {
"source_file": file_path,
"chunk_index": i,
"language": language_name,
},
}
for i, text in enumerate(texts)
]

45
src/ocrag/cli.py Normal file
View File

@ -0,0 +1,45 @@
import click
from ocrag.commands import add as add_cmd
from ocrag.commands import search as search_cmd
from ocrag.commands import remove as remove_cmd
from ocrag.commands import list_cmd
@click.group()
@click.version_option()
def main():
"""ocrag - 代码库 RAG 命令行工具"""
pass
@main.command()
@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True))
@click.option("--recursive", "-r", is_flag=True, help="递归处理目录")
def add(paths, recursive):
"""向知识库添加文件或目录"""
add_cmd(paths, recursive)
@main.command()
@click.argument("query")
@click.option("--top-k", "-k", default=5, show_default=True, help="返回结果数量")
def search(query, top_k):
"""语义搜索知识库"""
search_cmd(query, top_k)
@main.command()
@click.argument("path", type=click.Path())
def remove(path):
"""从知识库中移除指定文件"""
remove_cmd(path)
@main.command()
def list():
"""列出知识库中的所有条目"""
list_cmd()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
from ocrag.commands.add import run as add
from ocrag.commands.search import run as search
from ocrag.commands.remove import run as remove
from ocrag.commands.list import run as list_cmd

56
src/ocrag/commands/add.py Normal file
View File

@ -0,0 +1,56 @@
import os
from pathlib import Path
from ocrag.chunker import chunk_code
from ocrag.embedder import embedder
from ocrag.db import RagDB
def collect_files(paths, recursive):
"""收集所有需要处理的文件"""
files = []
for p in paths:
p = Path(p)
if p.is_file():
files.append(p)
elif p.is_dir() and recursive:
for root, dirs, filenames in os.walk(p):
for f in filenames:
files.append(Path(root) / f)
elif p.is_dir() and not recursive:
print(f"跳过目录 {p}(未使用 --recursive")
return files
def run(paths, recursive):
db = RagDB()
files = collect_files(paths, recursive)
total_chunks = 0
for file_path in files:
try:
content = file_path.read_text(encoding="utf-8")
chunks = chunk_code(content, str(file_path))
if not chunks:
continue
# 批量 embedding
texts = [c["text"] for c in chunks]
vectors = embedder.embed(texts)
documents = []
for chunk, vec in zip(chunks, vectors):
documents.append(
{
"text": chunk["text"],
"vector": vec,
"metadata": chunk["metadata"],
}
)
db.add_documents(documents)
total_chunks += len(documents)
print(f"{file_path} -> {len(documents)} 个块")
except Exception as e:
print(f"❌ 处理 {file_path} 失败: {e}")
print(f"\n📦 总计添加 {total_chunks} 个块")

View File

@ -0,0 +1,14 @@
from ocrag.db import RagDB
def run():
db = RagDB()
sources = db.list_sources()
if not sources:
print("知识库为空")
return
print("知识库中的文件:")
for i, source in enumerate(sources, 1):
print(f"{i}. {source}")

View File

@ -0,0 +1,9 @@
from ocrag.db import RagDB
def run(path: str):
db = RagDB()
# 注意当前db.py中的delete_by_source方法尚未完全实现
# 这里先调用以保持接口一致后续需要完善db.py的实现
db.delete_by_source(path)
print(f"已删除 {path} 的所有块")

View File

@ -0,0 +1,18 @@
from ocrag.embedder import embedder
from ocrag.db import RagDB
def run(query: str, top_k: int):
db = RagDB()
query_vec = embedder.embed_single(query)
results = db.search(query_vec, top_k)
if not results:
print("未找到相关结果。")
return
for i, r in enumerate(results, 1):
print(f"\n[{i}] 相似度: {r['_distance']:.4f}")
print(f"来源: {r['metadata'].get('source_file', 'unknown')}")
print(f"内容:\n{r['text']}")
print("-" * 80)

112
src/ocrag/db.py Normal file
View File

@ -0,0 +1,112 @@
import lancedb
import pyarrow as pa
from typing import List, Dict, Any
from pathlib import Path
DB_PATH = Path.home() / ".ocrag" / "data.lance"
class RagDB:
def __init__(self, db_path: str = None):
self.path = db_path or str(DB_PATH)
self.conn = lancedb.connect(self.path)
self._init_table()
def _init_table(self):
"""如果表不存在则创建"""
schema = pa.schema(
[
("text", pa.string()),
(
"vector",
pa.list_(pa.float32(), 1024),
), # 1024维 (Qwen3-Embedding-0.6B)
("metadata", pa.string()), # JSON 字符串
]
)
try:
self.table = self.conn.open_table("documents")
except:
self.table = self.conn.create_table("documents", schema=schema)
def add_documents(self, documents: List[Dict[str, Any]]):
"""批量添加文档
documents: [{"text": str, "vector": List[float], "metadata": dict}, ...]
"""
import json
data = []
for doc in documents:
data.append(
{
"text": doc["text"],
"vector": doc["vector"],
"metadata": json.dumps(doc.get("metadata", {})),
}
)
self.table.add(data)
def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
import json
results = self.table.search(query_vector).limit(top_k).to_list()
for r in results:
r["metadata"] = json.loads(r["metadata"])
return results
def delete_by_source(self, source_file: str) -> int:
"""删除指定源文件的所有块
Args:
source_file: 要删除的源文件路径
Returns:
删除的块数量
"""
import json
# 安全实现:使用 Pandas 过滤来避免 SQL 注入
df = self.table.to_pandas()
if df.empty:
return 0
# 过滤出不包含该源文件的行
indices_to_keep = []
for idx, row in df.iterrows():
try:
meta = json.loads(row["metadata"])
if meta.get("source_file") != source_file:
indices_to_keep.append(idx)
except (json.JSONDecodeError, KeyError):
# 如果解析失败,保留该行
indices_to_keep.append(idx)
num_deleted = len(df) - len(indices_to_keep)
if num_deleted == 0:
return 0
# 保留不需要删除的行
df_remaining = df.loc[indices_to_keep]
# 使用 overwrite 模式重建表
self.conn.drop_table("documents")
self.table = self.conn.create_table("documents", df_remaining)
return num_deleted
def list_sources(self) -> List[str]:
"""列出所有已添加的源文件路径(去重)"""
# 获取所有 metadata提取 source_file
df = self.table.to_pandas()
if df.empty:
return []
import json
sources = set()
for meta_str in df["metadata"]:
meta = json.loads(meta_str)
if "source_file" in meta:
sources.add(meta["source_file"])
return sorted(sources)

45
src/ocrag/embedder.py Normal file
View File

@ -0,0 +1,45 @@
from pathlib import Path
from typing import List
import os
from sentence_transformers import SentenceTransformer
class Embedder:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
# 计算模型路径
model_path = (
Path(__file__).parent.parent.parent / "models" / "Qwen3-Embedding-0.6B"
)
if not model_path.exists():
raise FileNotFoundError(f"模型目录不存在: {model_path}")
try:
cls._instance.model = SentenceTransformer(
str(model_path),
device="cuda"
if os.getenv("USE_GPU", "false").lower() == "true"
else "cpu",
)
except Exception as e:
raise RuntimeError(f"加载模型失败: {e}")
return cls._instance
def embed(self, texts: List[str]) -> List[List[float]]:
"""返回向量列表,每个向量为 float 列表"""
embeddings = self.model.encode(texts, normalize_embeddings=True)
return embeddings.tolist()
def embed_single(self, text: str) -> List[float]:
return self.embed([text])[0]
# 全局单例
embedder = Embedder()

78
src/ocrag/utils.py Normal file
View File

@ -0,0 +1,78 @@
import os
import hashlib
from pathlib import Path
from typing import Optional
def get_file_hash(file_path: str) -> str:
"""计算文件的MD5哈希值用于检测文件变化"""
md5_hash = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()
def ensure_dir(dir_path: str) -> None:
"""确保目录存在,不存在则创建"""
Path(dir_path).mkdir(parents=True, exist_ok=True)
def get_file_size(file_path: str) -> int:
"""获取文件大小(字节)"""
return os.path.getsize(file_path)
def format_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ["B", "KB", "MB", "GB"]:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} TB"
def is_text_file(file_path: str) -> bool:
"""检查文件是否为文本文件"""
text_extensions = {
".py",
".js",
".ts",
".jsx",
".tsx",
".java",
".c",
".cpp",
".h",
".hpp",
".go",
".rs",
".rb",
".php",
".swift",
".kt",
".scala",
".md",
".txt",
".json",
".yaml",
".yml",
".xml",
".html",
".css",
".sql",
".sh",
".bash",
".zsh",
".ps1",
".bat",
".vue",
".svelte",
}
ext = Path(file_path).suffix.lower()
return ext in text_extensions
def get_project_root() -> Path:
"""获取项目根目录"""
return Path(__file__).parent.parent.parent

View File

@ -0,0 +1,45 @@
import pytest
from ocrag.chunker import chunk_code
@pytest.fixture
def sample_python_code():
return """def hello():
print('Hello, World!')
class Test:
def method(self):
pass"""
@pytest.fixture
def sample_markdown_content():
return """# Title
Paragraph 1
## Subtitle
Paragraph 2"""
def test_chunk_code_python(sample_python_code):
chunks = chunk_code(sample_python_code, "test.py")
assert len(chunks) > 0
assert "def hello" in chunks[0]["text"]
assert chunks[0]["metadata"]["language"] == "python"
def test_chunk_code_markdown(sample_markdown_content):
chunks = chunk_code(sample_markdown_content, "test.md")
assert len(chunks) > 0
assert "# Title" in chunks[0]["text"]
assert chunks[0]["metadata"]["language"] == "markdown"
def test_chunk_code_text():
content = "a" * 10000 # Large content to test text fallback
chunks = chunk_code(content, "large.txt")
assert len(chunks) > 0
assert chunks[0]["metadata"]["source_file"] == "large.txt"
assert chunks[0]["metadata"]["language"] == "text"

View File

@ -0,0 +1,63 @@
import pytest
import os
from click.testing import CliRunner
from ocrag.cli import main
from ocrag.db import RagDB
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture
def setup_test_env(tmpdir):
# Create test files
os.makedirs(os.path.join(tmpdir, "test_dir"))
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
f.write("def test_func():\n pass")
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
f.write("print('Hello')")
return tmpdir
def test_add_command(runner, setup_test_env, tmpdir):
# Use temporary DB
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
assert "test_file.py" in result.output
assert "总计添加" in result.output
assert result.exit_code == 0
def test_search_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
add_result = runner.invoke(
main, ["add", os.path.join(setup_test_env, "test_file.py")]
)
assert add_result.exit_code == 0
# Search
result = runner.invoke(main, ["search", "Hello"])
assert "Hello" in result.output
assert "来源: " in result.output
def test_list_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
add_result = runner.invoke(
main, ["add", os.path.join(setup_test_env, "test_file.py")]
)
assert add_result.exit_code == 0
# List
result = runner.invoke(main, ["list"])
assert "test_file.py" in result.output

115
tests/unit/test_db.py Normal file
View File

@ -0,0 +1,115 @@
import pytest
import os
import shutil
from ocrag.db import RagDB
@pytest.fixture
def temp_db(tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
yield RagDB(db_path)
if os.path.exists(db_path):
shutil.rmtree(db_path)
def test_db_initialization(temp_db):
assert temp_db.table is not None
def test_add_documents(temp_db):
documents = [
{
"text": "Test content",
"vector": [0.1] * 1024,
"metadata": {"source_file": "test.py"},
}
]
temp_db.add_documents(documents)
assert temp_db.table.to_pandas().shape[0] == 1
def test_search(temp_db):
# Add a document
documents = [
{
"text": "Database configuration",
"vector": [0.1] * 1024,
"metadata": {"source_file": "config.py"},
}
]
temp_db.add_documents(documents)
# Search
results = temp_db.search([0.1] * 1024)
assert len(results) == 1
assert "Database configuration" in results[0]["text"]
def test_list_sources(temp_db):
# Add documents from multiple sources
documents = [
{
"text": "Content 1",
"vector": [0.1] * 1024,
"metadata": {"source_file": "file1.py"},
},
{
"text": "Content 2",
"vector": [0.2] * 1024,
"metadata": {"source_file": "file2.py"},
},
]
temp_db.add_documents(documents)
sources = temp_db.list_sources()
assert len(sources) == 2
assert "file1.py" in sources
assert "file2.py" in sources
def test_delete_by_source(temp_db):
# Add documents from multiple sources
documents = [
{
"text": "Content A",
"vector": [0.1] * 1024,
"metadata": {"source_file": "file1.py"},
},
{
"text": "Content B",
"vector": [0.2] * 1024,
"metadata": {"source_file": "file2.py"},
},
{
"text": "Content C",
"vector": [0.3] * 1024,
"metadata": {"source_file": "file1.py"},
},
{
"text": "Content D",
"vector": [0.4] * 1024,
"metadata": {"source_file": "file3.py"},
},
]
temp_db.add_documents(documents)
# Verify initial state
sources = temp_db.list_sources()
assert len(sources) == 3
# Delete file1.py (should delete 2 chunks)
num_deleted = temp_db.delete_by_source("file1.py")
assert num_deleted == 2
# Verify deletion
sources = temp_db.list_sources()
assert len(sources) == 2
assert "file1.py" not in sources
assert "file2.py" in sources
assert "file3.py" in sources
def test_delete_nonexistent_source(temp_db):
# Try to delete a source that doesn't exist
num_deleted = temp_db.delete_by_source("nonexistent.py")
assert num_deleted == 0

View File

@ -0,0 +1,28 @@
import pytest
from ocrag.embedder import Embedder
@pytest.fixture(scope="module")
def embedder():
return Embedder()
def test_embedder_singleton(embedder):
embedder2 = Embedder()
assert embedder is embedder2
def test_embed_single(embedder):
vector = embedder.embed_single("Test sentence")
assert len(vector) == 1024 # Qwen3-Embedding-0.6B output dimension
assert isinstance(vector[0], float)
def test_embed_batch(embedder):
vectors = embedder.embed(["Sentence 1", "Sentence 2"])
assert len(vectors) == 2
assert len(vectors[0]) == 1024
assert len(vectors[1]) == 1024
assert (
vectors[0] != vectors[1]
) # Different sentences should have different embeddings