feat(chunker): 使用 tree-sitter 实现语义感知代码分块

This commit is contained in:
songsenand 2026-04-16 21:44:45 +08:00
parent 9aee24979f
commit fdfe2fa97d
15 changed files with 1099 additions and 204 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
**/__pycache__/
models/*
uv.lock

View File

@ -63,19 +63,20 @@
- **Python 优先**:原生 Python SDK类型提示完善
- **性能优秀**:基于 Apache Arrow查询速度快
### 2.3 为什么选择 langchain-text-splitters而非 chonkie
### 2.3 为什么选择 tree-sitter
最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现
使用 `tree-sitter` 实现 AST 级别的语义感知分块
| 方案 | 优势 | 劣势 |
|------|------|------|
| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer初始化慢 |
| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 |
| tree-sitter | AST 感知、语义完整、不跨类/函数切割 | 依赖语言 parser |
| langchain-text-splitters | 轻量、规则简单 | chunk_size 优先,会切割语义单元 |
**最终选择 langchain-text-splitters**
- 对于 RAG 场景,分块精度不是最关键因素
- langchain 的方案更加轻量,启动更快
- 代码按自然段落(函数、类)边界切分,效果足够好
**最终选择 tree-sitter**
- 代码以**类、函数**等完整语义单元为最小块
- 超过 token 限制时,递归在子节点级别拆分
- 支持 40+ 编程语言(通过 tree-sitter-languages
- 底层通过 ctypes 加载 parser绕过语言绑定兼容性问题
### 2.4 为什么使用 Qwen3-Embedding-0.6B
@ -187,7 +188,7 @@
| 模块 | 职责 | 依赖 |
|------|------|------|
| `cli.py` | 命令行入口,参数解析 | click |
| `chunker.py` | 代码分块处理 | langchain-text-splitters |
| `chunker.py` | 代码分块处理 | tree-sitter, tree-sitter-languages |
| `embedder.py` | 文本向量化 | sentence-transformers |
| `db.py` | 向量数据库操作 | lancedb, pyarrow |
| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db |
@ -199,12 +200,13 @@
### 4.1 代码分块算法
**实现**:使用 langchain 的 `RecursiveCharacterTextSplitter`,支持 29 种编程语言的语法感知分块
**实现**:使用 tree-sitter 进行 AST 级别的语义分块,支持 40+ 编程语言
**分块参数**
- `chunk_size=2000`:每块约 2000 字符
- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯
- `separators`:按语言自动确定,按语义优先级(类 > 函数 > 空行段落 > 换行 > 空格 > 字符)递减
**分块策略**
- **源码文件**:使用 tree-sitter 解析为 AST`class_definition`、`function_definition` 等语义节点提取完整块
- **非语义单元文件**(如 Markdown使用 tiktoken 按行分块
- **超大语义单元**>4000 tokens递归在子节点级别拆分保证每个块不超过大小限制
- **Parser 加载**:通过 ctypes 从 `languages.so` 直接加载,绕过语言绑定兼容性问题
### 4.2 向量化策略

View File

@ -11,7 +11,7 @@
- 🚀 **零配置**:嵌入式数据库,安装即用
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
- 🔍 **语义搜索**:通过自然语言查询相关代码片段
- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 29 种编程语言
- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 40+ 种编程语言
- ⚡ **极速响应**:搜索延迟 < 100ms
- 🤖 **AI 友好**:输出格式专为 AI 消费设计
@ -48,17 +48,34 @@ cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/
# 添加文件到知识库
uv run ocrag add ./src/main.py
# 递归添加整个目录
# 递归添加整个目录(自动应用 .gitignore + 二进制文件过滤)
uv run ocrag add ./src/ --recursive
# 排除特定文件
uv run ocrag add ./src/ --recursive --exclude "**/__pycache__" "**/test_*.py"
# 只添加特定类型的文件
uv run ocrag add ./src/ --recursive --include "*.py" "*.rs"
# 禁用 .gitignore 过滤
uv run ocrag add ./src/ --recursive --no-ignore
# 搜索相关代码
uv run ocrag search "如何实现用户认证"
# 列出知识库中的文件
uv run ocrag list
# 使用通配符筛选列表
uv run ocrag list "*.py" # 仅列出 Python 文件
uv run ocrag list "src/**" # 仅列出 src 目录下的文件
# 删除文件
uv run ocrag remove ./src/main.py
# 使用通配符批量删除
uv run ocrag remove "**/test_*.py" # 删除所有测试文件
uv run ocrag remove "src/**/*.js" # 删除 src 下所有 JS 文件
```
### OpenCode 集成

View File

@ -23,12 +23,19 @@ description: 代码库 RAG 技能,用于向知识库添加代码或搜索已
2. 获得返回的代码片段后,结合片段内容回答用户
3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码
### 管理知识库
- `rag_list` 可选传入 pattern 参数进行通配符筛选,如 `"*.py"``"src/**"`
- `rag_remove` 支持通配符模式批量删除,如 `"**/test_*.py"` 删除所有测试文件
## 示例
- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径
- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现"
- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true
- 用户:"知识库里有几个测试文件?" → 调用 `rag_list` pattern="**/test_*.py"
- 用户:"把测试文件都删了" → 调用 `rag_remove` pattern="**/test_*.py"
## 注意事项
- 搜索结果会显示代码片段的来源文件
- 可以通过 top_k 参数调整返回结果数量
- 添加文件后,搜索会立即包含新添加的内容
- remove 和 list 支持 fnmatch 通配符:`*` 匹配任意字符,`**` 匹配路径分隔符

View File

@ -1,19 +1,29 @@
[project]
name = "project"
name = "ocrag"
version = "0.1.0"
description = "Add your description here"
description = "Local code knowledge base RAG plugin for OpenCode"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"chonkie>=1.6.2",
"tree-sitter>=0.21",
"tree-sitter-languages>=1.10",
"click>=8.3.2",
"lancedb>=0.30.2",
"pathspec>=0.12.0",
"sentence-transformers>=5.4.1",
"sentencepiece>=0.2.1",
"tiktoken>=0.12.0",
"tokenizers>=0.22.2",
"tqdm>=4.66.0",
]
[project.scripts]
ocrag = "ocrag.cli:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[dependency-groups]
dev = [
"pytest>=9.0.3",

View File

@ -1,47 +1,478 @@
"""
代码分块模块
使用 langchain_text_splitters 进行语言感知分块
支持 29 种编程语言的语法感知分块tiktoken 已内置无需额外下载
使用 tree-sitter 进行语义感知分块
源码按类/函数等语义单元切分保证每个 chunk 都是完整的语义块
非代码文件使用简单分块
"""
import ctypes
import os
import re
import warnings
from pathlib import Path
from typing import List, Dict, Any
warnings.filterwarnings(
"ignore",
message="int argument support is deprecated",
category=DeprecationWarning,
)
from typing import List, Dict, Any, Optional, Tuple
try:
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
import tree_sitter
from tree_sitter import Parser, Language, Node
except ImportError:
raise ImportError(
"请安装 langchain-text-splitters: uv pip install langchain-text-splitters"
)
raise ImportError("请安装 tree-sitter: uv pip install tree-sitter")
LANGUAGE_MAP = {
".py": Language.PYTHON,
".rs": Language.RUST,
".cpp": Language.CPP,
".cc": Language.CPP,
".cxx": Language.CPP,
".go": Language.GO,
".java": Language.JAVA,
".js": Language.JS,
".ts": Language.TS,
".jsx": Language.JS,
".tsx": Language.TS,
".c": Language.C,
".h": Language.C,
".hpp": Language.CPP,
".rb": Language.RUBY,
".swift": Language.SWIFT,
".kt": Language.KOTLIN,
".md": Language.MARKDOWN,
".txt": None,
".py": "python",
".rs": "rust",
".cpp": "cpp",
".cc": "cpp",
".cxx": "cpp",
".go": "go",
".java": "java",
".js": "javascript",
".ts": "typescript",
".jsx": "javascript",
".tsx": "typescript",
".c": "c",
".h": "cpp",
".hpp": "cpp",
".rb": "ruby",
".swift": "swift",
".kt": "kotlin",
".sh": "bash",
".bash": "bash",
".zsh": "bash",
".lua": "lua",
".php": "php",
".scala": "scala",
".cs": "c_sharp",
".r": "r",
".R": "r",
".sql": "sql",
".toml": "toml",
".yaml": "yaml",
".yml": "yaml",
".json": "json",
".xml": "xml",
".html": "html",
".htm": "html",
".css": "css",
".scss": "css",
".less": "css",
".md": "markdown",
".rst": "rst",
".dockerfile": "dockerfile",
".tf": "hcl",
".hs": "haskell",
".ex": "elixir",
".exs": "elixir",
".erl": "erlang",
".ml": "ocaml",
".mli": "ocaml",
".fs": "fsharp",
".fsx": "fsharp",
}
SEMANTIC_UNITS = {
"python": [
"function_definition",
"class_definition",
"async_function_definition",
"decorated_definition",
],
"rust": [
"function_item",
"struct_item",
"impl_item",
"enum_item",
"trait_item",
"type_item",
"macro_invocation",
"macro_rule",
],
"cpp": [
"class_specifier",
"struct_specifier",
"namespace_definition",
"function_definition",
"method_declaration",
"constructor_initializer",
"template_declaration",
],
"c": [
"struct_specifier",
"union_specifier",
"enum_specifier",
"function_definition",
"type_definition",
],
"go": [
"function_declaration",
"method_declaration",
"type_declaration",
"const_declaration",
"var_declaration",
],
"java": [
"class_declaration",
"interface_declaration",
"enum_declaration",
"record_declaration",
"method_declaration",
],
"javascript": [
"class_declaration",
"function_declaration",
"arrow_function",
"method_definition",
],
"typescript": [
"class_declaration",
"function_declaration",
"arrow_function",
"method_definition",
"interface_declaration",
"type_alias_declaration",
"enum_declaration",
],
"ruby": [
"class",
"module",
"def",
"method",
],
"swift": [
"class_declaration",
"struct_declaration",
"enum_declaration",
"protocol_declaration",
"function_declaration",
"method_declaration",
],
"kotlin": [
"class_declaration",
"object_declaration",
"function_declaration",
"method_declaration",
],
"scala": [
"class_definition",
"object_definition",
"trait_definition",
"function_definition",
"method_definition",
],
"c_sharp": [
"class_declaration",
"struct_declaration",
"interface_declaration",
"enum_declaration",
"method_declaration",
],
"go_mod": ["import_declaration"],
"bash": [
"function_definition",
"command",
],
"lua": [
"function_declaration",
"local_function",
"assignment_statement",
],
"php": [
"class_declaration",
"interface_declaration",
"trait_declaration",
"function_definition",
"method_declaration",
],
"sql": [
"create_table_statement",
"create_index_statement",
"create_view_statement",
"create_trigger_statement",
"procedure_definition",
"function_definition",
],
"toml": ["table"],
"yaml": ["block_mapping"],
"json": ["pair", "array"],
"html": [
"element",
"script_element",
"style_element",
],
"css": ["rule_set", "at_rule"],
"markdown": ["section", "fenced_code_block"],
"rst": ["section", "directive"],
"dockerfile": ["from_instruction", "run_instruction", "cmd_instruction"],
"hcl": ["block", "attribute"],
"haskell": [
"type_declaration",
"data_declaration",
"function_definition",
"class_declaration",
"instance_declaration",
],
"elixir": [
"module_definition",
"function_definition",
"def",
"defp",
],
"erlang": [
"function_definition",
"module_definition",
],
"ocaml": [
"value_definition",
"type_definition",
"module_definition",
"class_definition",
],
"fsharp": [
"class_definition",
"module_definition",
"function_definition",
"type_definition",
],
"r": [
"function_definition",
"assignment",
],
}
def detect_language(file_path: str) -> Language | None:
"""检测文件语言类型"""
ext = Path(file_path).suffix.lower()
return LANGUAGE_MAP.get(ext)
MAX_CHUNK_CHARS = 8000
MIN_CHUNK_CHARS = 100
class TiktokenCounter:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
import tiktoken
self._enc = tiktoken.encoding_for_model("gpt-4o")
self._initialized = True
def count(self, text: str) -> int:
return len(self._enc.encode(text))
_token_counter = TiktokenCounter()
def _find_languages_so() -> str:
import site
for path in site.getsitepackages() + [site.getusersitepackages()]:
so_path = os.path.join(path, "tree_sitter_languages", "languages.so")
if os.path.exists(so_path):
return so_path
for path in site.getsitepackages():
if os.path.exists(os.path.join(path, "tree_sitter_languages")):
return os.path.join(path, "tree_sitter_languages", "languages.so")
cache_dir = os.path.expanduser("~/.cache/uv/archive-v0")
if os.path.exists(cache_dir):
for entry in os.listdir(cache_dir):
entry_path = os.path.join(
cache_dir, entry, "tree_sitter_languages", "languages.so"
)
if os.path.exists(entry_path):
return entry_path
raise FileNotFoundError(
"无法找到 tree-sitter-languages 的 languages.so。"
" 请确保已安装 tree-sitter-languages。"
)
_lib = None
def _get_lib() -> ctypes.CDLL:
global _lib
if _lib is None:
so_path = _find_languages_so()
_lib = ctypes.CDLL(so_path)
return _lib
_parsers_cache: Dict[str, Parser] = {}
def get_parser(language_name: str) -> Tuple[Parser, Language]:
if language_name in _parsers_cache:
return _parsers_cache[language_name], _parsers_cache[f"_lang_{language_name}"]
lib = _get_lib()
symbol_name = f"tree_sitter_{language_name}"
if not hasattr(lib, symbol_name):
raise ValueError(f"不支持的语言: {language_name}")
func = getattr(lib, symbol_name)
func.argtypes = []
func.restype = ctypes.c_void_p
ptr = func()
lang = Language(ptr) # type: ignore[call-arg]
parser = Parser(lang)
_parsers_cache[language_name] = parser
_parsers_cache[f"_lang_{language_name}"] = lang
return parser, lang
def _is_semantic_unit(node: Node, semantic_types: List[str]) -> bool:
return node.type in semantic_types
def _extract_semantic_chunks(
node: Node,
semantic_types: List[str],
source_bytes: bytes,
parent_context: str = "",
) -> List[Tuple[str, str, int, int]]:
chunks = []
node_type = node.type
if _is_semantic_unit(node, semantic_types):
text = source_bytes[node.start_byte : node.end_byte].decode(
"utf-8", errors="replace"
)
token_count = _token_counter.count(text)
if token_count > 4000:
children_chunks = _split_large_node(
node, semantic_types, source_bytes, parent_context
)
if children_chunks:
chunks.extend(children_chunks)
else:
chunks.append((node_type, text, node.start_byte, node.end_byte))
else:
chunks.append((node_type, text, node.start_byte, node.end_byte))
else:
for child in node.children:
chunks.extend(
_extract_semantic_chunks(
child, semantic_types, source_bytes, parent_context
)
)
return chunks
def _split_large_node(
node: Node,
semantic_types: List[str],
source_bytes: bytes,
parent_context: str,
) -> List[Tuple[str, str, int, int]]:
children_with_semantic = [
(child, child.type in semantic_types)
for child in node.children
if child.type not in ("comment", "preproc_directive", "attribute")
]
if not children_with_semantic:
return []
results = []
current_text_parts = []
current_start = children_with_semantic[0][0].start_byte
for child, is_semantic in children_with_semantic:
if is_semantic:
if current_text_parts:
combined = source_bytes[
current_start : current_text_parts[0][0].start_byte
].decode("utf-8", errors="replace") + "".join(
source_bytes[p[0].start_byte : p[0].end_byte].decode(
"utf-8", errors="replace"
)
for p in current_text_parts
)
if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4:
results.append(
(
"sub_chunk",
combined,
current_start,
current_text_parts[0][0].start_byte,
)
)
current_text_parts = []
current_start = child.start_byte
child_text = source_bytes[child.start_byte : child.end_byte].decode(
"utf-8", errors="replace"
)
if _token_counter.count(child_text) <= 4000:
results.append(
(child.type, child_text, child.start_byte, child.end_byte)
)
else:
nested = _split_large_node(
child, semantic_types, source_bytes, parent_context
)
results.extend(nested)
else:
current_text_parts.append((child, child.type))
if current_text_parts:
combined = source_bytes[
current_start : current_text_parts[0][0].start_byte
].decode("utf-8", errors="replace") + "".join(
source_bytes[p[0].start_byte : p[0].end_byte].decode(
"utf-8", errors="replace"
)
for p in current_text_parts
)
if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4:
results.append(
(
"sub_chunk",
combined,
current_start,
current_text_parts[0][0].start_byte,
)
)
return results
def _simple_chunk_text(content: str, chunk_size: int = 512) -> List[str]:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
lines = content.split("\n")
chunks = []
current = []
current_tokens = 0
for line in lines:
line_tokens = len(enc.encode(line))
if current_tokens + line_tokens > chunk_size and current:
chunks.append("\n".join(current))
current = [line]
current_tokens = line_tokens
else:
current.append(line)
current_tokens += line_tokens
if current:
chunks.append("\n".join(current))
return chunks
def chunk_code(
@ -50,41 +481,32 @@ def chunk_code(
chunk_size: int = 512,
chunk_overlap: int = 50,
) -> List[Dict[str, Any]]:
"""将代码文件分块,返回块列表
ext = Path(file_path).suffix.lower()
lang_name = LANGUAGE_MAP.get(ext)
使用 langchain_text_splitters 进行语言感知分块
支持 29 种编程语言的语法感知分块
Args:
content: 文件内容
file_path: 文件路径
chunk_size: 分块大小token
chunk_overlap: 重叠大小token
Returns:
分块列表每个块包含 text metadata
"""
language = detect_language(file_path)
if language is None:
# 不支持的语言,使用简单分块
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""],
length_function=len,
)
texts = splitter.split_text(content)
language_name = "text"
if not lang_name:
texts = _simple_chunk_text(content, chunk_size)
lang_display = "text"
else:
# 使用语言感知分块
splitter = RecursiveCharacterTextSplitter.from_language(
language=language,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
semantic_types = SEMANTIC_UNITS.get(lang_name, [])
source_bytes = content.encode("utf-8")
try:
parser, _ = get_parser(lang_name)
except Exception:
texts = _simple_chunk_text(content, chunk_size)
lang_display = lang_name
else:
tree = parser.parse(source_bytes)
raw_chunks = _extract_semantic_chunks(
tree.root_node, semantic_types, source_bytes
)
texts = splitter.split_text(content)
language_name = language.value
if not raw_chunks:
texts = _simple_chunk_text(content, chunk_size)
else:
texts = [text for _, text, _, _ in raw_chunks]
lang_display = lang_name
return [
{
@ -92,8 +514,13 @@ def chunk_code(
"metadata": {
"source_file": file_path,
"chunk_index": i,
"language": language_name,
"language": lang_display,
},
}
for i, text in enumerate(texts)
]
def detect_language(file_path: str) -> Optional[str]:
ext = Path(file_path).suffix.lower()
return LANGUAGE_MAP.get(ext)

View File

@ -1,4 +1,5 @@
import click
from pathlib import Path
from ocrag.commands import add as add_cmd
from ocrag.commands import search as search_cmd
from ocrag.commands import remove as remove_cmd
@ -13,11 +14,31 @@ def main():
@main.command()
@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True))
@click.argument("paths", nargs=-1, required=True)
@click.option("--recursive", "-r", is_flag=True, help="递归处理目录")
def add(paths, recursive):
"""向知识库添加文件或目录"""
add_cmd(paths, recursive)
@click.option(
"--include", "-i", multiple=True, help="只添加匹配的文件glob 模式,可多次指定)"
)
@click.option(
"--exclude", "-e", multiple=True, help="排除匹配的文件glob 模式,可多次指定)"
)
@click.option("--no-ignore", is_flag=True, help="禁用 .gitignore 过滤")
def add(paths, recursive, include, exclude, no_ignore):
"""向知识库添加文件或目录
默认会自动读取 .gitignore 规则并跳过匹配的文件
目录遍历时自动跳过 .git__pycache__node_modules 等目录
"""
for p in paths:
if not Path(p).exists():
raise click.BadParameter(f"路径不存在: {p}")
add_cmd(
paths,
recursive,
include_patterns=list(include) if include else None,
exclude_patterns=list(exclude) if exclude else None,
use_gitignore=not no_ignore,
)
@main.command()
@ -29,16 +50,17 @@ def search(query, top_k):
@main.command()
@click.argument("path", type=click.Path())
def remove(path):
"""从知识库中移除指定文件"""
remove_cmd(path)
@click.argument("pattern")
def remove(pattern):
"""从知识库中移除匹配模式的所有文件(支持通配符)"""
remove_cmd(pattern)
@main.command()
def list():
"""列出知识库中的所有条目"""
list_cmd()
@main.command("list")
@click.argument("pattern", required=False)
def list_files(pattern):
"""列出知识库中的条目(可选通配符筛选)"""
list_cmd(pattern)
if __name__ == "__main__":

View File

@ -1,39 +1,125 @@
import os
import pathspec
from pathlib import Path
from typing import List
from tqdm import tqdm
from ocrag.chunker import chunk_code
from ocrag.embedder import embedder
from ocrag.db import RagDB
from ocrag.utils import is_text_file
from ocrag.gitignore import get_gitignore_matcher
def collect_files(paths, recursive):
"""收集所有需要处理的文件"""
def collect_files(
paths,
recursive: bool,
include_patterns: List[str] = None,
exclude_patterns: List[str] = None,
use_gitignore: bool = True,
) -> List[Path]:
"""收集所有需要处理的文件
Args:
paths: 文件或目录路径列表
recursive: 是否递归处理目录
include_patterns: 只处理匹配这些 glob 模式的文件
exclude_patterns: 跳过匹配这些 glob 模式的文件
use_gitignore: 是否应用 .gitignore 规则
"""
files = []
seen = set()
exclude_spec = None
if exclude_patterns:
exclude_spec = pathspec.PathSpec.from_lines("gitignore", exclude_patterns)
include_spec = None
if include_patterns:
include_spec = pathspec.PathSpec.from_lines("gitignore", include_patterns)
for p in paths:
p = Path(p)
p = Path(p).resolve()
if p.is_file():
files.append(p)
elif p.is_dir() and recursive:
for root, dirs, filenames in os.walk(p):
for f in filenames:
files.append(Path(root) / f)
matcher = get_gitignore_matcher(p) if use_gitignore else None
for root_dir, dirs, filenames in os.walk(p):
root_dir = Path(root_dir)
dirs[:] = [
d
for d in dirs
if d
not in (
".git",
"__pycache__",
".pytest_cache",
".mypy_cache",
"node_modules",
".venv",
"venv",
".env",
".tox",
".eggs",
)
and not d.endswith(".egg-info")
]
for filename in filenames:
file_path = root_dir / filename
if file_path in seen:
continue
seen.add(file_path)
rel = file_path.relative_to(p)
if not is_text_file(str(file_path)):
continue
if matcher and matcher.is_ignored(file_path):
continue
if exclude_spec and exclude_spec.match_file(str(rel)):
continue
if include_spec and not include_spec.match_file(str(rel)):
continue
files.append(file_path)
elif p.is_dir() and not recursive:
print(f"跳过目录 {p}(未使用 --recursive")
return files
def run(paths, recursive):
def run(
paths,
recursive: bool,
include_patterns: List[str] = None,
exclude_patterns: List[str] = None,
use_gitignore: bool = True,
):
db = RagDB()
files = collect_files(paths, recursive)
files = collect_files(
paths, recursive, include_patterns, exclude_patterns, use_gitignore
)
if not files:
print("没有找到需要处理的文件")
return
total_chunks = 0
with tqdm(total=len(files), desc="处理文件", unit="file") as pbar:
for file_path in files:
try:
content = file_path.read_text(encoding="utf-8")
chunks = chunk_code(content, str(file_path))
if not chunks:
pbar.update(1)
continue
# 批量 embedding
texts = [c["text"] for c in chunks]
vectors = embedder.embed(texts)
@ -49,8 +135,9 @@ def run(paths, recursive):
db.add_documents(documents)
total_chunks += len(documents)
print(f"{file_path} -> {len(documents)} 个块")
pbar.set_postfix({"": f"{total_chunks}"})
except Exception as e:
print(f"❌ 处理 {file_path} 失败: {e}")
print(f"\n❌ 处理 {file_path} 失败: {e}")
pbar.update(1)
print(f"\n📦 总计添加 {total_chunks} 个块")

View File

@ -1,14 +1,15 @@
from ocrag.db import RagDB
def run():
def run(pattern: str = None):
db = RagDB()
sources = db.list_sources()
sources = db.list_sources(pattern)
if not sources:
print("知识库为空")
return
print("知识库中的文件:")
label = f'匹配 "{pattern}"' if pattern else ""
print(f"{label}知识库中的文件:")
for i, source in enumerate(sources, 1):
print(f"{i}. {source}")

View File

@ -1,9 +1,8 @@
import fnmatch
from ocrag.db import RagDB
def run(path: str):
def run(pattern: str):
db = RagDB()
# 注意当前db.py中的delete_by_source方法尚未完全实现
# 这里先调用以保持接口一致后续需要完善db.py的实现
db.delete_by_source(path)
print(f"已删除 {path} 的所有块")
count = db.delete_by_pattern(pattern)
print(f'已删除 {count} 个匹配 "{pattern}" 的块')

View File

@ -8,43 +8,51 @@ DB_PATH = Path.home() / ".ocrag" / "data.lance"
class RagDB:
def __init__(self, db_path: str = None):
self.path = db_path or str(DB_PATH)
import os as _os
self.path = db_path or _os.environ.get("OCRAG_DB_PATH", str(DB_PATH))
self.conn = lancedb.connect(self.path)
self._init_table()
def _init_table(self):
"""如果表不存在则创建"""
schema = pa.schema(
[
("text", pa.string()),
(
"vector",
pa.list_(pa.float32(), 1024),
), # 1024维 (Qwen3-Embedding-0.6B)
("metadata", pa.string()), # JSON 字符串
]
)
try:
self.table = self.conn.open_table("documents")
except:
self.table = self.conn.create_table("documents", schema=schema)
empty = pa.table(
{
"text": pa.array([], type=pa.string()),
"vector": pa.array([], type=pa.list_(pa.float32(), 1024)),
"metadata": pa.array([], type=pa.string()),
}
)
self.table = self.conn.create_table("documents", empty)
def add_documents(self, documents: List[Dict[str, Any]]):
"""批量添加文档
documents: [{"text": str, "vector": List[float], "metadata": dict}, ...]
"""
import json
import numpy as np
data = []
texts = []
vectors = []
metadatas = []
for doc in documents:
data.append(
{
"text": doc["text"],
"vector": doc["vector"],
"metadata": json.dumps(doc.get("metadata", {})),
}
)
self.table.add(data)
texts.append(doc["text"])
vec = doc["vector"]
if isinstance(vec, list):
vec = np.array(vec, dtype=np.float32)
vectors.append(vec)
metadatas.append(json.dumps(doc.get("metadata", {})))
arrays = [
pa.array(texts),
pa.array(vectors, type=pa.list_(pa.float32(), 1024)),
pa.array(metadatas),
]
batch = pa.record_batch(arrays, names=["text", "vector", "metadata"])
self.table.add(batch)
def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
import json
@ -96,9 +104,50 @@ class RagDB:
return num_deleted
def list_sources(self) -> List[str]:
"""列出所有已添加的源文件路径(去重)"""
# 获取所有 metadata提取 source_file
def delete_by_pattern(self, pattern: str) -> int:
"""删除匹配通配符模式的所有源文件的所有块
Args:
pattern: fnmatch 模式 "*.py", "src/**/*.js", "**/test_*.py"
Returns:
删除的块数量
"""
import fnmatch
import json
df = self.table.to_pandas()
if df.empty:
return 0
indices_to_keep = []
for idx, row in df.iterrows():
try:
meta = json.loads(row["metadata"])
source = meta.get("source_file", "")
if not fnmatch.fnmatch(source, pattern):
indices_to_keep.append(idx)
except (json.JSONDecodeError, KeyError):
indices_to_keep.append(idx)
num_deleted = len(df) - len(indices_to_keep)
if num_deleted == 0:
return 0
df_remaining = df.loc[indices_to_keep]
self.conn.drop_table("documents")
self.table = self.conn.create_table("documents", df_remaining)
return num_deleted
def list_sources(self, pattern: str = None) -> List[str]:
"""列出所有已添加的源文件路径(去重)
Args:
pattern: 可选的 fnmatch 模式筛选匹配的文件
"""
import fnmatch
df = self.table.to_pandas()
if df.empty:
return []
@ -108,5 +157,6 @@ class RagDB:
for meta_str in df["metadata"]:
meta = json.loads(meta_str)
if "source_file" in meta:
if pattern is None or fnmatch.fnmatch(meta["source_file"], pattern):
sources.add(meta["source_file"])
return sorted(sources)

73
src/ocrag/gitignore.py Normal file
View File

@ -0,0 +1,73 @@
import os
import fnmatch
import pathspec
from pathlib import Path
from typing import List, Tuple, Set, Optional
class GitignoreMatcher:
"""多层 .gitignore 解析器
支持从项目根目录向下遍历收集所有 .gitignore 文件
每个 .gitignore 中的规则只对该文件所在目录及其子目录生效
"""
def __init__(self, root: Path):
self.root = root.resolve()
self._specs: List[Tuple[Path, pathspec.PathSpec]] = []
self._load()
def _load(self):
"""遍历目录树,收集所有 .gitignore 并构建 pathspec"""
for dirpath, dirnames, filenames in os.walk(self.root):
dirpath = Path(dirpath)
gi_path = dirpath / ".gitignore"
if gi_path.exists() and gi_path.is_file():
try:
lines = gi_path.read_text(
encoding="utf-8", errors="ignore"
).splitlines()
spec = pathspec.PathSpec.from_lines("gitignore", lines)
self._specs.append((dirpath, spec))
except Exception:
pass
dirnames[:] = [d for d in dirnames if d not in (".git", "__pycache__")]
def is_ignored(self, file_path: Path) -> bool:
"""判断文件是否被任一层 .gitignore 忽略"""
abs_path = file_path.resolve()
for dirpath, spec in self._specs:
try:
rel = abs_path.relative_to(dirpath)
except ValueError:
continue
if spec.match_file(str(rel)):
return True
return False
def find_project_root(start_path: Path) -> Optional[Path]:
"""从 start_path 向上查找包含 .gitignore 的目录作为项目根"""
current = start_path.resolve()
if start_path.is_file():
current = current.parent
while True:
if (current / ".gitignore").exists():
return current
parent = current.parent
if parent == current:
return None
current = parent
def get_gitignore_matcher(root: Path) -> Optional[GitignoreMatcher]:
"""为给定目录创建 GitignoreMatcher"""
gi_path = find_project_root(root)
if gi_path is None:
return None
return GitignoreMatcher(gi_path)

View File

@ -33,44 +33,23 @@ def format_size(size_bytes: int) -> str:
def is_text_file(file_path: str) -> bool:
"""检查文件是否为文本文件"""
text_extensions = {
".py",
".js",
".ts",
".jsx",
".tsx",
".java",
".c",
".cpp",
".h",
".hpp",
".go",
".rs",
".rb",
".php",
".swift",
".kt",
".scala",
".md",
".txt",
".json",
".yaml",
".yml",
".xml",
".html",
".css",
".sql",
".sh",
".bash",
".zsh",
".ps1",
".bat",
".vue",
".svelte",
}
ext = Path(file_path).suffix.lower()
return ext in text_extensions
"""检查文件是否为文本文件
通过读取文件前 1KB 内容检测
- 包含 null byte (\\x00) 则判定为二进制
- 能够以 UTF-8 解码则判定为文本
"""
try:
with open(file_path, "rb") as f:
chunk = f.read(1024)
if not chunk:
return True
if b"\x00" in chunk:
return False
chunk.decode("utf-8")
return True
except (UnicodeDecodeError, OSError):
return False
def get_project_root() -> Path:

View File

@ -3,6 +3,7 @@ import os
from click.testing import CliRunner
from ocrag.cli import main
from ocrag.db import RagDB
from ocrag.commands.add import collect_files
@pytest.fixture
@ -12,17 +13,15 @@ def runner():
@pytest.fixture
def setup_test_env(tmpdir):
# Create test files
os.makedirs(os.path.join(tmpdir, "test_dir"))
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
f.write("def test_func():\n pass")
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
f.write("print('Hello')")
return tmpdir
return str(tmpdir)
def test_add_command(runner, setup_test_env, tmpdir):
# Use temporary DB
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
@ -33,7 +32,6 @@ def test_add_command(runner, setup_test_env, tmpdir):
def test_search_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
@ -42,14 +40,12 @@ def test_search_command(runner, setup_test_env, tmpdir):
)
assert add_result.exit_code == 0
# Search
result = runner.invoke(main, ["search", "Hello"])
assert "Hello" in result.output
assert "来源: " in result.output
def test_list_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
@ -58,6 +54,104 @@ def test_list_command(runner, setup_test_env, tmpdir):
)
assert add_result.exit_code == 0
# List
result = runner.invoke(main, ["list"])
assert "test_file.py" in result.output
def test_remove_wildcard(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
result = runner.invoke(main, ["remove", "*.py"])
assert result.exit_code == 0
assert "已删除" in result.output
list_result = runner.invoke(main, ["list"])
assert "test_file.py" not in list_result.output
def test_list_wildcard(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
result = runner.invoke(main, ["list", "*.py"])
assert result.exit_code == 0
assert "test_file.py" in result.output
assert "file1.py" in result.output
class TestCollectFiles:
def test_exclude_pattern(self, setup_test_env):
files = collect_files(
[setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"]
)
names = [f.name for f in files]
assert "test_file.py" not in names
assert "file1.py" in names
def test_include_pattern(self, setup_test_env):
files = collect_files(
[setup_test_env], recursive=True, include_patterns=["*.py"]
)
names = [f.name for f in files]
assert "test_file.py" in names
assert "file1.py" in names
def test_binary_skip(self, tmpdir):
with open(os.path.join(tmpdir, "binary.bin"), "wb") as f:
f.write(b"\x00\x01\x02binary")
with open(os.path.join(tmpdir, "text.py"), "w") as f:
f.write("print('hello')")
files = collect_files([str(tmpdir)], recursive=True)
names = [f.name for f in files]
assert "binary.bin" not in names
assert "text.py" in names
def test_gitignore(self, tmpdir):
os.makedirs(os.path.join(tmpdir, "src"))
with open(os.path.join(tmpdir, "src", "main.py"), "w") as f:
f.write("print('hello')")
with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f:
f.write("def test(): pass")
with open(os.path.join(tmpdir, ".gitignore"), "w") as f:
f.write("**/test_*.py\n")
files = collect_files([os.path.join(tmpdir, "src")], recursive=True)
names = [f.name for f in files]
assert "main.py" in names
assert "test_main.py" not in names
def test_no_ignore_disables_gitignore(self, setup_test_env):
with open(os.path.join(setup_test_env, ".gitignore"), "w") as f:
f.write("test_file.py\n")
files = collect_files([setup_test_env], recursive=True, use_gitignore=False)
names = [f.name for f in files]
assert "test_file.py" in names
def test_combined_filters(self, tmpdir):
os.makedirs(os.path.join(tmpdir, "lib"))
with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f:
f.write("def main(): pass")
with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f:
f.write("def test(): pass")
with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f:
f.write(b"\x00binary")
files = collect_files(
[os.path.join(tmpdir, "lib")],
recursive=True,
include_patterns=["*.py"],
exclude_patterns=["**/test_*.py"],
)
names = [f.name for f in files]
assert "main.py" in names
assert "test_main.py" not in names
assert "data.bin" not in names

126
tests/unit/test_utils.py Normal file
View File

@ -0,0 +1,126 @@
import os
import pytest
from pathlib import Path
from ocrag.gitignore import GitignoreMatcher, get_gitignore_matcher, find_project_root
from ocrag.utils import is_text_file
class TestIsTextFile:
def test_python_file_is_text(self, tmp_path):
f = tmp_path / "test.py"
f.write_text("def foo():\n pass")
assert is_text_file(str(f)) is True
def test_binary_file_is_not_text(self, tmp_path):
f = tmp_path / "test.bin"
f.write_bytes(b"\x00\x01\x02\x03binary")
assert is_text_file(str(f)) is False
def test_empty_file_is_text(self, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("")
assert is_text_file(str(f)) is True
def test_utf8_with_non_ascii_is_text(self, tmp_path):
f = tmp_path / "unicode.txt"
f.write_text("中文内容测试\n日本語\n한국어")
assert is_text_file(str(f)) is True
def test_invalid_utf8_is_not_text(self, tmp_path):
f = tmp_path / "invalid.txt"
f.write_bytes(b"hello\x80world")
assert is_text_file(str(f)) is False
class TestFindProjectRoot:
def test_finds_gitignore(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.pyc")
sub = tmp_path / "src" / "deep"
sub.mkdir(parents=True)
found = find_project_root(sub)
assert found == tmp_path
def test_no_gitignore_returns_none(self, tmp_path):
sub = tmp_path / "src"
sub.mkdir(parents=True)
found = find_project_root(sub)
assert found is None
class TestGitignoreMatcher:
def test_ignore_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.pyc\n__pycache__/\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "file.pyc") is True
assert matcher.is_ignored(tmp_path / "file.py") is False
assert matcher.is_ignored(tmp_path / "__pycache__" / "cache.pyc") is True
def test_negation_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.log\n!important.log\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "debug.log") is True
assert matcher.is_ignored(tmp_path / "important.log") is False
def test_subdirectory_gitignore(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.py\n")
sub_dir = tmp_path / "src"
sub_dir.mkdir()
(sub_dir / ".gitignore").write_text("test_*.py\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "main.py") is True
assert matcher.is_ignored(tmp_path / "utils.py") is True
assert matcher.is_ignored(sub_dir / "helper.py") is True
assert matcher.is_ignored(sub_dir / "test_main.py") is True
def test_anchored_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("/build\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "build" / "output") is True
assert matcher.is_ignored(tmp_path / "src" / "build" / "out") is False
def test_ignore_directory_ending_slash(self, tmp_path):
(tmp_path / ".gitignore").write_text("node_modules/\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "node_modules" / "package.json") is True
assert matcher.is_ignored(tmp_path / "other" / "node_modules" / "file") is True
def test_get_matcher_returns_none_when_no_gitignore(self, tmp_path):
result = get_gitignore_matcher(tmp_path / "nonexistent")
assert result is None
def test_get_matcher_returns_matcher(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.pyc")
result = get_gitignore_matcher(tmp_path)
assert result is not None
assert result.is_ignored(tmp_path / "file.pyc") is True
def test_comments_ignored(self, tmp_path):
(tmp_path / ".gitignore").write_text(
"# this is a comment\n*.log\n # another comment\n"
)
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "debug.log") is True
def test_whitespace_lines(self, tmp_path):
(tmp_path / ".gitignore").write_text(" *.log\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / " debug.log") is True
assert matcher.is_ignored(tmp_path / "debug.log") is False
def test_double_star_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("**/*.pyc\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "a.pyc") is True
assert matcher.is_ignored(tmp_path / "src" / "a.pyc") is True
assert matcher.is_ignored(tmp_path / "src" / "deep" / "a.pyc") is True
def test_no_gitignore_file(self, tmp_path):
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "file.py") is False
assert matcher.is_ignored(tmp_path / "anything") is False