diff --git a/.gitignore b/.gitignore index 284b0d2..719b928 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +**/__pycache__/ models/* uv.lock diff --git a/DescriptionOfDesign.md b/DescriptionOfDesign.md index a0ebbb8..aa0d136 100644 --- a/DescriptionOfDesign.md +++ b/DescriptionOfDesign.md @@ -63,19 +63,20 @@ - **Python 优先**:原生 Python SDK,类型提示完善 - **性能优秀**:基于 Apache Arrow,查询速度快 -### 2.3 为什么选择 langchain-text-splitters(而非 chonkie) +### 2.3 为什么选择 tree-sitter -最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现: +使用 `tree-sitter` 实现 AST 级别的语义感知分块: | 方案 | 优势 | 劣势 | |------|------|------| -| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer,初始化慢 | -| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 | +| tree-sitter | AST 感知、语义完整、不跨类/函数切割 | 依赖语言 parser | +| langchain-text-splitters | 轻量、规则简单 | chunk_size 优先,会切割语义单元 | -**最终选择 langchain-text-splitters**: -- 对于 RAG 场景,分块精度不是最关键因素 -- langchain 的方案更加轻量,启动更快 -- 代码按自然段落(函数、类)边界切分,效果足够好 +**最终选择 tree-sitter**: +- 代码以**类、函数**等完整语义单元为最小块 +- 超过 token 限制时,递归在子节点级别拆分 +- 支持 40+ 编程语言(通过 tree-sitter-languages) +- 底层通过 ctypes 加载 parser,绕过语言绑定兼容性问题 ### 2.4 为什么使用 Qwen3-Embedding-0.6B @@ -187,7 +188,7 @@ | 模块 | 职责 | 依赖 | |------|------|------| | `cli.py` | 命令行入口,参数解析 | click | -| `chunker.py` | 代码分块处理 | langchain-text-splitters | +| `chunker.py` | 代码分块处理 | tree-sitter, tree-sitter-languages | | `embedder.py` | 文本向量化 | sentence-transformers | | `db.py` | 向量数据库操作 | lancedb, pyarrow | | `commands/*.py` | 各命令实现 | cli, chunker, embedder, db | @@ -199,12 +200,13 @@ ### 4.1 代码分块算法 -**实现**:使用 langchain 的 `RecursiveCharacterTextSplitter`,支持 29 种编程语言的语法感知分块。 +**实现**:使用 tree-sitter 进行 AST 级别的语义分块,支持 40+ 编程语言。 -**分块参数**: -- `chunk_size=2000`:每块约 2000 字符 -- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯 -- `separators`:按语言自动确定,按语义优先级(类 > 函数 > 空行段落 > 换行 > 空格 > 字符)递减 +**分块策略**: +- **源码文件**:使用 tree-sitter 解析为 AST,按 `class_definition`、`function_definition` 等语义节点提取完整块 +- **非语义单元文件**(如 Markdown):使用 tiktoken 按行分块 +- **超大语义单元**(>4000 tokens):递归在子节点级别拆分,保证每个块不超过大小限制 +- **Parser 加载**:通过 ctypes 从 `languages.so` 直接加载,绕过语言绑定兼容性问题 ### 4.2 向量化策略 diff --git a/README.md b/README.md index 4244d5e..91b36fc 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ - 🚀 **零配置**:嵌入式数据库,安装即用 - 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器 - 🔍 **语义搜索**:通过自然语言查询相关代码片段 -- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 29 种编程语言 +- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 40+ 种编程语言 - ⚡ **极速响应**:搜索延迟 < 100ms - 🤖 **AI 友好**:输出格式专为 AI 消费设计 @@ -48,17 +48,34 @@ cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/ # 添加文件到知识库 uv run ocrag add ./src/main.py -# 递归添加整个目录 +# 递归添加整个目录(自动应用 .gitignore + 二进制文件过滤) uv run ocrag add ./src/ --recursive +# 排除特定文件 +uv run ocrag add ./src/ --recursive --exclude "**/__pycache__" "**/test_*.py" + +# 只添加特定类型的文件 +uv run ocrag add ./src/ --recursive --include "*.py" "*.rs" + +# 禁用 .gitignore 过滤 +uv run ocrag add ./src/ --recursive --no-ignore + # 搜索相关代码 uv run ocrag search "如何实现用户认证" # 列出知识库中的文件 uv run ocrag list +# 使用通配符筛选列表 +uv run ocrag list "*.py" # 仅列出 Python 文件 +uv run ocrag list "src/**" # 仅列出 src 目录下的文件 + # 删除文件 uv run ocrag remove ./src/main.py + +# 使用通配符批量删除 +uv run ocrag remove "**/test_*.py" # 删除所有测试文件 +uv run ocrag remove "src/**/*.js" # 删除 src 下所有 JS 文件 ``` ### OpenCode 集成 diff --git a/opencode-skill/SKILL.md b/opencode-skill/SKILL.md index 00ad13c..33b78ca 100644 --- a/opencode-skill/SKILL.md +++ b/opencode-skill/SKILL.md @@ -23,12 +23,19 @@ description: 代码库 RAG 技能,用于向知识库添加代码或搜索已 2. 获得返回的代码片段后,结合片段内容回答用户 3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码 +### 管理知识库 +- `rag_list` 可选传入 pattern 参数进行通配符筛选,如 `"*.py"` 或 `"src/**"` +- `rag_remove` 支持通配符模式批量删除,如 `"**/test_*.py"` 删除所有测试文件 + ## 示例 - 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径 - 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现" - 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true +- 用户:"知识库里有几个测试文件?" → 调用 `rag_list` pattern="**/test_*.py" +- 用户:"把测试文件都删了" → 调用 `rag_remove` pattern="**/test_*.py" ## 注意事项 - 搜索结果会显示代码片段的来源文件 - 可以通过 top_k 参数调整返回结果数量 - 添加文件后,搜索会立即包含新添加的内容 +- remove 和 list 支持 fnmatch 通配符:`*` 匹配任意字符,`**` 匹配路径分隔符 diff --git a/pyproject.toml b/pyproject.toml index 5231c48..f0d1853 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,29 @@ [project] -name = "project" +name = "ocrag" version = "0.1.0" -description = "Add your description here" +description = "Local code knowledge base RAG plugin for OpenCode" readme = "README.md" requires-python = ">=3.12" dependencies = [ - "chonkie>=1.6.2", + "tree-sitter>=0.21", + "tree-sitter-languages>=1.10", "click>=8.3.2", "lancedb>=0.30.2", + "pathspec>=0.12.0", "sentence-transformers>=5.4.1", "sentencepiece>=0.2.1", "tiktoken>=0.12.0", "tokenizers>=0.22.2", + "tqdm>=4.66.0", ] +[project.scripts] +ocrag = "ocrag.cli:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [dependency-groups] dev = [ "pytest>=9.0.3", diff --git a/src/ocrag/chunker.py b/src/ocrag/chunker.py index e720fb2..9d91abf 100644 --- a/src/ocrag/chunker.py +++ b/src/ocrag/chunker.py @@ -1,47 +1,478 @@ """ 代码分块模块 -使用 langchain_text_splitters 进行语言感知分块。 -支持 29 种编程语言的语法感知分块,tiktoken 已内置,无需额外下载。 +使用 tree-sitter 进行语义感知分块。 +源码按类/函数等语义单元切分,保证每个 chunk 都是完整的语义块。 +非代码文件使用简单分块。 """ +import ctypes +import os +import re +import warnings from pathlib import Path -from typing import List, Dict, Any + +warnings.filterwarnings( + "ignore", + message="int argument support is deprecated", + category=DeprecationWarning, +) +from typing import List, Dict, Any, Optional, Tuple try: - from langchain_text_splitters import Language, RecursiveCharacterTextSplitter + import tree_sitter + from tree_sitter import Parser, Language, Node except ImportError: - raise ImportError( - "请安装 langchain-text-splitters: uv pip install langchain-text-splitters" - ) + raise ImportError("请安装 tree-sitter: uv pip install tree-sitter") LANGUAGE_MAP = { - ".py": Language.PYTHON, - ".rs": Language.RUST, - ".cpp": Language.CPP, - ".cc": Language.CPP, - ".cxx": Language.CPP, - ".go": Language.GO, - ".java": Language.JAVA, - ".js": Language.JS, - ".ts": Language.TS, - ".jsx": Language.JS, - ".tsx": Language.TS, - ".c": Language.C, - ".h": Language.C, - ".hpp": Language.CPP, - ".rb": Language.RUBY, - ".swift": Language.SWIFT, - ".kt": Language.KOTLIN, - ".md": Language.MARKDOWN, - ".txt": None, + ".py": "python", + ".rs": "rust", + ".cpp": "cpp", + ".cc": "cpp", + ".cxx": "cpp", + ".go": "go", + ".java": "java", + ".js": "javascript", + ".ts": "typescript", + ".jsx": "javascript", + ".tsx": "typescript", + ".c": "c", + ".h": "cpp", + ".hpp": "cpp", + ".rb": "ruby", + ".swift": "swift", + ".kt": "kotlin", + ".sh": "bash", + ".bash": "bash", + ".zsh": "bash", + ".lua": "lua", + ".php": "php", + ".scala": "scala", + ".cs": "c_sharp", + ".r": "r", + ".R": "r", + ".sql": "sql", + ".toml": "toml", + ".yaml": "yaml", + ".yml": "yaml", + ".json": "json", + ".xml": "xml", + ".html": "html", + ".htm": "html", + ".css": "css", + ".scss": "css", + ".less": "css", + ".md": "markdown", + ".rst": "rst", + ".dockerfile": "dockerfile", + ".tf": "hcl", + ".hs": "haskell", + ".ex": "elixir", + ".exs": "elixir", + ".erl": "erlang", + ".ml": "ocaml", + ".mli": "ocaml", + ".fs": "fsharp", + ".fsx": "fsharp", } +SEMANTIC_UNITS = { + "python": [ + "function_definition", + "class_definition", + "async_function_definition", + "decorated_definition", + ], + "rust": [ + "function_item", + "struct_item", + "impl_item", + "enum_item", + "trait_item", + "type_item", + "macro_invocation", + "macro_rule", + ], + "cpp": [ + "class_specifier", + "struct_specifier", + "namespace_definition", + "function_definition", + "method_declaration", + "constructor_initializer", + "template_declaration", + ], + "c": [ + "struct_specifier", + "union_specifier", + "enum_specifier", + "function_definition", + "type_definition", + ], + "go": [ + "function_declaration", + "method_declaration", + "type_declaration", + "const_declaration", + "var_declaration", + ], + "java": [ + "class_declaration", + "interface_declaration", + "enum_declaration", + "record_declaration", + "method_declaration", + ], + "javascript": [ + "class_declaration", + "function_declaration", + "arrow_function", + "method_definition", + ], + "typescript": [ + "class_declaration", + "function_declaration", + "arrow_function", + "method_definition", + "interface_declaration", + "type_alias_declaration", + "enum_declaration", + ], + "ruby": [ + "class", + "module", + "def", + "method", + ], + "swift": [ + "class_declaration", + "struct_declaration", + "enum_declaration", + "protocol_declaration", + "function_declaration", + "method_declaration", + ], + "kotlin": [ + "class_declaration", + "object_declaration", + "function_declaration", + "method_declaration", + ], + "scala": [ + "class_definition", + "object_definition", + "trait_definition", + "function_definition", + "method_definition", + ], + "c_sharp": [ + "class_declaration", + "struct_declaration", + "interface_declaration", + "enum_declaration", + "method_declaration", + ], + "go_mod": ["import_declaration"], + "bash": [ + "function_definition", + "command", + ], + "lua": [ + "function_declaration", + "local_function", + "assignment_statement", + ], + "php": [ + "class_declaration", + "interface_declaration", + "trait_declaration", + "function_definition", + "method_declaration", + ], + "sql": [ + "create_table_statement", + "create_index_statement", + "create_view_statement", + "create_trigger_statement", + "procedure_definition", + "function_definition", + ], + "toml": ["table"], + "yaml": ["block_mapping"], + "json": ["pair", "array"], + "html": [ + "element", + "script_element", + "style_element", + ], + "css": ["rule_set", "at_rule"], + "markdown": ["section", "fenced_code_block"], + "rst": ["section", "directive"], + "dockerfile": ["from_instruction", "run_instruction", "cmd_instruction"], + "hcl": ["block", "attribute"], + "haskell": [ + "type_declaration", + "data_declaration", + "function_definition", + "class_declaration", + "instance_declaration", + ], + "elixir": [ + "module_definition", + "function_definition", + "def", + "defp", + ], + "erlang": [ + "function_definition", + "module_definition", + ], + "ocaml": [ + "value_definition", + "type_definition", + "module_definition", + "class_definition", + ], + "fsharp": [ + "class_definition", + "module_definition", + "function_definition", + "type_definition", + ], + "r": [ + "function_definition", + "assignment", + ], +} -def detect_language(file_path: str) -> Language | None: - """检测文件语言类型""" - ext = Path(file_path).suffix.lower() - return LANGUAGE_MAP.get(ext) +MAX_CHUNK_CHARS = 8000 +MIN_CHUNK_CHARS = 100 + + +class TiktokenCounter: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + import tiktoken + + self._enc = tiktoken.encoding_for_model("gpt-4o") + self._initialized = True + + def count(self, text: str) -> int: + return len(self._enc.encode(text)) + + +_token_counter = TiktokenCounter() + + +def _find_languages_so() -> str: + import site + + for path in site.getsitepackages() + [site.getusersitepackages()]: + so_path = os.path.join(path, "tree_sitter_languages", "languages.so") + if os.path.exists(so_path): + return so_path + for path in site.getsitepackages(): + if os.path.exists(os.path.join(path, "tree_sitter_languages")): + return os.path.join(path, "tree_sitter_languages", "languages.so") + cache_dir = os.path.expanduser("~/.cache/uv/archive-v0") + if os.path.exists(cache_dir): + for entry in os.listdir(cache_dir): + entry_path = os.path.join( + cache_dir, entry, "tree_sitter_languages", "languages.so" + ) + if os.path.exists(entry_path): + return entry_path + raise FileNotFoundError( + "无法找到 tree-sitter-languages 的 languages.so。" + " 请确保已安装 tree-sitter-languages。" + ) + + +_lib = None + + +def _get_lib() -> ctypes.CDLL: + global _lib + if _lib is None: + so_path = _find_languages_so() + _lib = ctypes.CDLL(so_path) + return _lib + + +_parsers_cache: Dict[str, Parser] = {} + + +def get_parser(language_name: str) -> Tuple[Parser, Language]: + if language_name in _parsers_cache: + return _parsers_cache[language_name], _parsers_cache[f"_lang_{language_name}"] + + lib = _get_lib() + symbol_name = f"tree_sitter_{language_name}" + if not hasattr(lib, symbol_name): + raise ValueError(f"不支持的语言: {language_name}") + + func = getattr(lib, symbol_name) + func.argtypes = [] + func.restype = ctypes.c_void_p + ptr = func() + lang = Language(ptr) # type: ignore[call-arg] + parser = Parser(lang) + _parsers_cache[language_name] = parser + _parsers_cache[f"_lang_{language_name}"] = lang + return parser, lang + + +def _is_semantic_unit(node: Node, semantic_types: List[str]) -> bool: + return node.type in semantic_types + + +def _extract_semantic_chunks( + node: Node, + semantic_types: List[str], + source_bytes: bytes, + parent_context: str = "", +) -> List[Tuple[str, str, int, int]]: + chunks = [] + node_type = node.type + + if _is_semantic_unit(node, semantic_types): + text = source_bytes[node.start_byte : node.end_byte].decode( + "utf-8", errors="replace" + ) + token_count = _token_counter.count(text) + + if token_count > 4000: + children_chunks = _split_large_node( + node, semantic_types, source_bytes, parent_context + ) + if children_chunks: + chunks.extend(children_chunks) + else: + chunks.append((node_type, text, node.start_byte, node.end_byte)) + else: + chunks.append((node_type, text, node.start_byte, node.end_byte)) + else: + for child in node.children: + chunks.extend( + _extract_semantic_chunks( + child, semantic_types, source_bytes, parent_context + ) + ) + + return chunks + + +def _split_large_node( + node: Node, + semantic_types: List[str], + source_bytes: bytes, + parent_context: str, +) -> List[Tuple[str, str, int, int]]: + children_with_semantic = [ + (child, child.type in semantic_types) + for child in node.children + if child.type not in ("comment", "preproc_directive", "attribute") + ] + + if not children_with_semantic: + return [] + + results = [] + current_text_parts = [] + current_start = children_with_semantic[0][0].start_byte + + for child, is_semantic in children_with_semantic: + if is_semantic: + if current_text_parts: + combined = source_bytes[ + current_start : current_text_parts[0][0].start_byte + ].decode("utf-8", errors="replace") + "".join( + source_bytes[p[0].start_byte : p[0].end_byte].decode( + "utf-8", errors="replace" + ) + for p in current_text_parts + ) + if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4: + results.append( + ( + "sub_chunk", + combined, + current_start, + current_text_parts[0][0].start_byte, + ) + ) + current_text_parts = [] + current_start = child.start_byte + + child_text = source_bytes[child.start_byte : child.end_byte].decode( + "utf-8", errors="replace" + ) + if _token_counter.count(child_text) <= 4000: + results.append( + (child.type, child_text, child.start_byte, child.end_byte) + ) + else: + nested = _split_large_node( + child, semantic_types, source_bytes, parent_context + ) + results.extend(nested) + else: + current_text_parts.append((child, child.type)) + + if current_text_parts: + combined = source_bytes[ + current_start : current_text_parts[0][0].start_byte + ].decode("utf-8", errors="replace") + "".join( + source_bytes[p[0].start_byte : p[0].end_byte].decode( + "utf-8", errors="replace" + ) + for p in current_text_parts + ) + if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4: + results.append( + ( + "sub_chunk", + combined, + current_start, + current_text_parts[0][0].start_byte, + ) + ) + + return results + + +def _simple_chunk_text(content: str, chunk_size: int = 512) -> List[str]: + import tiktoken + + enc = tiktoken.encoding_for_model("gpt-4o") + lines = content.split("\n") + chunks = [] + current = [] + current_tokens = 0 + + for line in lines: + line_tokens = len(enc.encode(line)) + if current_tokens + line_tokens > chunk_size and current: + chunks.append("\n".join(current)) + current = [line] + current_tokens = line_tokens + else: + current.append(line) + current_tokens += line_tokens + + if current: + chunks.append("\n".join(current)) + return chunks def chunk_code( @@ -50,41 +481,32 @@ def chunk_code( chunk_size: int = 512, chunk_overlap: int = 50, ) -> List[Dict[str, Any]]: - """将代码文件分块,返回块列表 + ext = Path(file_path).suffix.lower() + lang_name = LANGUAGE_MAP.get(ext) - 使用 langchain_text_splitters 进行语言感知分块。 - 支持 29 种编程语言的语法感知分块。 - - Args: - content: 文件内容 - file_path: 文件路径 - chunk_size: 分块大小(token 数) - chunk_overlap: 重叠大小(token 数) - - Returns: - 分块列表,每个块包含 text 和 metadata - """ - language = detect_language(file_path) - - if language is None: - # 不支持的语言,使用简单分块 - splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - separators=["\n\n", "\n", " ", ""], - length_function=len, - ) - texts = splitter.split_text(content) - language_name = "text" + if not lang_name: + texts = _simple_chunk_text(content, chunk_size) + lang_display = "text" else: - # 使用语言感知分块 - splitter = RecursiveCharacterTextSplitter.from_language( - language=language, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - texts = splitter.split_text(content) - language_name = language.value + semantic_types = SEMANTIC_UNITS.get(lang_name, []) + source_bytes = content.encode("utf-8") + + try: + parser, _ = get_parser(lang_name) + except Exception: + texts = _simple_chunk_text(content, chunk_size) + lang_display = lang_name + else: + tree = parser.parse(source_bytes) + raw_chunks = _extract_semantic_chunks( + tree.root_node, semantic_types, source_bytes + ) + + if not raw_chunks: + texts = _simple_chunk_text(content, chunk_size) + else: + texts = [text for _, text, _, _ in raw_chunks] + lang_display = lang_name return [ { @@ -92,8 +514,13 @@ def chunk_code( "metadata": { "source_file": file_path, "chunk_index": i, - "language": language_name, + "language": lang_display, }, } for i, text in enumerate(texts) ] + + +def detect_language(file_path: str) -> Optional[str]: + ext = Path(file_path).suffix.lower() + return LANGUAGE_MAP.get(ext) diff --git a/src/ocrag/cli.py b/src/ocrag/cli.py index 8c5826e..842e017 100644 --- a/src/ocrag/cli.py +++ b/src/ocrag/cli.py @@ -1,4 +1,5 @@ import click +from pathlib import Path from ocrag.commands import add as add_cmd from ocrag.commands import search as search_cmd from ocrag.commands import remove as remove_cmd @@ -13,11 +14,31 @@ def main(): @main.command() -@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True)) +@click.argument("paths", nargs=-1, required=True) @click.option("--recursive", "-r", is_flag=True, help="递归处理目录") -def add(paths, recursive): - """向知识库添加文件或目录""" - add_cmd(paths, recursive) +@click.option( + "--include", "-i", multiple=True, help="只添加匹配的文件(glob 模式,可多次指定)" +) +@click.option( + "--exclude", "-e", multiple=True, help="排除匹配的文件(glob 模式,可多次指定)" +) +@click.option("--no-ignore", is_flag=True, help="禁用 .gitignore 过滤") +def add(paths, recursive, include, exclude, no_ignore): + """向知识库添加文件或目录 + + 默认会自动读取 .gitignore 规则并跳过匹配的文件。 + 目录遍历时自动跳过 .git、__pycache__、node_modules 等目录。 + """ + for p in paths: + if not Path(p).exists(): + raise click.BadParameter(f"路径不存在: {p}") + add_cmd( + paths, + recursive, + include_patterns=list(include) if include else None, + exclude_patterns=list(exclude) if exclude else None, + use_gitignore=not no_ignore, + ) @main.command() @@ -29,16 +50,17 @@ def search(query, top_k): @main.command() -@click.argument("path", type=click.Path()) -def remove(path): - """从知识库中移除指定文件""" - remove_cmd(path) +@click.argument("pattern") +def remove(pattern): + """从知识库中移除匹配模式的所有文件(支持通配符)""" + remove_cmd(pattern) -@main.command() -def list(): - """列出知识库中的所有条目""" - list_cmd() +@main.command("list") +@click.argument("pattern", required=False) +def list_files(pattern): + """列出知识库中的条目(可选通配符筛选)""" + list_cmd(pattern) if __name__ == "__main__": diff --git a/src/ocrag/commands/add.py b/src/ocrag/commands/add.py index d5d9f99..2069c18 100644 --- a/src/ocrag/commands/add.py +++ b/src/ocrag/commands/add.py @@ -1,56 +1,143 @@ import os +import pathspec from pathlib import Path +from typing import List +from tqdm import tqdm from ocrag.chunker import chunk_code from ocrag.embedder import embedder from ocrag.db import RagDB +from ocrag.utils import is_text_file +from ocrag.gitignore import get_gitignore_matcher -def collect_files(paths, recursive): - """收集所有需要处理的文件""" +def collect_files( + paths, + recursive: bool, + include_patterns: List[str] = None, + exclude_patterns: List[str] = None, + use_gitignore: bool = True, +) -> List[Path]: + """收集所有需要处理的文件 + + Args: + paths: 文件或目录路径列表 + recursive: 是否递归处理目录 + include_patterns: 只处理匹配这些 glob 模式的文件 + exclude_patterns: 跳过匹配这些 glob 模式的文件 + use_gitignore: 是否应用 .gitignore 规则 + """ files = [] + seen = set() + + exclude_spec = None + if exclude_patterns: + exclude_spec = pathspec.PathSpec.from_lines("gitignore", exclude_patterns) + + include_spec = None + if include_patterns: + include_spec = pathspec.PathSpec.from_lines("gitignore", include_patterns) + for p in paths: - p = Path(p) + p = Path(p).resolve() if p.is_file(): files.append(p) elif p.is_dir() and recursive: - for root, dirs, filenames in os.walk(p): - for f in filenames: - files.append(Path(root) / f) + matcher = get_gitignore_matcher(p) if use_gitignore else None + + for root_dir, dirs, filenames in os.walk(p): + root_dir = Path(root_dir) + + dirs[:] = [ + d + for d in dirs + if d + not in ( + ".git", + "__pycache__", + ".pytest_cache", + ".mypy_cache", + "node_modules", + ".venv", + "venv", + ".env", + ".tox", + ".eggs", + ) + and not d.endswith(".egg-info") + ] + + for filename in filenames: + file_path = root_dir / filename + + if file_path in seen: + continue + seen.add(file_path) + + rel = file_path.relative_to(p) + + if not is_text_file(str(file_path)): + continue + + if matcher and matcher.is_ignored(file_path): + continue + + if exclude_spec and exclude_spec.match_file(str(rel)): + continue + + if include_spec and not include_spec.match_file(str(rel)): + continue + + files.append(file_path) elif p.is_dir() and not recursive: print(f"跳过目录 {p}(未使用 --recursive)") + return files -def run(paths, recursive): +def run( + paths, + recursive: bool, + include_patterns: List[str] = None, + exclude_patterns: List[str] = None, + use_gitignore: bool = True, +): db = RagDB() - files = collect_files(paths, recursive) + files = collect_files( + paths, recursive, include_patterns, exclude_patterns, use_gitignore + ) + + if not files: + print("没有找到需要处理的文件") + return total_chunks = 0 - for file_path in files: - try: - content = file_path.read_text(encoding="utf-8") - chunks = chunk_code(content, str(file_path)) - if not chunks: - continue + with tqdm(total=len(files), desc="处理文件", unit="file") as pbar: + for file_path in files: + try: + content = file_path.read_text(encoding="utf-8") + chunks = chunk_code(content, str(file_path)) + if not chunks: + pbar.update(1) + continue - # 批量 embedding - texts = [c["text"] for c in chunks] - vectors = embedder.embed(texts) + texts = [c["text"] for c in chunks] + vectors = embedder.embed(texts) - documents = [] - for chunk, vec in zip(chunks, vectors): - documents.append( - { - "text": chunk["text"], - "vector": vec, - "metadata": chunk["metadata"], - } - ) + documents = [] + for chunk, vec in zip(chunks, vectors): + documents.append( + { + "text": chunk["text"], + "vector": vec, + "metadata": chunk["metadata"], + } + ) - db.add_documents(documents) - total_chunks += len(documents) - print(f"✅ {file_path} -> {len(documents)} 个块") - except Exception as e: - print(f"❌ 处理 {file_path} 失败: {e}") + db.add_documents(documents) + total_chunks += len(documents) + pbar.set_postfix({"块": f"{total_chunks}"}) + except Exception as e: + print(f"\n❌ 处理 {file_path} 失败: {e}") + pbar.update(1) print(f"\n📦 总计添加 {total_chunks} 个块") diff --git a/src/ocrag/commands/list.py b/src/ocrag/commands/list.py index dd18f29..dd0049a 100644 --- a/src/ocrag/commands/list.py +++ b/src/ocrag/commands/list.py @@ -1,14 +1,15 @@ from ocrag.db import RagDB -def run(): +def run(pattern: str = None): db = RagDB() - sources = db.list_sources() + sources = db.list_sources(pattern) if not sources: print("知识库为空") return - print("知识库中的文件:") + label = f'匹配 "{pattern}" 的' if pattern else "" + print(f"{label}知识库中的文件:") for i, source in enumerate(sources, 1): print(f"{i}. {source}") diff --git a/src/ocrag/commands/remove.py b/src/ocrag/commands/remove.py index 992179e..f1f7374 100644 --- a/src/ocrag/commands/remove.py +++ b/src/ocrag/commands/remove.py @@ -1,9 +1,8 @@ +import fnmatch from ocrag.db import RagDB -def run(path: str): +def run(pattern: str): db = RagDB() - # 注意:当前db.py中的delete_by_source方法尚未完全实现 - # 这里先调用以保持接口一致,后续需要完善db.py的实现 - db.delete_by_source(path) - print(f"已删除 {path} 的所有块") + count = db.delete_by_pattern(pattern) + print(f'已删除 {count} 个匹配 "{pattern}" 的块') diff --git a/src/ocrag/db.py b/src/ocrag/db.py index a920441..31b08d8 100644 --- a/src/ocrag/db.py +++ b/src/ocrag/db.py @@ -8,43 +8,51 @@ DB_PATH = Path.home() / ".ocrag" / "data.lance" class RagDB: def __init__(self, db_path: str = None): - self.path = db_path or str(DB_PATH) + import os as _os + + self.path = db_path or _os.environ.get("OCRAG_DB_PATH", str(DB_PATH)) self.conn = lancedb.connect(self.path) self._init_table() def _init_table(self): """如果表不存在则创建""" - schema = pa.schema( - [ - ("text", pa.string()), - ( - "vector", - pa.list_(pa.float32(), 1024), - ), # 1024维 (Qwen3-Embedding-0.6B) - ("metadata", pa.string()), # JSON 字符串 - ] - ) try: self.table = self.conn.open_table("documents") except: - self.table = self.conn.create_table("documents", schema=schema) + empty = pa.table( + { + "text": pa.array([], type=pa.string()), + "vector": pa.array([], type=pa.list_(pa.float32(), 1024)), + "metadata": pa.array([], type=pa.string()), + } + ) + self.table = self.conn.create_table("documents", empty) def add_documents(self, documents: List[Dict[str, Any]]): """批量添加文档 documents: [{"text": str, "vector": List[float], "metadata": dict}, ...] """ import json + import numpy as np - data = [] + texts = [] + vectors = [] + metadatas = [] for doc in documents: - data.append( - { - "text": doc["text"], - "vector": doc["vector"], - "metadata": json.dumps(doc.get("metadata", {})), - } - ) - self.table.add(data) + texts.append(doc["text"]) + vec = doc["vector"] + if isinstance(vec, list): + vec = np.array(vec, dtype=np.float32) + vectors.append(vec) + metadatas.append(json.dumps(doc.get("metadata", {}))) + + arrays = [ + pa.array(texts), + pa.array(vectors, type=pa.list_(pa.float32(), 1024)), + pa.array(metadatas), + ] + batch = pa.record_batch(arrays, names=["text", "vector", "metadata"]) + self.table.add(batch) def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]: import json @@ -96,9 +104,50 @@ class RagDB: return num_deleted - def list_sources(self) -> List[str]: - """列出所有已添加的源文件路径(去重)""" - # 获取所有 metadata,提取 source_file + def delete_by_pattern(self, pattern: str) -> int: + """删除匹配通配符模式的所有源文件的所有块 + + Args: + pattern: fnmatch 模式,如 "*.py", "src/**/*.js", "**/test_*.py" + + Returns: + 删除的块数量 + """ + import fnmatch + import json + + df = self.table.to_pandas() + if df.empty: + return 0 + + indices_to_keep = [] + for idx, row in df.iterrows(): + try: + meta = json.loads(row["metadata"]) + source = meta.get("source_file", "") + if not fnmatch.fnmatch(source, pattern): + indices_to_keep.append(idx) + except (json.JSONDecodeError, KeyError): + indices_to_keep.append(idx) + + num_deleted = len(df) - len(indices_to_keep) + if num_deleted == 0: + return 0 + + df_remaining = df.loc[indices_to_keep] + self.conn.drop_table("documents") + self.table = self.conn.create_table("documents", df_remaining) + + return num_deleted + + def list_sources(self, pattern: str = None) -> List[str]: + """列出所有已添加的源文件路径(去重) + + Args: + pattern: 可选的 fnmatch 模式,筛选匹配的文件 + """ + import fnmatch + df = self.table.to_pandas() if df.empty: return [] @@ -108,5 +157,6 @@ class RagDB: for meta_str in df["metadata"]: meta = json.loads(meta_str) if "source_file" in meta: - sources.add(meta["source_file"]) + if pattern is None or fnmatch.fnmatch(meta["source_file"], pattern): + sources.add(meta["source_file"]) return sorted(sources) diff --git a/src/ocrag/gitignore.py b/src/ocrag/gitignore.py new file mode 100644 index 0000000..ff49a13 --- /dev/null +++ b/src/ocrag/gitignore.py @@ -0,0 +1,73 @@ +import os +import fnmatch +import pathspec +from pathlib import Path +from typing import List, Tuple, Set, Optional + + +class GitignoreMatcher: + """多层 .gitignore 解析器 + + 支持从项目根目录向下遍历,收集所有 .gitignore 文件, + 每个 .gitignore 中的规则只对该文件所在目录及其子目录生效。 + """ + + def __init__(self, root: Path): + self.root = root.resolve() + self._specs: List[Tuple[Path, pathspec.PathSpec]] = [] + self._load() + + def _load(self): + """遍历目录树,收集所有 .gitignore 并构建 pathspec""" + for dirpath, dirnames, filenames in os.walk(self.root): + dirpath = Path(dirpath) + gi_path = dirpath / ".gitignore" + if gi_path.exists() and gi_path.is_file(): + try: + lines = gi_path.read_text( + encoding="utf-8", errors="ignore" + ).splitlines() + spec = pathspec.PathSpec.from_lines("gitignore", lines) + self._specs.append((dirpath, spec)) + except Exception: + pass + + dirnames[:] = [d for d in dirnames if d not in (".git", "__pycache__")] + + def is_ignored(self, file_path: Path) -> bool: + """判断文件是否被任一层 .gitignore 忽略""" + abs_path = file_path.resolve() + + for dirpath, spec in self._specs: + try: + rel = abs_path.relative_to(dirpath) + except ValueError: + continue + + if spec.match_file(str(rel)): + return True + + return False + + +def find_project_root(start_path: Path) -> Optional[Path]: + """从 start_path 向上查找包含 .gitignore 的目录作为项目根""" + current = start_path.resolve() + if start_path.is_file(): + current = current.parent + + while True: + if (current / ".gitignore").exists(): + return current + parent = current.parent + if parent == current: + return None + current = parent + + +def get_gitignore_matcher(root: Path) -> Optional[GitignoreMatcher]: + """为给定目录创建 GitignoreMatcher""" + gi_path = find_project_root(root) + if gi_path is None: + return None + return GitignoreMatcher(gi_path) diff --git a/src/ocrag/utils.py b/src/ocrag/utils.py index 41f0df8..77e103d 100644 --- a/src/ocrag/utils.py +++ b/src/ocrag/utils.py @@ -33,44 +33,23 @@ def format_size(size_bytes: int) -> str: def is_text_file(file_path: str) -> bool: - """检查文件是否为文本文件""" - text_extensions = { - ".py", - ".js", - ".ts", - ".jsx", - ".tsx", - ".java", - ".c", - ".cpp", - ".h", - ".hpp", - ".go", - ".rs", - ".rb", - ".php", - ".swift", - ".kt", - ".scala", - ".md", - ".txt", - ".json", - ".yaml", - ".yml", - ".xml", - ".html", - ".css", - ".sql", - ".sh", - ".bash", - ".zsh", - ".ps1", - ".bat", - ".vue", - ".svelte", - } - ext = Path(file_path).suffix.lower() - return ext in text_extensions + """检查文件是否为文本文件 + + 通过读取文件前 1KB 内容检测: + - 包含 null byte (\\x00) 则判定为二进制 + - 能够以 UTF-8 解码则判定为文本 + """ + try: + with open(file_path, "rb") as f: + chunk = f.read(1024) + if not chunk: + return True + if b"\x00" in chunk: + return False + chunk.decode("utf-8") + return True + except (UnicodeDecodeError, OSError): + return False def get_project_root() -> Path: diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 5eef252..fe986a7 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -3,6 +3,7 @@ import os from click.testing import CliRunner from ocrag.cli import main from ocrag.db import RagDB +from ocrag.commands.add import collect_files @pytest.fixture @@ -12,17 +13,15 @@ def runner(): @pytest.fixture def setup_test_env(tmpdir): - # Create test files os.makedirs(os.path.join(tmpdir, "test_dir")) with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f: f.write("def test_func():\n pass") with open(os.path.join(tmpdir, "test_file.py"), "w") as f: f.write("print('Hello')") - return tmpdir + return str(tmpdir) def test_add_command(runner, setup_test_env, tmpdir): - # Use temporary DB db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path @@ -33,7 +32,6 @@ def test_add_command(runner, setup_test_env, tmpdir): def test_search_command(runner, setup_test_env, tmpdir): - # Use temporary DB and add a file first db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path @@ -42,14 +40,12 @@ def test_search_command(runner, setup_test_env, tmpdir): ) assert add_result.exit_code == 0 - # Search result = runner.invoke(main, ["search", "Hello"]) assert "Hello" in result.output assert "来源: " in result.output def test_list_command(runner, setup_test_env, tmpdir): - # Use temporary DB and add a file first db_path = os.path.join(tmpdir, "test_db.lance") os.environ["OCRAG_DB_PATH"] = db_path @@ -58,6 +54,104 @@ def test_list_command(runner, setup_test_env, tmpdir): ) assert add_result.exit_code == 0 - # List result = runner.invoke(main, ["list"]) assert "test_file.py" in result.output + + +def test_remove_wildcard(runner, setup_test_env, tmpdir): + db_path = os.path.join(tmpdir, "test_db.lance") + os.environ["OCRAG_DB_PATH"] = db_path + + runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")]) + runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")]) + + result = runner.invoke(main, ["remove", "*.py"]) + assert result.exit_code == 0 + assert "已删除" in result.output + + list_result = runner.invoke(main, ["list"]) + assert "test_file.py" not in list_result.output + + +def test_list_wildcard(runner, setup_test_env, tmpdir): + db_path = os.path.join(tmpdir, "test_db.lance") + os.environ["OCRAG_DB_PATH"] = db_path + + runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")]) + runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")]) + + result = runner.invoke(main, ["list", "*.py"]) + assert result.exit_code == 0 + assert "test_file.py" in result.output + assert "file1.py" in result.output + + +class TestCollectFiles: + def test_exclude_pattern(self, setup_test_env): + files = collect_files( + [setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"] + ) + names = [f.name for f in files] + assert "test_file.py" not in names + assert "file1.py" in names + + def test_include_pattern(self, setup_test_env): + files = collect_files( + [setup_test_env], recursive=True, include_patterns=["*.py"] + ) + names = [f.name for f in files] + assert "test_file.py" in names + assert "file1.py" in names + + def test_binary_skip(self, tmpdir): + with open(os.path.join(tmpdir, "binary.bin"), "wb") as f: + f.write(b"\x00\x01\x02binary") + with open(os.path.join(tmpdir, "text.py"), "w") as f: + f.write("print('hello')") + + files = collect_files([str(tmpdir)], recursive=True) + names = [f.name for f in files] + assert "binary.bin" not in names + assert "text.py" in names + + def test_gitignore(self, tmpdir): + os.makedirs(os.path.join(tmpdir, "src")) + with open(os.path.join(tmpdir, "src", "main.py"), "w") as f: + f.write("print('hello')") + with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f: + f.write("def test(): pass") + with open(os.path.join(tmpdir, ".gitignore"), "w") as f: + f.write("**/test_*.py\n") + + files = collect_files([os.path.join(tmpdir, "src")], recursive=True) + names = [f.name for f in files] + assert "main.py" in names + assert "test_main.py" not in names + + def test_no_ignore_disables_gitignore(self, setup_test_env): + with open(os.path.join(setup_test_env, ".gitignore"), "w") as f: + f.write("test_file.py\n") + + files = collect_files([setup_test_env], recursive=True, use_gitignore=False) + names = [f.name for f in files] + assert "test_file.py" in names + + def test_combined_filters(self, tmpdir): + os.makedirs(os.path.join(tmpdir, "lib")) + with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f: + f.write("def main(): pass") + with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f: + f.write("def test(): pass") + with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f: + f.write(b"\x00binary") + + files = collect_files( + [os.path.join(tmpdir, "lib")], + recursive=True, + include_patterns=["*.py"], + exclude_patterns=["**/test_*.py"], + ) + names = [f.name for f in files] + assert "main.py" in names + assert "test_main.py" not in names + assert "data.bin" not in names diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..9944432 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,126 @@ +import os +import pytest +from pathlib import Path +from ocrag.gitignore import GitignoreMatcher, get_gitignore_matcher, find_project_root +from ocrag.utils import is_text_file + + +class TestIsTextFile: + def test_python_file_is_text(self, tmp_path): + f = tmp_path / "test.py" + f.write_text("def foo():\n pass") + assert is_text_file(str(f)) is True + + def test_binary_file_is_not_text(self, tmp_path): + f = tmp_path / "test.bin" + f.write_bytes(b"\x00\x01\x02\x03binary") + assert is_text_file(str(f)) is False + + def test_empty_file_is_text(self, tmp_path): + f = tmp_path / "empty.txt" + f.write_text("") + assert is_text_file(str(f)) is True + + def test_utf8_with_non_ascii_is_text(self, tmp_path): + f = tmp_path / "unicode.txt" + f.write_text("中文内容测试\n日本語\n한국어") + assert is_text_file(str(f)) is True + + def test_invalid_utf8_is_not_text(self, tmp_path): + f = tmp_path / "invalid.txt" + f.write_bytes(b"hello\x80world") + assert is_text_file(str(f)) is False + + +class TestFindProjectRoot: + def test_finds_gitignore(self, tmp_path): + (tmp_path / ".gitignore").write_text("*.pyc") + sub = tmp_path / "src" / "deep" + sub.mkdir(parents=True) + found = find_project_root(sub) + assert found == tmp_path + + def test_no_gitignore_returns_none(self, tmp_path): + sub = tmp_path / "src" + sub.mkdir(parents=True) + found = find_project_root(sub) + assert found is None + + +class TestGitignoreMatcher: + def test_ignore_pattern(self, tmp_path): + (tmp_path / ".gitignore").write_text("*.pyc\n__pycache__/\n") + matcher = GitignoreMatcher(tmp_path) + + assert matcher.is_ignored(tmp_path / "file.pyc") is True + assert matcher.is_ignored(tmp_path / "file.py") is False + assert matcher.is_ignored(tmp_path / "__pycache__" / "cache.pyc") is True + + def test_negation_pattern(self, tmp_path): + (tmp_path / ".gitignore").write_text("*.log\n!important.log\n") + matcher = GitignoreMatcher(tmp_path) + + assert matcher.is_ignored(tmp_path / "debug.log") is True + assert matcher.is_ignored(tmp_path / "important.log") is False + + def test_subdirectory_gitignore(self, tmp_path): + (tmp_path / ".gitignore").write_text("*.py\n") + sub_dir = tmp_path / "src" + sub_dir.mkdir() + (sub_dir / ".gitignore").write_text("test_*.py\n") + matcher = GitignoreMatcher(tmp_path) + + assert matcher.is_ignored(tmp_path / "main.py") is True + assert matcher.is_ignored(tmp_path / "utils.py") is True + assert matcher.is_ignored(sub_dir / "helper.py") is True + assert matcher.is_ignored(sub_dir / "test_main.py") is True + + def test_anchored_pattern(self, tmp_path): + (tmp_path / ".gitignore").write_text("/build\n") + matcher = GitignoreMatcher(tmp_path) + + assert matcher.is_ignored(tmp_path / "build" / "output") is True + assert matcher.is_ignored(tmp_path / "src" / "build" / "out") is False + + def test_ignore_directory_ending_slash(self, tmp_path): + (tmp_path / ".gitignore").write_text("node_modules/\n") + matcher = GitignoreMatcher(tmp_path) + + assert matcher.is_ignored(tmp_path / "node_modules" / "package.json") is True + assert matcher.is_ignored(tmp_path / "other" / "node_modules" / "file") is True + + def test_get_matcher_returns_none_when_no_gitignore(self, tmp_path): + result = get_gitignore_matcher(tmp_path / "nonexistent") + assert result is None + + def test_get_matcher_returns_matcher(self, tmp_path): + (tmp_path / ".gitignore").write_text("*.pyc") + result = get_gitignore_matcher(tmp_path) + assert result is not None + assert result.is_ignored(tmp_path / "file.pyc") is True + + def test_comments_ignored(self, tmp_path): + (tmp_path / ".gitignore").write_text( + "# this is a comment\n*.log\n # another comment\n" + ) + matcher = GitignoreMatcher(tmp_path) + assert matcher.is_ignored(tmp_path / "debug.log") is True + + def test_whitespace_lines(self, tmp_path): + (tmp_path / ".gitignore").write_text(" *.log\n") + matcher = GitignoreMatcher(tmp_path) + assert matcher.is_ignored(tmp_path / " debug.log") is True + assert matcher.is_ignored(tmp_path / "debug.log") is False + + def test_double_star_pattern(self, tmp_path): + (tmp_path / ".gitignore").write_text("**/*.pyc\n") + matcher = GitignoreMatcher(tmp_path) + + assert matcher.is_ignored(tmp_path / "a.pyc") is True + assert matcher.is_ignored(tmp_path / "src" / "a.pyc") is True + assert matcher.is_ignored(tmp_path / "src" / "deep" / "a.pyc") is True + + def test_no_gitignore_file(self, tmp_path): + matcher = GitignoreMatcher(tmp_path) + assert matcher.is_ignored(tmp_path / "file.py") is False + assert matcher.is_ignored(tmp_path / "anything") is False