From 28e557594a5c525aa5d1aba111a98dd4f36ae66c Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 16 Apr 2026 11:31:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E7=9F=A5=E8=AF=86=E5=BA=93=20RAG=20=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=EF=BC=8C=E6=94=AF=E6=8C=81=E8=AF=AD=E4=B9=89=E6=90=9C?= =?UTF-8?q?=E7=B4=A2=E4=B8=8E=E5=AE=9E=E6=97=B6=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 182 ++++++++++++++ DescriptionOfDesign.md | 386 ++++++++++++++++++++++++++++++ README.md | 191 +++++++++++++++ opencode-plugin/ocrag-plugin.ts | 42 ++++ opencode-plugin/package.json | 9 + opencode-skill/SKILL.md | 34 +++ pyproject.toml | 20 ++ pytest.ini | 3 + scripts/benchmark.py | 278 ++++++++++++++++++++++ scripts/test_chunking.py | 409 ++++++++++++++++++++++++++++++++ src/ocrag/__init__.py | 0 src/ocrag/chunker.py | 99 ++++++++ src/ocrag/cli.py | 45 ++++ src/ocrag/commands/__init__.py | 4 + src/ocrag/commands/add.py | 56 +++++ src/ocrag/commands/list.py | 14 ++ src/ocrag/commands/remove.py | 9 + src/ocrag/commands/search.py | 18 ++ src/ocrag/db.py | 112 +++++++++ src/ocrag/embedder.py | 45 ++++ src/ocrag/utils.py | 78 ++++++ tests/unit/test_chunker.py | 45 ++++ tests/unit/test_commands.py | 63 +++++ tests/unit/test_db.py | 115 +++++++++ tests/unit/test_embedder.py | 28 +++ 25 files changed, 2285 insertions(+) create mode 100644 .gitignore create mode 100644 DescriptionOfDesign.md create mode 100644 README.md create mode 100644 opencode-plugin/ocrag-plugin.ts create mode 100644 opencode-plugin/package.json create mode 100644 opencode-skill/SKILL.md create mode 100644 pyproject.toml create mode 100644 pytest.ini create mode 100755 scripts/benchmark.py create mode 100644 scripts/test_chunking.py create mode 100644 src/ocrag/__init__.py create mode 100644 src/ocrag/chunker.py create mode 100644 src/ocrag/cli.py create mode 100644 src/ocrag/commands/__init__.py create mode 100644 src/ocrag/commands/add.py create mode 100644 src/ocrag/commands/list.py create mode 100644 src/ocrag/commands/remove.py create mode 100644 src/ocrag/commands/search.py create mode 100644 src/ocrag/db.py create mode 100644 src/ocrag/embedder.py create mode 100644 src/ocrag/utils.py create mode 100644 tests/unit/test_chunker.py create mode 100644 tests/unit/test_commands.py create mode 100644 tests/unit/test_db.py create mode 100644 tests/unit/test_embedder.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..284b0d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,182 @@ +models/* +uv.lock + +# ---> Python +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# ---> VisualStudioCode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +uv.lock + +data/* diff --git a/DescriptionOfDesign.md b/DescriptionOfDesign.md new file mode 100644 index 0000000..e387ce5 --- /dev/null +++ b/DescriptionOfDesign.md @@ -0,0 +1,386 @@ +# OpenCode RAG 插件设计文档 + +> 本文档详细说明了 ocrag 项目的设计目的、技术选型考量、架构构思和实现细节。 + +--- + +## 1. 设计背景与目标 + +### 1.1 问题背景 + +在大型代码库中进行 AI 辅助开发时,AI 需要理解项目中的大量代码才能给出准确的建议和答案。然而: +- **上下文窗口有限**:无法将整个代码库放入 AI 提示词 +- **实时性需求**:开发者需要 AI 能立即理解刚编写的代码 +- **本地化优先**:代码不应该上传到外部服务器 + +### 1.2 设计目标 + +为 OpenCode 构建一个**本地代码知识库 RAG 系统**,实现: + +| 功能 | 描述 | +|------|------| +| **实时添加** | 将代码文件或目录添加到本地知识库 | +| **语义搜索** | 通过自然语言查询获取相关代码片段 | +| **智能管理** | 支持删除和列出知识库中的条目 | + +### 1.3 设计原则 + +1. **本地化优先**:所有数据和计算都在本地完成,不依赖外部服务 +2. **轻量高效**:避免引入复杂的服务端组件,保持极低的延迟 +3. **零运维**:嵌入式数据库,无需安装配置,即装即用 +4. **AI 友好**:生成可被 AI 直接理解和使用的上下文 + +--- + +## 2. 技术选型详解 + +### 2.1 为什么选择 Python + +| 考量因素 | Python | Rust | +|----------|--------|------| +| 开发效率 | ✅ 高 | ⚠️ 中 | +| LLM 生成质量 | ✅ 高(AI 更熟悉 Python) | ⚠️ 中 | +| 生态丰富度 | ✅ 成熟 | ⚠️ 一般 | +| 运行时性能 | ⚠️ 中(可通过 C 扩展优化) | ✅ 高 | +| 社区支持 | ✅ 丰富 | ⚠️ 有限 | + +**结论**:Python 的高开发效率和 AI 生成质量优势明显,即使运行时性能略低,但对于 RAG 这种 I/O 密集型应用影响有限。 + +### 2.2 为什么选择 LanceDB + +传统向量数据库对比: + +| 数据库 | 特点 | 缺点 | +|--------|------|------| +| Chroma | 简单易用 | 不支持持久化、不适合生产 | +| Milvus | 功能强大 | 需要 Docker 部署 | +| Qdrant | Rust 实现,高性能 | 需要单独部署 | +| **LanceDB** | **嵌入式、零运维、Python 原生** | **相对较新** | + +**LanceDB 优势**: +- **嵌入式**:数据库就是一个文件夹,无需单独服务 +- **零运维**:安装即用,自动管理 +- **Python 优先**:原生 Python SDK,类型提示完善 +- **性能优秀**:基于 Apache Arrow,查询速度快 + +### 2.3 为什么选择 langchain-text-splitters(而非 chonkie) + +最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现: + +| 方案 | 优势 | 劣势 | +|------|------|------| +| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer,初始化慢 | +| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 | + +**最终选择 langchain-text-splitters**: +- 对于 RAG 场景,分块精度不是最关键因素 +- langchain 的方案更加轻量,启动更快 +- 代码按自然段落(函数、类)边界切分,效果足够好 + +### 2.4 为什么使用 Qwen3-Embedding-0.6B + +| 模型 | 维度 | 优势 | 劣势 | +|------|------|------|------| +| all-MiniLM-L6-v2 | 384 | 快速、小巧 | 英文为主 | +| **Qwen3-Embedding-0.6B** | **1024** | **中文优化、中英双语** | 较大、首次加载慢 | + +**选择理由**: +- 开源可本地部署,代码安全 +- 中文支持优秀 +- 适合代码+文档混合场景 + +### 2.5 为什么不用 MCP 服务器 + +**MCP 方案的劣势**: +- 需要额外部署 MCP 服务器 +- 增加系统复杂度 +- 调试困难 + +**CLI 方案的优势**: +- 零额外组件 +- 通过 `bash` 工具直接调用 +- 易于调试和扩展 + +--- + +## 3. 架构设计 + +### 3.1 整体架构 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ OpenCode AI │ +└────────────────────────────┬──────────────────────────────────┘ + │ 1. rag_add / rag_search + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ TypeScript Plugin (ocrag-plugin.ts) │ +│ │ +│ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ +│ │ rag_add │ │rag_search│ │ Skill 指令 │ │ +│ └────┬─────┘ └────┬─────┘ └──────────────┘ │ +└───────┼─────────────────┼───────────────────────────────────────┘ + │ │ + ▼ ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ Python CLI (ocrag) │ +│ │ +│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────────┐ │ +│ │ add │ │ search │ │ remove │ │ list │ │ +│ └───┬─────┘ └────┬─────┘ └───┬─────┘ └──────┬──────┘ │ +└──────┼───────────────┼─────────────┼──────────────────┼──────────┘ + │ │ │ │ + ▼ ▼ ▼ │ +┌─────────────────────────────────────────┐ │ +│ Processing Pipeline │ │ +│ │ │ +│ ┌────────────┐ ┌────────────────┐ │ │ +│ │ Chunker │───▶│ Embedder │──┼────────────┤ +│ │ (分块) │ │ (向量化) │ │ │ +│ └────────────┘ └────────────────┘ │ │ +└─────────────────────────────────────────┼────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ LanceDB (向量数据库) │ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ documents 表 │ │ +│ │ ┌──────────┬──────────────────────┬──────────────────┐ │ │ +│ │ │ text │ vector │ metadata │ │ │ +│ │ │ (文本) │ (1024维向量) │ (JSON元数据) │ │ │ +│ │ └──────────┴──────────────────────┴──────────────────┘ │ │ +│ └──────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### 3.2 数据流设计 + +#### 添加文件流程 + +``` +用户请求 → Plugin → CLI:add → Chunker → Embedder → LanceDB → 返回结果 +``` + +**详细步骤**: +1. **文件收集**:递归遍历目录或单个文件 +2. **内容读取**:以 UTF-8 编码读取文件内容 +3. **代码分块**:按语言和语法结构切分代码 +4. **向量化**:使用 Embedder 将文本转为 1024 维向量 +5. **存储入库**:将文本、向量、元数据存入 LanceDB + +#### 搜索流程 + +``` +用户查询 → Plugin → CLI:search → Embedder → LanceDB → 返回结果 +``` + +**详细步骤**: +1. **查询向量化**:将自然语言查询转为向量 +2. **相似度搜索**:LanceDB 执行向量相似度检索 +3. **结果排序**:按相似度距离排序 +4. **元数据提取**:解析 JSON 元数据 +5. **结果返回**:格式化的代码片段列表 + +### 3.3 模块职责划分 + +| 模块 | 职责 | 依赖 | +|------|------|------| +| `cli.py` | 命令行入口,参数解析 | click | +| `chunker.py` | 代码分块处理 | langchain-text-splitters | +| `embedder.py` | 文本向量化 | sentence-transformers | +| `db.py` | 向量数据库操作 | lancedb, pyarrow | +| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db | +| `utils.py` | 工具函数 | - | + +--- + +## 4. 核心算法设计 + +### 4.1 代码分块算法 + +**设计思路**:按代码的语义结构进行分块 + +```python +# Python 语言的分隔符优先级 +separators = [ + "\n\nclass ", # 类定义 + "\n\ndef ", # 函数定义 + "\n\nasync def ", # 异步函数 + "\n\n", # 空行段落 + "\n", # 换行 + " ", # 空格 + "" # 字符 +] +``` + +**分块参数**: +- `chunk_size=2000`:每块约 2000 字符 +- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯 + +### 4.2 向量化策略 + +**模型选择**:Qwen3-Embedding-0.6B +- 输出维度:1024 +- 向量归一化:L2 归一化,便于余弦相似度计算 + +**批量处理**: +- 一次性对多个文本块进行编码 +- 利用 GPU/CPU 的批处理能力提升吞吐量 + +### 4.3 搜索策略 + +**相似度度量**:LanceDB 默认使用余弦相似度 + +**top_k 参数**:控制返回结果数量,默认 5 条 + +--- + +## 5. 安全与性能考量 + +### 5.1 安全性设计 + +| 风险点 | 防护措施 | +|--------|----------| +| SQL 注入 | 使用 Pandas 过滤而非 SQL 字符串拼接 | +| 路径遍历 | 仅处理指定路径,不执行动态导入 | +| 数据泄露 | 所有数据本地存储,不涉及网络传输 | + +### 5.2 性能优化 + +**已实现**: +- ✅ 批量 embedding 减少 I/O 开销 +- ✅ 单例模式避免重复加载模型 +- ✅ 向量归一化加速相似度计算 + +**可优化项**: +- 文件哈希缓存避免重复添加 +- GPU 加速 embedding 计算 +- 增量索引更新而非全量重建 + +### 5.3 性能基准 + +| 指标 | 实测结果 | 说明 | +|------|----------|------| +| 搜索延迟 | **63-70 ms** | 包含 embedding + 向量检索 | +| 数据库写入 | 2-3 ms/块 | LanceDB 性能优秀 | +| 分块速度 | <1 ms | 纯规则,无模型加载 | +| Embedding | ~2.5秒/块 | Qwen3 模型较大 | + +--- + +## 6. 扩展性设计 + +### 6.1 多语言支持 + +只需在 `LANGUAGE_MAP` 中添加新扩展名即可: + +```python +LANGUAGE_MAP = { + ".py": "python", + ".js": "javascript", + # 添加新语言... + ".kt": "kotlin", + ".swift": "swift", +} +``` + +### 6.2 自定义 Embedding 模型 + +修改 `embedder.py` 中的模型路径: + +```python +model_path = "path/to/your/model" +``` + +### 6.3 增量同步(Watch 模式) + +使用 `watchdog` 库监听文件变化,自动更新知识库: + +```python +observer = Observer() +observer.schedule(RAGSyncHandler(), path, recursive=True) +observer.start() +``` + +--- + +## 7. 与 OpenCode 的集成 + +### 7.1 插件架构 + +``` +OpenCode + ├── TypeScript Plugin + │ ├── rag_add 工具 + │ └── rag_search 工具 + │ + └── Skill 指令 + └── SKILL.md +``` + +### 7.2 工具定义 + +| 工具名 | 参数 | 功能 | +|--------|------|------| +| `rag_add` | `paths: string[]`, `recursive?: boolean` | 添加文件到知识库 | +| `rag_search` | `query: string`, `top_k?: number` | 搜索知识库 | + +### 7.3 错误处理 + +所有工具调用都包裹在 try-catch 中: + +```typescript +execute: async (args) => { + try { + const result = await Bun.$`${cmd}`.text(); + return result; + } catch (error) { + return `Error: ${error.message}`; + } +} +``` + +--- + +## 8. 设计总结 + +### 8.1 核心创新 + +1. **零 MCP 架构**:通过 CLI 直接集成,简化系统复杂度 +2. **本地化优先**:数据不出本地,保证代码安全 +3. **轻量高效**:嵌入式数据库,秒级启动 +4. **AI 原生**:输出格式专为 AI 消费设计 + +### 8.2 适用场景 + +| 场景 | 适用性 | 说明 | +|------|--------|------| +| 个人项目知识管理 | ✅ 非常适合 | 本地存储,隐私安全 | +| 小团队代码库问答 | ✅ 适合 | 轻量易部署 | +| 大型企业代码库 | ⚠️ 需优化 | 可能需要分布式扩展 | +| 跨语言代码库 | ✅ 支持 | 多语言分块支持 | + +### 8.3 未来展望 + +1. **语义分块升级**:集成更智能的分块算法 +2. **多模态支持**:支持图片、图表等非文本内容 +3. **增量索引**:支持大型代码库的实时更新 +4. **分布式部署**:支持多节点协同检索 + +--- + +## 附录:术语表 + +| 术语 | 解释 | +|------|------| +| RAG | Retrieval-Augmented Generation,检索增强生成 | +| Embedding | 将文本转为向量的过程 | +| 向量数据库 | 专门存储和检索向量的数据库 | +| Chunk | 分块后的文本片段 | +| LanceDB | 嵌入式向量数据库 | + +--- + +**文档版本**:1.0 +**最后更新**:2026年4月15日 diff --git a/README.md b/README.md new file mode 100644 index 0000000..55baca4 --- /dev/null +++ b/README.md @@ -0,0 +1,191 @@ +# Ocrag - OpenCode 代码知识库 RAG 插件 + +> 为 OpenCode 提供本地代码语义搜索能力,支持将代码文件添加到知识库并通过自然语言进行检索。 + +[English](./README_EN.md) | [设计文档](./DescriptionOfDesign.md) + +--- + +## ✨ 特性 + +- 🚀 **零配置**:嵌入式数据库,安装即用 +- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器 +- 🔍 **语义搜索**:通过自然语言查询相关代码片段 +- 📦 **多种语言**:支持 Python、JavaScript、TypeScript、Rust、Go 等主流编程语言 +- ⚡ **极速响应**:搜索延迟 < 100ms +- 🤖 **AI 友好**:输出格式专为 AI 消费设计 + +--- + +## 📦 安装 + +### 环境要求 + +- Python 3.10+ +- [uv](https://github.com/astral-sh/uv)(包管理器) + +### 安装步骤 + +```bash +# 1. 安装 Python 包 +uv pip install -e . + +# 2. 安装 OpenCode 插件(可选) +cp opencode-plugin/ocrag-plugin.ts ~/.config/opencode/plugins/ + +# 3. 安装 Skill(可选) +mkdir -p ~/.config/opencode/skills/ocrag +cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/ +``` + +--- + +## 🚀 快速开始 + +### 基本使用 + +```bash +# 添加文件到知识库 +uv run ocrag add ./src/main.py + +# 递归添加整个目录 +uv run ocrag add ./src/ --recursive + +# 搜索相关代码 +uv run ocrag search "如何实现用户认证" + +# 列出知识库中的文件 +uv run ocrag list + +# 删除文件 +uv run ocrag remove ./src/main.py +``` + +### OpenCode 集成 + +安装插件后,AI 可以自动使用知识库功能: + +``` +用户:把当前文件加入知识库 +AI:✓ 已将 main.py 添加到知识库 + +用户:认证模块是怎么实现的? +AI:根据知识库中的代码,认证模块包含以下关键实现... +``` + +--- + +## 📁 项目结构 + +``` +ocrag/ +├── src/ocrag/ # Python 源代码 +│ ├── cli.py # 命令行入口 +│ ├── chunker.py # 代码分块 +│ ├── embedder.py # 向量化模块 +│ ├── db.py # 数据库操作 +│ └── commands/ # 命令实现 +│ ├── add.py +│ ├── search.py +│ ├── remove.py +│ └── list.py +├── models/ # Embedding 模型 +│ └── Qwen3-Embedding-0.6B/ +├── tests/ # 单元测试 +├── scripts/ # 工具脚本 +│ └── benchmark.py # 性能测试 +├── opencode-plugin/ # OpenCode 插件 +└── opencode-skill/ # Skill 文件 +``` + +--- + +## 🔧 配置 + +### 模型配置 + +默认使用 `models/Qwen3-Embedding-0.6B` 本地模型。 + +如需使用 GPU 加速: +```bash +USE_GPU=true uv run ocrag search "查询内容" +``` + +### 数据库位置 + +数据默认存储在 `~/.ocrag/data.lance` + +--- + +## 📊 性能 + +| 操作 | 延迟 | 说明 | +|------|------|------| +| 搜索 | **~65ms** | 含 embedding + 向量检索 | +| 添加文件 | **~5s** | 含分块 + embedding + 存储 | +| 列出文件 | **~5ms** | 纯数据库查询 | + +运行性能测试: +```bash +uv run python scripts/benchmark.py +``` + +--- + +## ✅ 运行测试 + +```bash +# 运行所有测试 +uv run python -m pytest + +# 运行特定模块测试 +uv run python -m pytest tests/unit/test_db.py -v + +# 查看测试覆盖 +uv run python -m pytest --cov=src/ocrag +``` + +--- + +## 🛠️ 开发 + +### 添加新命令 + +1. 在 `src/ocrag/commands/` 下创建新文件 +2. 实现 `run()` 函数 +3. 在 `src/ocrag/cli.py` 中注册命令 +4. 在 `src/ocrag/commands/__init__.py` 中导出 + +### 代码规范 + +```bash +# 格式化代码 +ruff format src/ + +# 代码检查 +ruff check src/ +``` + +--- + +## 📚 相关文档 + +- [设计文档](./DescriptionOfDesign.md) - 详细的设计说明和技术选型分析 +- [CHANGELOG](./CHANGELOG.md) - 版本更新记录 + +--- + +## 🤝 贡献 + +欢迎提交 Issue 和 Pull Request! + +--- + +## 📄 许可证 + +MIT License + +--- + +**版本**:0.1.0 +**最后更新**:2026年4月15日 diff --git a/opencode-plugin/ocrag-plugin.ts b/opencode-plugin/ocrag-plugin.ts new file mode 100644 index 0000000..74844df --- /dev/null +++ b/opencode-plugin/ocrag-plugin.ts @@ -0,0 +1,42 @@ +import { tool, definePlugin } from "opencode"; +import { z } from "zod"; + +export default definePlugin({ + name: "ocrag-plugin", + hooks: () => ({ + rag_add: tool({ + description: "向本地 RAG 知识库添加一个或多个代码文件或目录。添加后,后续搜索将包含这些代码。", + parameters: { + paths: z.array(z.string()).describe("要添加的文件或目录路径列表"), + recursive: z.boolean().optional().describe("如果路径是目录,是否递归添加所有子文件"), + }, + execute: async (args) => { + try { + const paths = args.paths.join(" "); + const recursiveFlag = args.recursive ? "--recursive" : ""; + const cmd = `ocrag add ${paths} ${recursiveFlag}`; + const result = Bun.$`${cmd}`.text(); + return result; + } catch (error) { + return `Error: ${error.message}`; + } + }, + }), + rag_search: tool({ + description: "在本地 RAG 知识库中执行语义搜索,返回与查询最相关的代码片段。", + parameters: { + query: z.string().describe("自然语言查询,例如:'如何实现用户认证'"), + top_k: z.number().optional().default(5).describe("返回结果数量"), + }, + execute: async (args) => { + try { + const cmd = `ocrag search "${args.query}" --top-k ${args.top_k}`; + const result = Bun.$`${cmd}`.text(); + return result; + } catch (error) { + return `Error: ${error.message}`; + } + }, + }), + }), +}); diff --git a/opencode-plugin/package.json b/opencode-plugin/package.json new file mode 100644 index 0000000..6e43665 --- /dev/null +++ b/opencode-plugin/package.json @@ -0,0 +1,9 @@ +{ + "name": "ocrag-plugin", + "version": "1.0.0", + "description": "Code RAG plugin for OpenCode", + "main": "ocrag-plugin.ts", + "keywords": ["opencode", "plugin", "rag", "code"], + "author": "", + "license": "MIT" +} diff --git a/opencode-skill/SKILL.md b/opencode-skill/SKILL.md new file mode 100644 index 0000000..00ad13c --- /dev/null +++ b/opencode-skill/SKILL.md @@ -0,0 +1,34 @@ +--- +name: ocrag +description: 代码库 RAG 技能,用于向知识库添加代码或搜索已有代码。 +--- + +# 代码库 RAG 技能 + +## 何时使用此技能 +- 用户要求"把当前文件/这个目录加入知识库" +- 用户询问关于代码库的问题,需要基于已有代码回答 +- 用户明确提到"RAG"、"知识库"、"语义搜索"等关键词 + +## 如何使用 + +### 添加代码到知识库 +当用户提供文件路径或目录时,使用 `rag_add` 工具: +- 如果用户说"把这个文件加入知识库",提取路径并调用 `rag_add`,recursive 设为 false +- 如果用户说"添加 src/ 目录下所有代码",调用 `rag_add` 并设置 recursive = true + +### 搜索知识库 +当用户询问代码相关问题(如"这段代码做什么?"、"如何配置XX?")时: +1. 先用 `rag_search` 工具查询,query 参数为用户的自然语言问题 +2. 获得返回的代码片段后,结合片段内容回答用户 +3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码 + +## 示例 +- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径 +- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现" +- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true + +## 注意事项 +- 搜索结果会显示代码片段的来源文件 +- 可以通过 top_k 参数调整返回结果数量 +- 添加文件后,搜索会立即包含新添加的内容 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5231c48 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "project" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "chonkie>=1.6.2", + "click>=8.3.2", + "lancedb>=0.30.2", + "sentence-transformers>=5.4.1", + "sentencepiece>=0.2.1", + "tiktoken>=0.12.0", + "tokenizers>=0.22.2", +] + +[dependency-groups] +dev = [ + "pytest>=9.0.3", +] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..b2a3237 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +addopts = -v +pythonpath = src diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100755 index 0000000..6eef5d3 --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +""" +性能测试脚本 - 测试 ocrag 的写入和读取性能 + +使用方法: + uv run python scripts/benchmark.py + uv run python scripts/benchmark.py --cleanup # 测试后清理测试数据 +""" + +import sys +import os +import time +import tempfile +import shutil +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from ocrag.chunker import chunk_code +from ocrag.embedder import embedder +from ocrag.db import RagDB + + +class PerformanceBenchmark: + def __init__(self, db_path=None): + if db_path is None: + self.db_path = tempfile.mkdtemp(prefix="ocrag_benchmark_") + else: + self.db_path = db_path + self.db = RagDB(self.db_path) + + def cleanup(self): + """清理测试数据库""" + if os.path.exists(self.db_path): + shutil.rmtree(self.db_path) + print(f"✅ 已清理测试数据库: {self.db_path}") + + def generate_test_code(self, num_lines=100, language="python"): + """生成测试代码""" + if language == "python": + code = "def function_{i}():\n '''Docstring for function {i}'''\n result = 0\n" + code += " for i in range(100):\n result += i\n" + code += " return result\n\n" + elif language == "javascript": + code = "function function_{i}() {{\n // Docstring for function {i}\n" + code += " let result = 0;\n" + code += ( + " for (let i = 0; i < 100; i++) {{\n result += i;\n }}\n" + ) + code += " return result;\n}}\n\n" + + full_code = "" + for i in range(num_lines): + full_code += code.format(i=i) + return full_code + + def benchmark_single_file_write(self, num_lines=100): + """测试单个文件的写入性能""" + print(f"\n{'=' * 60}") + print(f"测试: 单个文件写入 ({num_lines} 行代码)") + print("=" * 60) + + code = self.generate_test_code(num_lines) + file_path = "test_file.py" + + # 分块 + start_time = time.time() + chunks = chunk_code(code, file_path) + chunk_time = time.time() - start_time + + print(f"分块耗时: {chunk_time * 1000:.2f} ms") + print(f"生成块数: {len(chunks)}") + + # Embedding + start_time = time.time() + texts = [c["text"] for c in chunks] + vectors = embedder.embed(texts) + embed_time = time.time() - start_time + + print(f"Embedding 耗时: {embed_time * 1000:.2f} ms") + + # 构建文档 + documents = [] + for chunk, vec in zip(chunks, vectors): + documents.append( + { + "text": chunk["text"], + "vector": vec, + "metadata": chunk["metadata"], + } + ) + + # 写入数据库 + start_time = time.time() + self.db.add_documents(documents) + db_write_time = time.time() - start_time + + print(f"数据库写入耗时: {db_write_time * 1000:.2f} ms") + + total_time = chunk_time + embed_time + db_write_time + print(f"\n📊 总耗时: {total_time * 1000:.2f} ms") + + return { + "chunk_time": chunk_time * 1000, + "embed_time": embed_time * 1000, + "db_write_time": db_write_time * 1000, + "total_time": total_time * 1000, + "num_chunks": len(chunks), + } + + def benchmark_batch_write(self, num_files=10, lines_per_file=50): + """测试批量文件写入性能""" + print(f"\n{'=' * 60}") + print(f"测试: 批量写入 ({num_files} 个文件, 每个 {lines_per_file} 行)") + print("=" * 60) + + total_chunks = 0 + total_embed_time = 0 + total_db_write_time = 0 + + for i in range(num_files): + code = self.generate_test_code(lines_per_file) + file_path = f"test_file_{i}.py" + + # 分块 + chunks = chunk_code(code, file_path) + total_chunks += len(chunks) + + # Embedding + texts = [c["text"] for c in chunks] + start_time = time.time() + vectors = embedder.embed(texts) + total_embed_time += time.time() - start_time + + # 构建文档并写入 + documents = [] + for chunk, vec in zip(chunks, vectors): + documents.append( + { + "text": chunk["text"], + "vector": vec, + "metadata": chunk["metadata"], + } + ) + + start_time = time.time() + self.db.add_documents(documents) + total_db_write_time += time.time() - start_time + + total_time = total_embed_time + total_db_write_time + + print(f"总块数: {total_chunks}") + print(f"平均每块 embedding: {total_embed_time / num_files * 1000:.2f} ms") + print(f"平均每块数据库写入: {total_db_write_time / num_files * 1000:.2f} ms") + print(f"📊 总耗时: {total_time * 1000:.2f} ms") + print(f"📊 吞吐量: {total_chunks / total_time:.2f} 块/秒") + + return { + "total_chunks": total_chunks, + "total_embed_time": total_embed_time * 1000, + "total_db_write_time": total_db_write_time * 1000, + "total_time": total_time * 1000, + "throughput": total_chunks / total_time, + } + + def benchmark_search(self, num_queries=10): + """测试搜索性能""" + print(f"\n{'=' * 60}") + print(f"测试: 搜索性能 ({num_queries} 次查询)") + print("=" * 60) + + queries = [ + "how to implement user authentication", + "database configuration", + "API endpoint handler", + "error handling function", + "data validation logic", + ] + + search_times = [] + + for i in range(num_queries): + query = queries[i % len(queries)] + + # Query embedding + start_time = time.time() + query_vec = embedder.embed_single(query) + embed_time = time.time() - start_time + + # Search + start_time = time.time() + results = self.db.search(query_vec, top_k=5) + search_time = time.time() - start_time + + total_time = embed_time + search_time + search_times.append(total_time) + + print( + f"查询 {i + 1}: {query[:30]}... → {len(results)} 结果 ({total_time * 1000:.2f} ms)" + ) + + avg_time = sum(search_times) / len(search_times) + min_time = min(search_times) + max_time = max(search_times) + + print(f"\n📊 搜索性能统计:") + print(f" 平均延迟: {avg_time * 1000:.2f} ms") + print(f" 最小延迟: {min_time * 1000:.2f} ms") + print(f" 最大延迟: {max_time * 1000:.2f} ms") + + return { + "avg_time": avg_time * 1000, + "min_time": min_time * 1000, + "max_time": max_time * 1000, + "num_queries": num_queries, + } + + def benchmark_list_sources(self): + """测试 list_sources 性能""" + print(f"\n{'=' * 60}") + print("测试: list_sources 性能") + print("=" * 60) + + start_time = time.time() + sources = self.db.list_sources() + elapsed = time.time() - start_time + + print(f"列出来源数: {len(sources)}") + print(f"📊 耗时: {elapsed * 1000:.2f} ms") + + return { + "num_sources": len(sources), + "time": elapsed * 1000, + } + + def run_full_benchmark(self): + """运行完整性能测试""" + print("\n" + "=" * 60) + print("🚀 Ocrag 性能基准测试") + print("=" * 60) + + # 准备测试数据 + print("\n📦 准备测试数据...") + self.benchmark_single_file_write(num_lines=100) + + # 批量写入测试 + self.benchmark_batch_write(num_files=20, lines_per_file=50) + + # 搜索测试 + self.benchmark_search(num_queries=10) + + # List 测试 + self.benchmark_list_sources() + + print("\n" + "=" * 60) + print("✅ 性能测试完成!") + print("=" * 60) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Ocrag 性能测试") + parser.add_argument("--cleanup", action="store_true", help="测试后清理数据") + parser.add_argument("--db-path", type=str, help="指定数据库路径") + args = parser.parse_args() + + benchmark = PerformanceBenchmark(db_path=args.db_path) + + try: + benchmark.run_full_benchmark() + finally: + if args.cleanup: + benchmark.cleanup() + + +if __name__ == "__main__": + main() diff --git a/scripts/test_chunking.py b/scripts/test_chunking.py new file mode 100644 index 0000000..57f20df --- /dev/null +++ b/scripts/test_chunking.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +""" +代码分块性能测试脚本 + +测试不同分块策略的效果和性能: +1. langchain - 语言感知分块 +2. semantic - 语义分块 +3. simple - 简单字符分块 + +使用方法: + uv run python scripts/test_chunking.py +""" + +import sys +import time +from pathlib import Path +from typing import List, Dict, Any + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from ocrag.chunker import ( + chunk_code, + ChunkStrategy, + get_available_strategies, + LANGCHAIN_AVAILABLE, + SEMANTIC_AVAILABLE, +) + + +class ChunkingBenchmark: + """分块性能测试""" + + # 测试代码样本 + TEST_CASES = { + "python_small": { + "code": ''' +class UserManager: + """用户管理类""" + def __init__(self): + self.users = {} + + def add_user(self, user_id, name): + self.users[user_id] = {'name': name, 'active': True} + return True + + def remove_user(self, user_id): + if user_id in self.users: + del self.users[user_id] + return True + return False +''', + "file": "test.py", + "expected_classes": 1, + }, + "python_medium": { + "code": ''' +class User: + """用户基类""" + def __init__(self, name, email): + self.name = name + self.email = email + self.active = True + + def deactivate(self): + self.active = False + +class Admin(User): + """管理员类""" + def __init__(self, name, email, permissions): + super().__init__(name, email) + self.permissions = permissions + + def has_permission(self, permission): + return permission in self.permissions + +class Guest(User): + """访客类""" + def __init__(self, name, email): + super().__init__(name, email) + self.active = False + +def create_user(user_type, name, email, **kwargs): + """用户工厂函数""" + if user_type == 'admin': + return Admin(name, email, kwargs.get('permissions', [])) + elif user_type == 'guest': + return Guest(name, email) + else: + return User(name, email) + +def validate_email(email): + """验证邮箱格式""" + return '@' in email and '.' in email.split('@')[1] + +def process_batch(users): + """批量处理用户""" + results = [] + for user in users: + if user.active: + results.append({ + 'name': user.name, + 'email': user.email, + 'status': 'active' + }) + return results +''', + "file": "user_manager.py", + "expected_classes": 3, + }, + "python_large": { + "code": ''' +import json +from typing import List, Dict, Optional + +class DataProcessor: + """数据处理器""" + def __init__(self, config): + self.config = config + self.cache = {} + + def process(self, data): + if not self.validate(data): + raise ValueError("Invalid data format") + return self.transform(data) + + def validate(self, data): + return isinstance(data, dict) and 'id' in data + + def transform(self, data): + return { + 'id': data['id'], + 'processed': True, + 'timestamp': time.time() + } + +class APIClient: + """API 客户端""" + def __init__(self, base_url, api_key): + self.base_url = base_url + self.api_key = api_key + + def get(self, endpoint): + url = f"{self.base_url}/{endpoint}" + headers = {'Authorization': f'Bearer {self.api_key}'} + return self._request('GET', url, headers) + + def post(self, endpoint, data): + url = f"{self.base_url}/{endpoint}" + headers = { + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json' + } + return self._request('POST', url, headers, json.dumps(data)) + + def _request(self, method, url, headers, body=None): + # 实现请求逻辑 + pass + +class CacheManager: + """缓存管理器""" + def __init__(self, max_size=1000): + self.max_size = max_size + self.cache = {} + self.access_order = [] + + def get(self, key): + if key in self.cache: + self._update_access(key) + return self.cache[key] + return None + + def set(self, key, value): + if len(self.cache) >= self.max_size: + oldest = self.access_order.pop(0) + del self.cache[oldest] + + self.cache[key] = value + self.access_order.append(key) + + def _update_access(self, key): + if key in self.access_order: + self.access_order.remove(key) + self.access_order.append(key) + + def clear(self): + self.cache.clear() + self.access_order.clear() +''', + "file": "large_module.py", + "expected_classes": 3, + }, + "rust": { + "code": """ +struct User { + name: String, + email: String, + active: bool, +} + +impl User { + fn new(name: String, email: String) -> Self { + User { + name, + email, + active: true, + } + } + + fn deactivate(&mut self) { + self.active = false; + } +} + +struct Admin { + user: User, + permissions: Vec, +} + +impl Admin { + fn has_permission(&self, permission: &str) -> bool { + self.permissions.contains(&permission.to_string()) + } +} + +fn create_user(name: String, email: String) -> User { + User::new(name, email) +} + +fn process_batch(users: Vec) -> Vec { + users + .iter() + .filter(|u| u.active) + .map(|u| u.name.clone()) + .collect() +} +""", + "file": "user.rs", + "expected_classes": 2, + }, + "cpp": { + "code": """ +#include +#include + +class User { +private: + std::string name; + std::string email; + bool active; + +public: + User(const std::string& name, const std::string& email) + : name(name), email(email), active(true) {} + + void deactivate() { + active = false; + } + + bool isActive() const { + return active; + } +}; + +class Admin : public User { +private: + std::vector permissions; + +public: + Admin(const std::string& name, const std::string& email) + : User(name, email) {} + + void addPermission(const std::string& perm) { + permissions.push_back(perm); + } +}; + +void processUsers(std::vector& users) { + for (auto& user : users) { + if (user.isActive()) { + // 处理活跃用户 + } + } +} +""", + "file": "user.cpp", + "expected_classes": 2, + }, + } + + def __init__(self): + self.results = {} + + def run_benchmark(self, strategy: ChunkStrategy): + """运行单个策略的基准测试""" + print(f"\n{'=' * 60}") + print(f"测试策略: {strategy.value}") + print("=" * 60) + + total_chunks = 0 + total_time = 0 + first_init_time = None + results = [] + + for name, test_case in self.TEST_CASES.items(): + # 测试初始化时间(首次) + start_time = time.time() + chunks = chunk_code( + test_case["code"], + test_case["file"], + strategy=strategy, + chunk_size=512, + chunk_overlap=50, + ) + elapsed = time.time() - start_time + + if first_init_time is None: + first_init_time = elapsed + + total_time += elapsed + total_chunks += len(chunks) + + results.append( + { + "name": name, + "chunks": len(chunks), + "time_ms": elapsed * 1000, + "chars": len(test_case["code"]), + } + ) + + print(f"\n [{name}]") + print(f" 字符数: {len(test_case['code'])}") + print(f" 分块数: {len(chunks)}") + print(f" 耗时: {elapsed * 1000:.2f}ms") + + print(f"\n总耗时: {total_time * 1000:.2f}ms") + print(f"总块数: {total_chunks}") + + return { + "strategy": strategy.value, + "total_time_ms": total_time * 1000, + "total_chunks": total_chunks, + "first_init_time_ms": first_init_time * 1000 if first_init_time else 0, + "details": results, + } + + def run_all(self): + """运行所有策略的测试""" + print("=" * 60) + print("📊 代码分块性能测试") + print("=" * 60) + + print(f"\n可用的分块策略: {get_available_strategies()}") + print(f"langchain 可用: {LANGCHAIN_AVAILABLE}") + print(f"semantic-text-splitter 可用: {SEMANTIC_AVAILABLE}") + + all_results = [] + + # 测试每个策略 + for strategy_name in get_available_strategies(): + strategy = ChunkStrategy(strategy_name) + result = self.run_benchmark(strategy) + all_results.append(result) + + # 打印汇总 + self.print_summary(all_results) + + return all_results + + def print_summary(self, results: List[Dict]): + """打印测试汇总""" + print("\n" + "=" * 60) + print("📈 测试结果汇总") + print("=" * 60) + + print(f"\n{'策略':<15} {'总耗时':<12} {'首次初始化':<12} {'分块数':<10}") + print("-" * 50) + + for result in results: + print( + f"{result['strategy']:<15} " + f"{result['total_time_ms']:>8.2f}ms " + f"{result['first_init_time_ms']:>8.2f}ms " + f"{result['total_chunks']:>6}" + ) + + # 找出最快的策略 + fastest = min(results, key=lambda x: x["total_time_ms"]) + most_chunks = max(results, key=lambda x: x["total_chunks"]) + + print( + f"\n🏆 最快策略: {fastest['strategy']} ({fastest['total_time_ms']:.2f}ms)" + ) + print( + f"📦 最多分块: {most_chunks['strategy']} ({most_chunks['total_chunks']} 块)" + ) + + +def main(): + benchmark = ChunkingBenchmark() + results = benchmark.run_all() + + print("\n" + "=" * 60) + print("✅ 测试完成!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/ocrag/__init__.py b/src/ocrag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ocrag/chunker.py b/src/ocrag/chunker.py new file mode 100644 index 0000000..e720fb2 --- /dev/null +++ b/src/ocrag/chunker.py @@ -0,0 +1,99 @@ +""" +代码分块模块 + +使用 langchain_text_splitters 进行语言感知分块。 +支持 29 种编程语言的语法感知分块,tiktoken 已内置,无需额外下载。 +""" + +from pathlib import Path +from typing import List, Dict, Any + +try: + from langchain_text_splitters import Language, RecursiveCharacterTextSplitter +except ImportError: + raise ImportError( + "请安装 langchain-text-splitters: uv pip install langchain-text-splitters" + ) + +LANGUAGE_MAP = { + ".py": Language.PYTHON, + ".rs": Language.RUST, + ".cpp": Language.CPP, + ".cc": Language.CPP, + ".cxx": Language.CPP, + ".go": Language.GO, + ".java": Language.JAVA, + ".js": Language.JS, + ".ts": Language.TS, + ".jsx": Language.JS, + ".tsx": Language.TS, + ".c": Language.C, + ".h": Language.C, + ".hpp": Language.CPP, + ".rb": Language.RUBY, + ".swift": Language.SWIFT, + ".kt": Language.KOTLIN, + ".md": Language.MARKDOWN, + ".txt": None, +} + + +def detect_language(file_path: str) -> Language | None: + """检测文件语言类型""" + ext = Path(file_path).suffix.lower() + return LANGUAGE_MAP.get(ext) + + +def chunk_code( + content: str, + file_path: str, + chunk_size: int = 512, + chunk_overlap: int = 50, +) -> List[Dict[str, Any]]: + """将代码文件分块,返回块列表 + + 使用 langchain_text_splitters 进行语言感知分块。 + 支持 29 种编程语言的语法感知分块。 + + Args: + content: 文件内容 + file_path: 文件路径 + chunk_size: 分块大小(token 数) + chunk_overlap: 重叠大小(token 数) + + Returns: + 分块列表,每个块包含 text 和 metadata + """ + language = detect_language(file_path) + + if language is None: + # 不支持的语言,使用简单分块 + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separators=["\n\n", "\n", " ", ""], + length_function=len, + ) + texts = splitter.split_text(content) + language_name = "text" + else: + # 使用语言感知分块 + splitter = RecursiveCharacterTextSplitter.from_language( + language=language, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + texts = splitter.split_text(content) + language_name = language.value + + return [ + { + "text": text, + "metadata": { + "source_file": file_path, + "chunk_index": i, + "language": language_name, + }, + } + for i, text in enumerate(texts) + ] diff --git a/src/ocrag/cli.py b/src/ocrag/cli.py new file mode 100644 index 0000000..8c5826e --- /dev/null +++ b/src/ocrag/cli.py @@ -0,0 +1,45 @@ +import click +from ocrag.commands import add as add_cmd +from ocrag.commands import search as search_cmd +from ocrag.commands import remove as remove_cmd +from ocrag.commands import list_cmd + + +@click.group() +@click.version_option() +def main(): + """ocrag - 代码库 RAG 命令行工具""" + pass + + +@main.command() +@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True)) +@click.option("--recursive", "-r", is_flag=True, help="递归处理目录") +def add(paths, recursive): + """向知识库添加文件或目录""" + add_cmd(paths, recursive) + + +@main.command() +@click.argument("query") +@click.option("--top-k", "-k", default=5, show_default=True, help="返回结果数量") +def search(query, top_k): + """语义搜索知识库""" + search_cmd(query, top_k) + + +@main.command() +@click.argument("path", type=click.Path()) +def remove(path): + """从知识库中移除指定文件""" + remove_cmd(path) + + +@main.command() +def list(): + """列出知识库中的所有条目""" + list_cmd() + + +if __name__ == "__main__": + main() diff --git a/src/ocrag/commands/__init__.py b/src/ocrag/commands/__init__.py new file mode 100644 index 0000000..edfe416 --- /dev/null +++ b/src/ocrag/commands/__init__.py @@ -0,0 +1,4 @@ +from ocrag.commands.add import run as add +from ocrag.commands.search import run as search +from ocrag.commands.remove import run as remove +from ocrag.commands.list import run as list_cmd diff --git a/src/ocrag/commands/add.py b/src/ocrag/commands/add.py new file mode 100644 index 0000000..d5d9f99 --- /dev/null +++ b/src/ocrag/commands/add.py @@ -0,0 +1,56 @@ +import os +from pathlib import Path +from ocrag.chunker import chunk_code +from ocrag.embedder import embedder +from ocrag.db import RagDB + + +def collect_files(paths, recursive): + """收集所有需要处理的文件""" + files = [] + for p in paths: + p = Path(p) + if p.is_file(): + files.append(p) + elif p.is_dir() and recursive: + for root, dirs, filenames in os.walk(p): + for f in filenames: + files.append(Path(root) / f) + elif p.is_dir() and not recursive: + print(f"跳过目录 {p}(未使用 --recursive)") + return files + + +def run(paths, recursive): + db = RagDB() + files = collect_files(paths, recursive) + + total_chunks = 0 + for file_path in files: + try: + content = file_path.read_text(encoding="utf-8") + chunks = chunk_code(content, str(file_path)) + if not chunks: + continue + + # 批量 embedding + texts = [c["text"] for c in chunks] + vectors = embedder.embed(texts) + + documents = [] + for chunk, vec in zip(chunks, vectors): + documents.append( + { + "text": chunk["text"], + "vector": vec, + "metadata": chunk["metadata"], + } + ) + + db.add_documents(documents) + total_chunks += len(documents) + print(f"✅ {file_path} -> {len(documents)} 个块") + except Exception as e: + print(f"❌ 处理 {file_path} 失败: {e}") + + print(f"\n📦 总计添加 {total_chunks} 个块") diff --git a/src/ocrag/commands/list.py b/src/ocrag/commands/list.py new file mode 100644 index 0000000..dd18f29 --- /dev/null +++ b/src/ocrag/commands/list.py @@ -0,0 +1,14 @@ +from ocrag.db import RagDB + + +def run(): + db = RagDB() + sources = db.list_sources() + + if not sources: + print("知识库为空") + return + + print("知识库中的文件:") + for i, source in enumerate(sources, 1): + print(f"{i}. {source}") diff --git a/src/ocrag/commands/remove.py b/src/ocrag/commands/remove.py new file mode 100644 index 0000000..992179e --- /dev/null +++ b/src/ocrag/commands/remove.py @@ -0,0 +1,9 @@ +from ocrag.db import RagDB + + +def run(path: str): + db = RagDB() + # 注意:当前db.py中的delete_by_source方法尚未完全实现 + # 这里先调用以保持接口一致,后续需要完善db.py的实现 + db.delete_by_source(path) + print(f"已删除 {path} 的所有块") diff --git a/src/ocrag/commands/search.py b/src/ocrag/commands/search.py new file mode 100644 index 0000000..cd1bf80 --- /dev/null +++ b/src/ocrag/commands/search.py @@ -0,0 +1,18 @@ +from ocrag.embedder import embedder +from ocrag.db import RagDB + + +def run(query: str, top_k: int): + db = RagDB() + query_vec = embedder.embed_single(query) + results = db.search(query_vec, top_k) + + if not results: + print("未找到相关结果。") + return + + for i, r in enumerate(results, 1): + print(f"\n[{i}] 相似度: {r['_distance']:.4f}") + print(f"来源: {r['metadata'].get('source_file', 'unknown')}") + print(f"内容:\n{r['text']}") + print("-" * 80) diff --git a/src/ocrag/db.py b/src/ocrag/db.py new file mode 100644 index 0000000..a920441 --- /dev/null +++ b/src/ocrag/db.py @@ -0,0 +1,112 @@ +import lancedb +import pyarrow as pa +from typing import List, Dict, Any +from pathlib import Path + +DB_PATH = Path.home() / ".ocrag" / "data.lance" + + +class RagDB: + def __init__(self, db_path: str = None): + self.path = db_path or str(DB_PATH) + self.conn = lancedb.connect(self.path) + self._init_table() + + def _init_table(self): + """如果表不存在则创建""" + schema = pa.schema( + [ + ("text", pa.string()), + ( + "vector", + pa.list_(pa.float32(), 1024), + ), # 1024维 (Qwen3-Embedding-0.6B) + ("metadata", pa.string()), # JSON 字符串 + ] + ) + try: + self.table = self.conn.open_table("documents") + except: + self.table = self.conn.create_table("documents", schema=schema) + + def add_documents(self, documents: List[Dict[str, Any]]): + """批量添加文档 + documents: [{"text": str, "vector": List[float], "metadata": dict}, ...] + """ + import json + + data = [] + for doc in documents: + data.append( + { + "text": doc["text"], + "vector": doc["vector"], + "metadata": json.dumps(doc.get("metadata", {})), + } + ) + self.table.add(data) + + def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]: + import json + + results = self.table.search(query_vector).limit(top_k).to_list() + for r in results: + r["metadata"] = json.loads(r["metadata"]) + return results + + def delete_by_source(self, source_file: str) -> int: + """删除指定源文件的所有块 + + Args: + source_file: 要删除的源文件路径 + + Returns: + 删除的块数量 + """ + import json + + # 安全实现:使用 Pandas 过滤来避免 SQL 注入 + df = self.table.to_pandas() + + if df.empty: + return 0 + + # 过滤出不包含该源文件的行 + indices_to_keep = [] + for idx, row in df.iterrows(): + try: + meta = json.loads(row["metadata"]) + if meta.get("source_file") != source_file: + indices_to_keep.append(idx) + except (json.JSONDecodeError, KeyError): + # 如果解析失败,保留该行 + indices_to_keep.append(idx) + + num_deleted = len(df) - len(indices_to_keep) + + if num_deleted == 0: + return 0 + + # 保留不需要删除的行 + df_remaining = df.loc[indices_to_keep] + + # 使用 overwrite 模式重建表 + self.conn.drop_table("documents") + self.table = self.conn.create_table("documents", df_remaining) + + return num_deleted + + def list_sources(self) -> List[str]: + """列出所有已添加的源文件路径(去重)""" + # 获取所有 metadata,提取 source_file + df = self.table.to_pandas() + if df.empty: + return [] + import json + + sources = set() + for meta_str in df["metadata"]: + meta = json.loads(meta_str) + if "source_file" in meta: + sources.add(meta["source_file"]) + return sorted(sources) diff --git a/src/ocrag/embedder.py b/src/ocrag/embedder.py new file mode 100644 index 0000000..4e2718f --- /dev/null +++ b/src/ocrag/embedder.py @@ -0,0 +1,45 @@ +from pathlib import Path +from typing import List +import os + +from sentence_transformers import SentenceTransformer + + +class Embedder: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + + # 计算模型路径 + model_path = ( + Path(__file__).parent.parent.parent / "models" / "Qwen3-Embedding-0.6B" + ) + + if not model_path.exists(): + raise FileNotFoundError(f"模型目录不存在: {model_path}") + + try: + cls._instance.model = SentenceTransformer( + str(model_path), + device="cuda" + if os.getenv("USE_GPU", "false").lower() == "true" + else "cpu", + ) + except Exception as e: + raise RuntimeError(f"加载模型失败: {e}") + + return cls._instance + + def embed(self, texts: List[str]) -> List[List[float]]: + """返回向量列表,每个向量为 float 列表""" + embeddings = self.model.encode(texts, normalize_embeddings=True) + return embeddings.tolist() + + def embed_single(self, text: str) -> List[float]: + return self.embed([text])[0] + + +# 全局单例 +embedder = Embedder() diff --git a/src/ocrag/utils.py b/src/ocrag/utils.py new file mode 100644 index 0000000..41f0df8 --- /dev/null +++ b/src/ocrag/utils.py @@ -0,0 +1,78 @@ +import os +import hashlib +from pathlib import Path +from typing import Optional + + +def get_file_hash(file_path: str) -> str: + """计算文件的MD5哈希值,用于检测文件变化""" + md5_hash = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + return md5_hash.hexdigest() + + +def ensure_dir(dir_path: str) -> None: + """确保目录存在,不存在则创建""" + Path(dir_path).mkdir(parents=True, exist_ok=True) + + +def get_file_size(file_path: str) -> int: + """获取文件大小(字节)""" + return os.path.getsize(file_path) + + +def format_size(size_bytes: int) -> str: + """格式化文件大小""" + for unit in ["B", "KB", "MB", "GB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} TB" + + +def is_text_file(file_path: str) -> bool: + """检查文件是否为文本文件""" + text_extensions = { + ".py", + ".js", + ".ts", + ".jsx", + ".tsx", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".go", + ".rs", + ".rb", + ".php", + ".swift", + ".kt", + ".scala", + ".md", + ".txt", + ".json", + ".yaml", + ".yml", + ".xml", + ".html", + ".css", + ".sql", + ".sh", + ".bash", + ".zsh", + ".ps1", + ".bat", + ".vue", + ".svelte", + } + ext = Path(file_path).suffix.lower() + return ext in text_extensions + + +def get_project_root() -> Path: + """获取项目根目录""" + return Path(__file__).parent.parent.parent diff --git a/tests/unit/test_chunker.py b/tests/unit/test_chunker.py new file mode 100644 index 0000000..13597f1 --- /dev/null +++ b/tests/unit/test_chunker.py @@ -0,0 +1,45 @@ +import pytest +from ocrag.chunker import chunk_code + + +@pytest.fixture +def sample_python_code(): + return """def hello(): + print('Hello, World!') + +class Test: + def method(self): + pass""" + + +@pytest.fixture +def sample_markdown_content(): + return """# Title + +Paragraph 1 + +## Subtitle + +Paragraph 2""" + + +def test_chunk_code_python(sample_python_code): + chunks = chunk_code(sample_python_code, "test.py") + assert len(chunks) > 0 + assert "def hello" in chunks[0]["text"] + assert chunks[0]["metadata"]["language"] == "python" + + +def test_chunk_code_markdown(sample_markdown_content): + chunks = chunk_code(sample_markdown_content, "test.md") + assert len(chunks) > 0 + assert "# Title" in chunks[0]["text"] + assert chunks[0]["metadata"]["language"] == "markdown" + + +def test_chunk_code_text(): + content = "a" * 10000 # Large content to test text fallback + chunks = chunk_code(content, "large.txt") + assert len(chunks) > 0 + assert chunks[0]["metadata"]["source_file"] == "large.txt" + assert chunks[0]["metadata"]["language"] == "text" diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py new file mode 100644 index 0000000..5eef252 --- /dev/null +++ b/tests/unit/test_commands.py @@ -0,0 +1,63 @@ +import pytest +import os +from click.testing import CliRunner +from ocrag.cli import main +from ocrag.db import RagDB + + +@pytest.fixture +def runner(): + return CliRunner() + + +@pytest.fixture +def setup_test_env(tmpdir): + # Create test files + os.makedirs(os.path.join(tmpdir, "test_dir")) + with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f: + f.write("def test_func():\n pass") + with open(os.path.join(tmpdir, "test_file.py"), "w") as f: + f.write("print('Hello')") + return tmpdir + + +def test_add_command(runner, setup_test_env, tmpdir): + # Use temporary DB + db_path = os.path.join(tmpdir, "test_db.lance") + os.environ["OCRAG_DB_PATH"] = db_path + + result = runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")]) + assert "test_file.py" in result.output + assert "总计添加" in result.output + assert result.exit_code == 0 + + +def test_search_command(runner, setup_test_env, tmpdir): + # Use temporary DB and add a file first + db_path = os.path.join(tmpdir, "test_db.lance") + os.environ["OCRAG_DB_PATH"] = db_path + + add_result = runner.invoke( + main, ["add", os.path.join(setup_test_env, "test_file.py")] + ) + assert add_result.exit_code == 0 + + # Search + result = runner.invoke(main, ["search", "Hello"]) + assert "Hello" in result.output + assert "来源: " in result.output + + +def test_list_command(runner, setup_test_env, tmpdir): + # Use temporary DB and add a file first + db_path = os.path.join(tmpdir, "test_db.lance") + os.environ["OCRAG_DB_PATH"] = db_path + + add_result = runner.invoke( + main, ["add", os.path.join(setup_test_env, "test_file.py")] + ) + assert add_result.exit_code == 0 + + # List + result = runner.invoke(main, ["list"]) + assert "test_file.py" in result.output diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py new file mode 100644 index 0000000..9ac7126 --- /dev/null +++ b/tests/unit/test_db.py @@ -0,0 +1,115 @@ +import pytest +import os +import shutil +from ocrag.db import RagDB + + +@pytest.fixture +def temp_db(tmpdir): + db_path = os.path.join(tmpdir, "test_db.lance") + yield RagDB(db_path) + if os.path.exists(db_path): + shutil.rmtree(db_path) + + +def test_db_initialization(temp_db): + assert temp_db.table is not None + + +def test_add_documents(temp_db): + documents = [ + { + "text": "Test content", + "vector": [0.1] * 1024, + "metadata": {"source_file": "test.py"}, + } + ] + temp_db.add_documents(documents) + assert temp_db.table.to_pandas().shape[0] == 1 + + +def test_search(temp_db): + # Add a document + documents = [ + { + "text": "Database configuration", + "vector": [0.1] * 1024, + "metadata": {"source_file": "config.py"}, + } + ] + temp_db.add_documents(documents) + + # Search + results = temp_db.search([0.1] * 1024) + assert len(results) == 1 + assert "Database configuration" in results[0]["text"] + + +def test_list_sources(temp_db): + # Add documents from multiple sources + documents = [ + { + "text": "Content 1", + "vector": [0.1] * 1024, + "metadata": {"source_file": "file1.py"}, + }, + { + "text": "Content 2", + "vector": [0.2] * 1024, + "metadata": {"source_file": "file2.py"}, + }, + ] + temp_db.add_documents(documents) + + sources = temp_db.list_sources() + assert len(sources) == 2 + assert "file1.py" in sources + assert "file2.py" in sources + + +def test_delete_by_source(temp_db): + # Add documents from multiple sources + documents = [ + { + "text": "Content A", + "vector": [0.1] * 1024, + "metadata": {"source_file": "file1.py"}, + }, + { + "text": "Content B", + "vector": [0.2] * 1024, + "metadata": {"source_file": "file2.py"}, + }, + { + "text": "Content C", + "vector": [0.3] * 1024, + "metadata": {"source_file": "file1.py"}, + }, + { + "text": "Content D", + "vector": [0.4] * 1024, + "metadata": {"source_file": "file3.py"}, + }, + ] + temp_db.add_documents(documents) + + # Verify initial state + sources = temp_db.list_sources() + assert len(sources) == 3 + + # Delete file1.py (should delete 2 chunks) + num_deleted = temp_db.delete_by_source("file1.py") + assert num_deleted == 2 + + # Verify deletion + sources = temp_db.list_sources() + assert len(sources) == 2 + assert "file1.py" not in sources + assert "file2.py" in sources + assert "file3.py" in sources + + +def test_delete_nonexistent_source(temp_db): + # Try to delete a source that doesn't exist + num_deleted = temp_db.delete_by_source("nonexistent.py") + assert num_deleted == 0 diff --git a/tests/unit/test_embedder.py b/tests/unit/test_embedder.py new file mode 100644 index 0000000..2d21ac6 --- /dev/null +++ b/tests/unit/test_embedder.py @@ -0,0 +1,28 @@ +import pytest +from ocrag.embedder import Embedder + + +@pytest.fixture(scope="module") +def embedder(): + return Embedder() + + +def test_embedder_singleton(embedder): + embedder2 = Embedder() + assert embedder is embedder2 + + +def test_embed_single(embedder): + vector = embedder.embed_single("Test sentence") + assert len(vector) == 1024 # Qwen3-Embedding-0.6B output dimension + assert isinstance(vector[0], float) + + +def test_embed_batch(embedder): + vectors = embedder.embed(["Sentence 1", "Sentence 2"]) + assert len(vectors) == 2 + assert len(vectors[0]) == 1024 + assert len(vectors[1]) == 1024 + assert ( + vectors[0] != vectors[1] + ) # Different sentences should have different embeddings