feat(chunker): 使用 tree-sitter 实现语义感知代码分块

This commit is contained in:
songsenand 2026-04-16 21:44:45 +08:00
parent 9aee24979f
commit fdfe2fa97d
15 changed files with 1099 additions and 204 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
**/__pycache__/
models/* models/*
uv.lock uv.lock

View File

@ -63,19 +63,20 @@
- **Python 优先**:原生 Python SDK类型提示完善 - **Python 优先**:原生 Python SDK类型提示完善
- **性能优秀**:基于 Apache Arrow查询速度快 - **性能优秀**:基于 Apache Arrow查询速度快
### 2.3 为什么选择 langchain-text-splitters而非 chonkie ### 2.3 为什么选择 tree-sitter
最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现 使用 `tree-sitter` 实现 AST 级别的语义感知分块
| 方案 | 优势 | 劣势 | | 方案 | 优势 | 劣势 |
|------|------|------| |------|------|------|
| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer初始化慢 | | tree-sitter | AST 感知、语义完整、不跨类/函数切割 | 依赖语言 parser |
| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 | | langchain-text-splitters | 轻量、规则简单 | chunk_size 优先,会切割语义单元 |
**最终选择 langchain-text-splitters** **最终选择 tree-sitter**
- 对于 RAG 场景,分块精度不是最关键因素 - 代码以**类、函数**等完整语义单元为最小块
- langchain 的方案更加轻量,启动更快 - 超过 token 限制时,递归在子节点级别拆分
- 代码按自然段落(函数、类)边界切分,效果足够好 - 支持 40+ 编程语言(通过 tree-sitter-languages
- 底层通过 ctypes 加载 parser绕过语言绑定兼容性问题
### 2.4 为什么使用 Qwen3-Embedding-0.6B ### 2.4 为什么使用 Qwen3-Embedding-0.6B
@ -187,7 +188,7 @@
| 模块 | 职责 | 依赖 | | 模块 | 职责 | 依赖 |
|------|------|------| |------|------|------|
| `cli.py` | 命令行入口,参数解析 | click | | `cli.py` | 命令行入口,参数解析 | click |
| `chunker.py` | 代码分块处理 | langchain-text-splitters | | `chunker.py` | 代码分块处理 | tree-sitter, tree-sitter-languages |
| `embedder.py` | 文本向量化 | sentence-transformers | | `embedder.py` | 文本向量化 | sentence-transformers |
| `db.py` | 向量数据库操作 | lancedb, pyarrow | | `db.py` | 向量数据库操作 | lancedb, pyarrow |
| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db | | `commands/*.py` | 各命令实现 | cli, chunker, embedder, db |
@ -199,12 +200,13 @@
### 4.1 代码分块算法 ### 4.1 代码分块算法
**实现**:使用 langchain 的 `RecursiveCharacterTextSplitter`,支持 29 种编程语言的语法感知分块 **实现**:使用 tree-sitter 进行 AST 级别的语义分块,支持 40+ 编程语言
**分块参数** **分块策略**
- `chunk_size=2000`:每块约 2000 字符 - **源码文件**:使用 tree-sitter 解析为 AST`class_definition`、`function_definition` 等语义节点提取完整块
- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯 - **非语义单元文件**(如 Markdown使用 tiktoken 按行分块
- `separators`:按语言自动确定,按语义优先级(类 > 函数 > 空行段落 > 换行 > 空格 > 字符)递减 - **超大语义单元**>4000 tokens递归在子节点级别拆分保证每个块不超过大小限制
- **Parser 加载**:通过 ctypes 从 `languages.so` 直接加载,绕过语言绑定兼容性问题
### 4.2 向量化策略 ### 4.2 向量化策略

View File

@ -11,7 +11,7 @@
- 🚀 **零配置**:嵌入式数据库,安装即用 - 🚀 **零配置**:嵌入式数据库,安装即用
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器 - 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
- 🔍 **语义搜索**:通过自然语言查询相关代码片段 - 🔍 **语义搜索**:通过自然语言查询相关代码片段
- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 29 种编程语言 - 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 40+ 种编程语言
- ⚡ **极速响应**:搜索延迟 < 100ms - ⚡ **极速响应**:搜索延迟 < 100ms
- 🤖 **AI 友好**:输出格式专为 AI 消费设计 - 🤖 **AI 友好**:输出格式专为 AI 消费设计
@ -48,17 +48,34 @@ cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/
# 添加文件到知识库 # 添加文件到知识库
uv run ocrag add ./src/main.py uv run ocrag add ./src/main.py
# 递归添加整个目录 # 递归添加整个目录(自动应用 .gitignore + 二进制文件过滤)
uv run ocrag add ./src/ --recursive uv run ocrag add ./src/ --recursive
# 排除特定文件
uv run ocrag add ./src/ --recursive --exclude "**/__pycache__" "**/test_*.py"
# 只添加特定类型的文件
uv run ocrag add ./src/ --recursive --include "*.py" "*.rs"
# 禁用 .gitignore 过滤
uv run ocrag add ./src/ --recursive --no-ignore
# 搜索相关代码 # 搜索相关代码
uv run ocrag search "如何实现用户认证" uv run ocrag search "如何实现用户认证"
# 列出知识库中的文件 # 列出知识库中的文件
uv run ocrag list uv run ocrag list
# 使用通配符筛选列表
uv run ocrag list "*.py" # 仅列出 Python 文件
uv run ocrag list "src/**" # 仅列出 src 目录下的文件
# 删除文件 # 删除文件
uv run ocrag remove ./src/main.py uv run ocrag remove ./src/main.py
# 使用通配符批量删除
uv run ocrag remove "**/test_*.py" # 删除所有测试文件
uv run ocrag remove "src/**/*.js" # 删除 src 下所有 JS 文件
``` ```
### OpenCode 集成 ### OpenCode 集成

View File

@ -23,12 +23,19 @@ description: 代码库 RAG 技能,用于向知识库添加代码或搜索已
2. 获得返回的代码片段后,结合片段内容回答用户 2. 获得返回的代码片段后,结合片段内容回答用户
3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码 3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码
### 管理知识库
- `rag_list` 可选传入 pattern 参数进行通配符筛选,如 `"*.py"``"src/**"`
- `rag_remove` 支持通配符模式批量删除,如 `"**/test_*.py"` 删除所有测试文件
## 示例 ## 示例
- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径 - 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径
- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现" - 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现"
- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true - 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true
- 用户:"知识库里有几个测试文件?" → 调用 `rag_list` pattern="**/test_*.py"
- 用户:"把测试文件都删了" → 调用 `rag_remove` pattern="**/test_*.py"
## 注意事项 ## 注意事项
- 搜索结果会显示代码片段的来源文件 - 搜索结果会显示代码片段的来源文件
- 可以通过 top_k 参数调整返回结果数量 - 可以通过 top_k 参数调整返回结果数量
- 添加文件后,搜索会立即包含新添加的内容 - 添加文件后,搜索会立即包含新添加的内容
- remove 和 list 支持 fnmatch 通配符:`*` 匹配任意字符,`**` 匹配路径分隔符

View File

@ -1,19 +1,29 @@
[project] [project]
name = "project" name = "ocrag"
version = "0.1.0" version = "0.1.0"
description = "Add your description here" description = "Local code knowledge base RAG plugin for OpenCode"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"chonkie>=1.6.2", "tree-sitter>=0.21",
"tree-sitter-languages>=1.10",
"click>=8.3.2", "click>=8.3.2",
"lancedb>=0.30.2", "lancedb>=0.30.2",
"pathspec>=0.12.0",
"sentence-transformers>=5.4.1", "sentence-transformers>=5.4.1",
"sentencepiece>=0.2.1", "sentencepiece>=0.2.1",
"tiktoken>=0.12.0", "tiktoken>=0.12.0",
"tokenizers>=0.22.2", "tokenizers>=0.22.2",
"tqdm>=4.66.0",
] ]
[project.scripts]
ocrag = "ocrag.cli:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[dependency-groups] [dependency-groups]
dev = [ dev = [
"pytest>=9.0.3", "pytest>=9.0.3",

View File

@ -1,47 +1,478 @@
""" """
代码分块模块 代码分块模块
使用 langchain_text_splitters 进行语言感知分块 使用 tree-sitter 进行语义感知分块
支持 29 种编程语言的语法感知分块tiktoken 已内置无需额外下载 源码按类/函数等语义单元切分保证每个 chunk 都是完整的语义块
非代码文件使用简单分块
""" """
import ctypes
import os
import re
import warnings
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any
warnings.filterwarnings(
"ignore",
message="int argument support is deprecated",
category=DeprecationWarning,
)
from typing import List, Dict, Any, Optional, Tuple
try: try:
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter import tree_sitter
from tree_sitter import Parser, Language, Node
except ImportError: except ImportError:
raise ImportError( raise ImportError("请安装 tree-sitter: uv pip install tree-sitter")
"请安装 langchain-text-splitters: uv pip install langchain-text-splitters"
)
LANGUAGE_MAP = { LANGUAGE_MAP = {
".py": Language.PYTHON, ".py": "python",
".rs": Language.RUST, ".rs": "rust",
".cpp": Language.CPP, ".cpp": "cpp",
".cc": Language.CPP, ".cc": "cpp",
".cxx": Language.CPP, ".cxx": "cpp",
".go": Language.GO, ".go": "go",
".java": Language.JAVA, ".java": "java",
".js": Language.JS, ".js": "javascript",
".ts": Language.TS, ".ts": "typescript",
".jsx": Language.JS, ".jsx": "javascript",
".tsx": Language.TS, ".tsx": "typescript",
".c": Language.C, ".c": "c",
".h": Language.C, ".h": "cpp",
".hpp": Language.CPP, ".hpp": "cpp",
".rb": Language.RUBY, ".rb": "ruby",
".swift": Language.SWIFT, ".swift": "swift",
".kt": Language.KOTLIN, ".kt": "kotlin",
".md": Language.MARKDOWN, ".sh": "bash",
".txt": None, ".bash": "bash",
".zsh": "bash",
".lua": "lua",
".php": "php",
".scala": "scala",
".cs": "c_sharp",
".r": "r",
".R": "r",
".sql": "sql",
".toml": "toml",
".yaml": "yaml",
".yml": "yaml",
".json": "json",
".xml": "xml",
".html": "html",
".htm": "html",
".css": "css",
".scss": "css",
".less": "css",
".md": "markdown",
".rst": "rst",
".dockerfile": "dockerfile",
".tf": "hcl",
".hs": "haskell",
".ex": "elixir",
".exs": "elixir",
".erl": "erlang",
".ml": "ocaml",
".mli": "ocaml",
".fs": "fsharp",
".fsx": "fsharp",
} }
SEMANTIC_UNITS = {
"python": [
"function_definition",
"class_definition",
"async_function_definition",
"decorated_definition",
],
"rust": [
"function_item",
"struct_item",
"impl_item",
"enum_item",
"trait_item",
"type_item",
"macro_invocation",
"macro_rule",
],
"cpp": [
"class_specifier",
"struct_specifier",
"namespace_definition",
"function_definition",
"method_declaration",
"constructor_initializer",
"template_declaration",
],
"c": [
"struct_specifier",
"union_specifier",
"enum_specifier",
"function_definition",
"type_definition",
],
"go": [
"function_declaration",
"method_declaration",
"type_declaration",
"const_declaration",
"var_declaration",
],
"java": [
"class_declaration",
"interface_declaration",
"enum_declaration",
"record_declaration",
"method_declaration",
],
"javascript": [
"class_declaration",
"function_declaration",
"arrow_function",
"method_definition",
],
"typescript": [
"class_declaration",
"function_declaration",
"arrow_function",
"method_definition",
"interface_declaration",
"type_alias_declaration",
"enum_declaration",
],
"ruby": [
"class",
"module",
"def",
"method",
],
"swift": [
"class_declaration",
"struct_declaration",
"enum_declaration",
"protocol_declaration",
"function_declaration",
"method_declaration",
],
"kotlin": [
"class_declaration",
"object_declaration",
"function_declaration",
"method_declaration",
],
"scala": [
"class_definition",
"object_definition",
"trait_definition",
"function_definition",
"method_definition",
],
"c_sharp": [
"class_declaration",
"struct_declaration",
"interface_declaration",
"enum_declaration",
"method_declaration",
],
"go_mod": ["import_declaration"],
"bash": [
"function_definition",
"command",
],
"lua": [
"function_declaration",
"local_function",
"assignment_statement",
],
"php": [
"class_declaration",
"interface_declaration",
"trait_declaration",
"function_definition",
"method_declaration",
],
"sql": [
"create_table_statement",
"create_index_statement",
"create_view_statement",
"create_trigger_statement",
"procedure_definition",
"function_definition",
],
"toml": ["table"],
"yaml": ["block_mapping"],
"json": ["pair", "array"],
"html": [
"element",
"script_element",
"style_element",
],
"css": ["rule_set", "at_rule"],
"markdown": ["section", "fenced_code_block"],
"rst": ["section", "directive"],
"dockerfile": ["from_instruction", "run_instruction", "cmd_instruction"],
"hcl": ["block", "attribute"],
"haskell": [
"type_declaration",
"data_declaration",
"function_definition",
"class_declaration",
"instance_declaration",
],
"elixir": [
"module_definition",
"function_definition",
"def",
"defp",
],
"erlang": [
"function_definition",
"module_definition",
],
"ocaml": [
"value_definition",
"type_definition",
"module_definition",
"class_definition",
],
"fsharp": [
"class_definition",
"module_definition",
"function_definition",
"type_definition",
],
"r": [
"function_definition",
"assignment",
],
}
def detect_language(file_path: str) -> Language | None: MAX_CHUNK_CHARS = 8000
"""检测文件语言类型""" MIN_CHUNK_CHARS = 100
ext = Path(file_path).suffix.lower()
return LANGUAGE_MAP.get(ext)
class TiktokenCounter:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
import tiktoken
self._enc = tiktoken.encoding_for_model("gpt-4o")
self._initialized = True
def count(self, text: str) -> int:
return len(self._enc.encode(text))
_token_counter = TiktokenCounter()
def _find_languages_so() -> str:
import site
for path in site.getsitepackages() + [site.getusersitepackages()]:
so_path = os.path.join(path, "tree_sitter_languages", "languages.so")
if os.path.exists(so_path):
return so_path
for path in site.getsitepackages():
if os.path.exists(os.path.join(path, "tree_sitter_languages")):
return os.path.join(path, "tree_sitter_languages", "languages.so")
cache_dir = os.path.expanduser("~/.cache/uv/archive-v0")
if os.path.exists(cache_dir):
for entry in os.listdir(cache_dir):
entry_path = os.path.join(
cache_dir, entry, "tree_sitter_languages", "languages.so"
)
if os.path.exists(entry_path):
return entry_path
raise FileNotFoundError(
"无法找到 tree-sitter-languages 的 languages.so。"
" 请确保已安装 tree-sitter-languages。"
)
_lib = None
def _get_lib() -> ctypes.CDLL:
global _lib
if _lib is None:
so_path = _find_languages_so()
_lib = ctypes.CDLL(so_path)
return _lib
_parsers_cache: Dict[str, Parser] = {}
def get_parser(language_name: str) -> Tuple[Parser, Language]:
if language_name in _parsers_cache:
return _parsers_cache[language_name], _parsers_cache[f"_lang_{language_name}"]
lib = _get_lib()
symbol_name = f"tree_sitter_{language_name}"
if not hasattr(lib, symbol_name):
raise ValueError(f"不支持的语言: {language_name}")
func = getattr(lib, symbol_name)
func.argtypes = []
func.restype = ctypes.c_void_p
ptr = func()
lang = Language(ptr) # type: ignore[call-arg]
parser = Parser(lang)
_parsers_cache[language_name] = parser
_parsers_cache[f"_lang_{language_name}"] = lang
return parser, lang
def _is_semantic_unit(node: Node, semantic_types: List[str]) -> bool:
return node.type in semantic_types
def _extract_semantic_chunks(
node: Node,
semantic_types: List[str],
source_bytes: bytes,
parent_context: str = "",
) -> List[Tuple[str, str, int, int]]:
chunks = []
node_type = node.type
if _is_semantic_unit(node, semantic_types):
text = source_bytes[node.start_byte : node.end_byte].decode(
"utf-8", errors="replace"
)
token_count = _token_counter.count(text)
if token_count > 4000:
children_chunks = _split_large_node(
node, semantic_types, source_bytes, parent_context
)
if children_chunks:
chunks.extend(children_chunks)
else:
chunks.append((node_type, text, node.start_byte, node.end_byte))
else:
chunks.append((node_type, text, node.start_byte, node.end_byte))
else:
for child in node.children:
chunks.extend(
_extract_semantic_chunks(
child, semantic_types, source_bytes, parent_context
)
)
return chunks
def _split_large_node(
node: Node,
semantic_types: List[str],
source_bytes: bytes,
parent_context: str,
) -> List[Tuple[str, str, int, int]]:
children_with_semantic = [
(child, child.type in semantic_types)
for child in node.children
if child.type not in ("comment", "preproc_directive", "attribute")
]
if not children_with_semantic:
return []
results = []
current_text_parts = []
current_start = children_with_semantic[0][0].start_byte
for child, is_semantic in children_with_semantic:
if is_semantic:
if current_text_parts:
combined = source_bytes[
current_start : current_text_parts[0][0].start_byte
].decode("utf-8", errors="replace") + "".join(
source_bytes[p[0].start_byte : p[0].end_byte].decode(
"utf-8", errors="replace"
)
for p in current_text_parts
)
if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4:
results.append(
(
"sub_chunk",
combined,
current_start,
current_text_parts[0][0].start_byte,
)
)
current_text_parts = []
current_start = child.start_byte
child_text = source_bytes[child.start_byte : child.end_byte].decode(
"utf-8", errors="replace"
)
if _token_counter.count(child_text) <= 4000:
results.append(
(child.type, child_text, child.start_byte, child.end_byte)
)
else:
nested = _split_large_node(
child, semantic_types, source_bytes, parent_context
)
results.extend(nested)
else:
current_text_parts.append((child, child.type))
if current_text_parts:
combined = source_bytes[
current_start : current_text_parts[0][0].start_byte
].decode("utf-8", errors="replace") + "".join(
source_bytes[p[0].start_byte : p[0].end_byte].decode(
"utf-8", errors="replace"
)
for p in current_text_parts
)
if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4:
results.append(
(
"sub_chunk",
combined,
current_start,
current_text_parts[0][0].start_byte,
)
)
return results
def _simple_chunk_text(content: str, chunk_size: int = 512) -> List[str]:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
lines = content.split("\n")
chunks = []
current = []
current_tokens = 0
for line in lines:
line_tokens = len(enc.encode(line))
if current_tokens + line_tokens > chunk_size and current:
chunks.append("\n".join(current))
current = [line]
current_tokens = line_tokens
else:
current.append(line)
current_tokens += line_tokens
if current:
chunks.append("\n".join(current))
return chunks
def chunk_code( def chunk_code(
@ -50,41 +481,32 @@ def chunk_code(
chunk_size: int = 512, chunk_size: int = 512,
chunk_overlap: int = 50, chunk_overlap: int = 50,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""将代码文件分块,返回块列表 ext = Path(file_path).suffix.lower()
lang_name = LANGUAGE_MAP.get(ext)
使用 langchain_text_splitters 进行语言感知分块 if not lang_name:
支持 29 种编程语言的语法感知分块 texts = _simple_chunk_text(content, chunk_size)
lang_display = "text"
Args:
content: 文件内容
file_path: 文件路径
chunk_size: 分块大小token
chunk_overlap: 重叠大小token
Returns:
分块列表每个块包含 text metadata
"""
language = detect_language(file_path)
if language is None:
# 不支持的语言,使用简单分块
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""],
length_function=len,
)
texts = splitter.split_text(content)
language_name = "text"
else: else:
# 使用语言感知分块 semantic_types = SEMANTIC_UNITS.get(lang_name, [])
splitter = RecursiveCharacterTextSplitter.from_language( source_bytes = content.encode("utf-8")
language=language,
chunk_size=chunk_size, try:
chunk_overlap=chunk_overlap, parser, _ = get_parser(lang_name)
) except Exception:
texts = splitter.split_text(content) texts = _simple_chunk_text(content, chunk_size)
language_name = language.value lang_display = lang_name
else:
tree = parser.parse(source_bytes)
raw_chunks = _extract_semantic_chunks(
tree.root_node, semantic_types, source_bytes
)
if not raw_chunks:
texts = _simple_chunk_text(content, chunk_size)
else:
texts = [text for _, text, _, _ in raw_chunks]
lang_display = lang_name
return [ return [
{ {
@ -92,8 +514,13 @@ def chunk_code(
"metadata": { "metadata": {
"source_file": file_path, "source_file": file_path,
"chunk_index": i, "chunk_index": i,
"language": language_name, "language": lang_display,
}, },
} }
for i, text in enumerate(texts) for i, text in enumerate(texts)
] ]
def detect_language(file_path: str) -> Optional[str]:
ext = Path(file_path).suffix.lower()
return LANGUAGE_MAP.get(ext)

View File

@ -1,4 +1,5 @@
import click import click
from pathlib import Path
from ocrag.commands import add as add_cmd from ocrag.commands import add as add_cmd
from ocrag.commands import search as search_cmd from ocrag.commands import search as search_cmd
from ocrag.commands import remove as remove_cmd from ocrag.commands import remove as remove_cmd
@ -13,11 +14,31 @@ def main():
@main.command() @main.command()
@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True)) @click.argument("paths", nargs=-1, required=True)
@click.option("--recursive", "-r", is_flag=True, help="递归处理目录") @click.option("--recursive", "-r", is_flag=True, help="递归处理目录")
def add(paths, recursive): @click.option(
"""向知识库添加文件或目录""" "--include", "-i", multiple=True, help="只添加匹配的文件glob 模式,可多次指定)"
add_cmd(paths, recursive) )
@click.option(
"--exclude", "-e", multiple=True, help="排除匹配的文件glob 模式,可多次指定)"
)
@click.option("--no-ignore", is_flag=True, help="禁用 .gitignore 过滤")
def add(paths, recursive, include, exclude, no_ignore):
"""向知识库添加文件或目录
默认会自动读取 .gitignore 规则并跳过匹配的文件
目录遍历时自动跳过 .git__pycache__node_modules 等目录
"""
for p in paths:
if not Path(p).exists():
raise click.BadParameter(f"路径不存在: {p}")
add_cmd(
paths,
recursive,
include_patterns=list(include) if include else None,
exclude_patterns=list(exclude) if exclude else None,
use_gitignore=not no_ignore,
)
@main.command() @main.command()
@ -29,16 +50,17 @@ def search(query, top_k):
@main.command() @main.command()
@click.argument("path", type=click.Path()) @click.argument("pattern")
def remove(path): def remove(pattern):
"""从知识库中移除指定文件""" """从知识库中移除匹配模式的所有文件(支持通配符)"""
remove_cmd(path) remove_cmd(pattern)
@main.command() @main.command("list")
def list(): @click.argument("pattern", required=False)
"""列出知识库中的所有条目""" def list_files(pattern):
list_cmd() """列出知识库中的条目(可选通配符筛选)"""
list_cmd(pattern)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,56 +1,143 @@
import os import os
import pathspec
from pathlib import Path from pathlib import Path
from typing import List
from tqdm import tqdm
from ocrag.chunker import chunk_code from ocrag.chunker import chunk_code
from ocrag.embedder import embedder from ocrag.embedder import embedder
from ocrag.db import RagDB from ocrag.db import RagDB
from ocrag.utils import is_text_file
from ocrag.gitignore import get_gitignore_matcher
def collect_files(paths, recursive): def collect_files(
"""收集所有需要处理的文件""" paths,
recursive: bool,
include_patterns: List[str] = None,
exclude_patterns: List[str] = None,
use_gitignore: bool = True,
) -> List[Path]:
"""收集所有需要处理的文件
Args:
paths: 文件或目录路径列表
recursive: 是否递归处理目录
include_patterns: 只处理匹配这些 glob 模式的文件
exclude_patterns: 跳过匹配这些 glob 模式的文件
use_gitignore: 是否应用 .gitignore 规则
"""
files = [] files = []
seen = set()
exclude_spec = None
if exclude_patterns:
exclude_spec = pathspec.PathSpec.from_lines("gitignore", exclude_patterns)
include_spec = None
if include_patterns:
include_spec = pathspec.PathSpec.from_lines("gitignore", include_patterns)
for p in paths: for p in paths:
p = Path(p) p = Path(p).resolve()
if p.is_file(): if p.is_file():
files.append(p) files.append(p)
elif p.is_dir() and recursive: elif p.is_dir() and recursive:
for root, dirs, filenames in os.walk(p): matcher = get_gitignore_matcher(p) if use_gitignore else None
for f in filenames:
files.append(Path(root) / f) for root_dir, dirs, filenames in os.walk(p):
root_dir = Path(root_dir)
dirs[:] = [
d
for d in dirs
if d
not in (
".git",
"__pycache__",
".pytest_cache",
".mypy_cache",
"node_modules",
".venv",
"venv",
".env",
".tox",
".eggs",
)
and not d.endswith(".egg-info")
]
for filename in filenames:
file_path = root_dir / filename
if file_path in seen:
continue
seen.add(file_path)
rel = file_path.relative_to(p)
if not is_text_file(str(file_path)):
continue
if matcher and matcher.is_ignored(file_path):
continue
if exclude_spec and exclude_spec.match_file(str(rel)):
continue
if include_spec and not include_spec.match_file(str(rel)):
continue
files.append(file_path)
elif p.is_dir() and not recursive: elif p.is_dir() and not recursive:
print(f"跳过目录 {p}(未使用 --recursive") print(f"跳过目录 {p}(未使用 --recursive")
return files return files
def run(paths, recursive): def run(
paths,
recursive: bool,
include_patterns: List[str] = None,
exclude_patterns: List[str] = None,
use_gitignore: bool = True,
):
db = RagDB() db = RagDB()
files = collect_files(paths, recursive) files = collect_files(
paths, recursive, include_patterns, exclude_patterns, use_gitignore
)
if not files:
print("没有找到需要处理的文件")
return
total_chunks = 0 total_chunks = 0
for file_path in files: with tqdm(total=len(files), desc="处理文件", unit="file") as pbar:
try: for file_path in files:
content = file_path.read_text(encoding="utf-8") try:
chunks = chunk_code(content, str(file_path)) content = file_path.read_text(encoding="utf-8")
if not chunks: chunks = chunk_code(content, str(file_path))
continue if not chunks:
pbar.update(1)
continue
# 批量 embedding texts = [c["text"] for c in chunks]
texts = [c["text"] for c in chunks] vectors = embedder.embed(texts)
vectors = embedder.embed(texts)
documents = [] documents = []
for chunk, vec in zip(chunks, vectors): for chunk, vec in zip(chunks, vectors):
documents.append( documents.append(
{ {
"text": chunk["text"], "text": chunk["text"],
"vector": vec, "vector": vec,
"metadata": chunk["metadata"], "metadata": chunk["metadata"],
} }
) )
db.add_documents(documents) db.add_documents(documents)
total_chunks += len(documents) total_chunks += len(documents)
print(f"{file_path} -> {len(documents)} 个块") pbar.set_postfix({"": f"{total_chunks}"})
except Exception as e: except Exception as e:
print(f"❌ 处理 {file_path} 失败: {e}") print(f"\n❌ 处理 {file_path} 失败: {e}")
pbar.update(1)
print(f"\n📦 总计添加 {total_chunks} 个块") print(f"\n📦 总计添加 {total_chunks} 个块")

View File

@ -1,14 +1,15 @@
from ocrag.db import RagDB from ocrag.db import RagDB
def run(): def run(pattern: str = None):
db = RagDB() db = RagDB()
sources = db.list_sources() sources = db.list_sources(pattern)
if not sources: if not sources:
print("知识库为空") print("知识库为空")
return return
print("知识库中的文件:") label = f'匹配 "{pattern}"' if pattern else ""
print(f"{label}知识库中的文件:")
for i, source in enumerate(sources, 1): for i, source in enumerate(sources, 1):
print(f"{i}. {source}") print(f"{i}. {source}")

View File

@ -1,9 +1,8 @@
import fnmatch
from ocrag.db import RagDB from ocrag.db import RagDB
def run(path: str): def run(pattern: str):
db = RagDB() db = RagDB()
# 注意当前db.py中的delete_by_source方法尚未完全实现 count = db.delete_by_pattern(pattern)
# 这里先调用以保持接口一致后续需要完善db.py的实现 print(f'已删除 {count} 个匹配 "{pattern}" 的块')
db.delete_by_source(path)
print(f"已删除 {path} 的所有块")

View File

@ -8,43 +8,51 @@ DB_PATH = Path.home() / ".ocrag" / "data.lance"
class RagDB: class RagDB:
def __init__(self, db_path: str = None): def __init__(self, db_path: str = None):
self.path = db_path or str(DB_PATH) import os as _os
self.path = db_path or _os.environ.get("OCRAG_DB_PATH", str(DB_PATH))
self.conn = lancedb.connect(self.path) self.conn = lancedb.connect(self.path)
self._init_table() self._init_table()
def _init_table(self): def _init_table(self):
"""如果表不存在则创建""" """如果表不存在则创建"""
schema = pa.schema(
[
("text", pa.string()),
(
"vector",
pa.list_(pa.float32(), 1024),
), # 1024维 (Qwen3-Embedding-0.6B)
("metadata", pa.string()), # JSON 字符串
]
)
try: try:
self.table = self.conn.open_table("documents") self.table = self.conn.open_table("documents")
except: except:
self.table = self.conn.create_table("documents", schema=schema) empty = pa.table(
{
"text": pa.array([], type=pa.string()),
"vector": pa.array([], type=pa.list_(pa.float32(), 1024)),
"metadata": pa.array([], type=pa.string()),
}
)
self.table = self.conn.create_table("documents", empty)
def add_documents(self, documents: List[Dict[str, Any]]): def add_documents(self, documents: List[Dict[str, Any]]):
"""批量添加文档 """批量添加文档
documents: [{"text": str, "vector": List[float], "metadata": dict}, ...] documents: [{"text": str, "vector": List[float], "metadata": dict}, ...]
""" """
import json import json
import numpy as np
data = [] texts = []
vectors = []
metadatas = []
for doc in documents: for doc in documents:
data.append( texts.append(doc["text"])
{ vec = doc["vector"]
"text": doc["text"], if isinstance(vec, list):
"vector": doc["vector"], vec = np.array(vec, dtype=np.float32)
"metadata": json.dumps(doc.get("metadata", {})), vectors.append(vec)
} metadatas.append(json.dumps(doc.get("metadata", {})))
)
self.table.add(data) arrays = [
pa.array(texts),
pa.array(vectors, type=pa.list_(pa.float32(), 1024)),
pa.array(metadatas),
]
batch = pa.record_batch(arrays, names=["text", "vector", "metadata"])
self.table.add(batch)
def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]: def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
import json import json
@ -96,9 +104,50 @@ class RagDB:
return num_deleted return num_deleted
def list_sources(self) -> List[str]: def delete_by_pattern(self, pattern: str) -> int:
"""列出所有已添加的源文件路径(去重)""" """删除匹配通配符模式的所有源文件的所有块
# 获取所有 metadata提取 source_file
Args:
pattern: fnmatch 模式 "*.py", "src/**/*.js", "**/test_*.py"
Returns:
删除的块数量
"""
import fnmatch
import json
df = self.table.to_pandas()
if df.empty:
return 0
indices_to_keep = []
for idx, row in df.iterrows():
try:
meta = json.loads(row["metadata"])
source = meta.get("source_file", "")
if not fnmatch.fnmatch(source, pattern):
indices_to_keep.append(idx)
except (json.JSONDecodeError, KeyError):
indices_to_keep.append(idx)
num_deleted = len(df) - len(indices_to_keep)
if num_deleted == 0:
return 0
df_remaining = df.loc[indices_to_keep]
self.conn.drop_table("documents")
self.table = self.conn.create_table("documents", df_remaining)
return num_deleted
def list_sources(self, pattern: str = None) -> List[str]:
"""列出所有已添加的源文件路径(去重)
Args:
pattern: 可选的 fnmatch 模式筛选匹配的文件
"""
import fnmatch
df = self.table.to_pandas() df = self.table.to_pandas()
if df.empty: if df.empty:
return [] return []
@ -108,5 +157,6 @@ class RagDB:
for meta_str in df["metadata"]: for meta_str in df["metadata"]:
meta = json.loads(meta_str) meta = json.loads(meta_str)
if "source_file" in meta: if "source_file" in meta:
sources.add(meta["source_file"]) if pattern is None or fnmatch.fnmatch(meta["source_file"], pattern):
sources.add(meta["source_file"])
return sorted(sources) return sorted(sources)

73
src/ocrag/gitignore.py Normal file
View File

@ -0,0 +1,73 @@
import os
import fnmatch
import pathspec
from pathlib import Path
from typing import List, Tuple, Set, Optional
class GitignoreMatcher:
"""多层 .gitignore 解析器
支持从项目根目录向下遍历收集所有 .gitignore 文件
每个 .gitignore 中的规则只对该文件所在目录及其子目录生效
"""
def __init__(self, root: Path):
self.root = root.resolve()
self._specs: List[Tuple[Path, pathspec.PathSpec]] = []
self._load()
def _load(self):
"""遍历目录树,收集所有 .gitignore 并构建 pathspec"""
for dirpath, dirnames, filenames in os.walk(self.root):
dirpath = Path(dirpath)
gi_path = dirpath / ".gitignore"
if gi_path.exists() and gi_path.is_file():
try:
lines = gi_path.read_text(
encoding="utf-8", errors="ignore"
).splitlines()
spec = pathspec.PathSpec.from_lines("gitignore", lines)
self._specs.append((dirpath, spec))
except Exception:
pass
dirnames[:] = [d for d in dirnames if d not in (".git", "__pycache__")]
def is_ignored(self, file_path: Path) -> bool:
"""判断文件是否被任一层 .gitignore 忽略"""
abs_path = file_path.resolve()
for dirpath, spec in self._specs:
try:
rel = abs_path.relative_to(dirpath)
except ValueError:
continue
if spec.match_file(str(rel)):
return True
return False
def find_project_root(start_path: Path) -> Optional[Path]:
"""从 start_path 向上查找包含 .gitignore 的目录作为项目根"""
current = start_path.resolve()
if start_path.is_file():
current = current.parent
while True:
if (current / ".gitignore").exists():
return current
parent = current.parent
if parent == current:
return None
current = parent
def get_gitignore_matcher(root: Path) -> Optional[GitignoreMatcher]:
"""为给定目录创建 GitignoreMatcher"""
gi_path = find_project_root(root)
if gi_path is None:
return None
return GitignoreMatcher(gi_path)

View File

@ -33,44 +33,23 @@ def format_size(size_bytes: int) -> str:
def is_text_file(file_path: str) -> bool: def is_text_file(file_path: str) -> bool:
"""检查文件是否为文本文件""" """检查文件是否为文本文件
text_extensions = {
".py", 通过读取文件前 1KB 内容检测
".js", - 包含 null byte (\\x00) 则判定为二进制
".ts", - 能够以 UTF-8 解码则判定为文本
".jsx", """
".tsx", try:
".java", with open(file_path, "rb") as f:
".c", chunk = f.read(1024)
".cpp", if not chunk:
".h", return True
".hpp", if b"\x00" in chunk:
".go", return False
".rs", chunk.decode("utf-8")
".rb", return True
".php", except (UnicodeDecodeError, OSError):
".swift", return False
".kt",
".scala",
".md",
".txt",
".json",
".yaml",
".yml",
".xml",
".html",
".css",
".sql",
".sh",
".bash",
".zsh",
".ps1",
".bat",
".vue",
".svelte",
}
ext = Path(file_path).suffix.lower()
return ext in text_extensions
def get_project_root() -> Path: def get_project_root() -> Path:

View File

@ -3,6 +3,7 @@ import os
from click.testing import CliRunner from click.testing import CliRunner
from ocrag.cli import main from ocrag.cli import main
from ocrag.db import RagDB from ocrag.db import RagDB
from ocrag.commands.add import collect_files
@pytest.fixture @pytest.fixture
@ -12,17 +13,15 @@ def runner():
@pytest.fixture @pytest.fixture
def setup_test_env(tmpdir): def setup_test_env(tmpdir):
# Create test files
os.makedirs(os.path.join(tmpdir, "test_dir")) os.makedirs(os.path.join(tmpdir, "test_dir"))
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f: with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
f.write("def test_func():\n pass") f.write("def test_func():\n pass")
with open(os.path.join(tmpdir, "test_file.py"), "w") as f: with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
f.write("print('Hello')") f.write("print('Hello')")
return tmpdir return str(tmpdir)
def test_add_command(runner, setup_test_env, tmpdir): def test_add_command(runner, setup_test_env, tmpdir):
# Use temporary DB
db_path = os.path.join(tmpdir, "test_db.lance") db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path os.environ["OCRAG_DB_PATH"] = db_path
@ -33,7 +32,6 @@ def test_add_command(runner, setup_test_env, tmpdir):
def test_search_command(runner, setup_test_env, tmpdir): def test_search_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance") db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path os.environ["OCRAG_DB_PATH"] = db_path
@ -42,14 +40,12 @@ def test_search_command(runner, setup_test_env, tmpdir):
) )
assert add_result.exit_code == 0 assert add_result.exit_code == 0
# Search
result = runner.invoke(main, ["search", "Hello"]) result = runner.invoke(main, ["search", "Hello"])
assert "Hello" in result.output assert "Hello" in result.output
assert "来源: " in result.output assert "来源: " in result.output
def test_list_command(runner, setup_test_env, tmpdir): def test_list_command(runner, setup_test_env, tmpdir):
# Use temporary DB and add a file first
db_path = os.path.join(tmpdir, "test_db.lance") db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path os.environ["OCRAG_DB_PATH"] = db_path
@ -58,6 +54,104 @@ def test_list_command(runner, setup_test_env, tmpdir):
) )
assert add_result.exit_code == 0 assert add_result.exit_code == 0
# List
result = runner.invoke(main, ["list"]) result = runner.invoke(main, ["list"])
assert "test_file.py" in result.output assert "test_file.py" in result.output
def test_remove_wildcard(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
result = runner.invoke(main, ["remove", "*.py"])
assert result.exit_code == 0
assert "已删除" in result.output
list_result = runner.invoke(main, ["list"])
assert "test_file.py" not in list_result.output
def test_list_wildcard(runner, setup_test_env, tmpdir):
db_path = os.path.join(tmpdir, "test_db.lance")
os.environ["OCRAG_DB_PATH"] = db_path
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
result = runner.invoke(main, ["list", "*.py"])
assert result.exit_code == 0
assert "test_file.py" in result.output
assert "file1.py" in result.output
class TestCollectFiles:
def test_exclude_pattern(self, setup_test_env):
files = collect_files(
[setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"]
)
names = [f.name for f in files]
assert "test_file.py" not in names
assert "file1.py" in names
def test_include_pattern(self, setup_test_env):
files = collect_files(
[setup_test_env], recursive=True, include_patterns=["*.py"]
)
names = [f.name for f in files]
assert "test_file.py" in names
assert "file1.py" in names
def test_binary_skip(self, tmpdir):
with open(os.path.join(tmpdir, "binary.bin"), "wb") as f:
f.write(b"\x00\x01\x02binary")
with open(os.path.join(tmpdir, "text.py"), "w") as f:
f.write("print('hello')")
files = collect_files([str(tmpdir)], recursive=True)
names = [f.name for f in files]
assert "binary.bin" not in names
assert "text.py" in names
def test_gitignore(self, tmpdir):
os.makedirs(os.path.join(tmpdir, "src"))
with open(os.path.join(tmpdir, "src", "main.py"), "w") as f:
f.write("print('hello')")
with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f:
f.write("def test(): pass")
with open(os.path.join(tmpdir, ".gitignore"), "w") as f:
f.write("**/test_*.py\n")
files = collect_files([os.path.join(tmpdir, "src")], recursive=True)
names = [f.name for f in files]
assert "main.py" in names
assert "test_main.py" not in names
def test_no_ignore_disables_gitignore(self, setup_test_env):
with open(os.path.join(setup_test_env, ".gitignore"), "w") as f:
f.write("test_file.py\n")
files = collect_files([setup_test_env], recursive=True, use_gitignore=False)
names = [f.name for f in files]
assert "test_file.py" in names
def test_combined_filters(self, tmpdir):
os.makedirs(os.path.join(tmpdir, "lib"))
with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f:
f.write("def main(): pass")
with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f:
f.write("def test(): pass")
with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f:
f.write(b"\x00binary")
files = collect_files(
[os.path.join(tmpdir, "lib")],
recursive=True,
include_patterns=["*.py"],
exclude_patterns=["**/test_*.py"],
)
names = [f.name for f in files]
assert "main.py" in names
assert "test_main.py" not in names
assert "data.bin" not in names

126
tests/unit/test_utils.py Normal file
View File

@ -0,0 +1,126 @@
import os
import pytest
from pathlib import Path
from ocrag.gitignore import GitignoreMatcher, get_gitignore_matcher, find_project_root
from ocrag.utils import is_text_file
class TestIsTextFile:
def test_python_file_is_text(self, tmp_path):
f = tmp_path / "test.py"
f.write_text("def foo():\n pass")
assert is_text_file(str(f)) is True
def test_binary_file_is_not_text(self, tmp_path):
f = tmp_path / "test.bin"
f.write_bytes(b"\x00\x01\x02\x03binary")
assert is_text_file(str(f)) is False
def test_empty_file_is_text(self, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("")
assert is_text_file(str(f)) is True
def test_utf8_with_non_ascii_is_text(self, tmp_path):
f = tmp_path / "unicode.txt"
f.write_text("中文内容测试\n日本語\n한국어")
assert is_text_file(str(f)) is True
def test_invalid_utf8_is_not_text(self, tmp_path):
f = tmp_path / "invalid.txt"
f.write_bytes(b"hello\x80world")
assert is_text_file(str(f)) is False
class TestFindProjectRoot:
def test_finds_gitignore(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.pyc")
sub = tmp_path / "src" / "deep"
sub.mkdir(parents=True)
found = find_project_root(sub)
assert found == tmp_path
def test_no_gitignore_returns_none(self, tmp_path):
sub = tmp_path / "src"
sub.mkdir(parents=True)
found = find_project_root(sub)
assert found is None
class TestGitignoreMatcher:
def test_ignore_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.pyc\n__pycache__/\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "file.pyc") is True
assert matcher.is_ignored(tmp_path / "file.py") is False
assert matcher.is_ignored(tmp_path / "__pycache__" / "cache.pyc") is True
def test_negation_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.log\n!important.log\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "debug.log") is True
assert matcher.is_ignored(tmp_path / "important.log") is False
def test_subdirectory_gitignore(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.py\n")
sub_dir = tmp_path / "src"
sub_dir.mkdir()
(sub_dir / ".gitignore").write_text("test_*.py\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "main.py") is True
assert matcher.is_ignored(tmp_path / "utils.py") is True
assert matcher.is_ignored(sub_dir / "helper.py") is True
assert matcher.is_ignored(sub_dir / "test_main.py") is True
def test_anchored_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("/build\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "build" / "output") is True
assert matcher.is_ignored(tmp_path / "src" / "build" / "out") is False
def test_ignore_directory_ending_slash(self, tmp_path):
(tmp_path / ".gitignore").write_text("node_modules/\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "node_modules" / "package.json") is True
assert matcher.is_ignored(tmp_path / "other" / "node_modules" / "file") is True
def test_get_matcher_returns_none_when_no_gitignore(self, tmp_path):
result = get_gitignore_matcher(tmp_path / "nonexistent")
assert result is None
def test_get_matcher_returns_matcher(self, tmp_path):
(tmp_path / ".gitignore").write_text("*.pyc")
result = get_gitignore_matcher(tmp_path)
assert result is not None
assert result.is_ignored(tmp_path / "file.pyc") is True
def test_comments_ignored(self, tmp_path):
(tmp_path / ".gitignore").write_text(
"# this is a comment\n*.log\n # another comment\n"
)
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "debug.log") is True
def test_whitespace_lines(self, tmp_path):
(tmp_path / ".gitignore").write_text(" *.log\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / " debug.log") is True
assert matcher.is_ignored(tmp_path / "debug.log") is False
def test_double_star_pattern(self, tmp_path):
(tmp_path / ".gitignore").write_text("**/*.pyc\n")
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "a.pyc") is True
assert matcher.is_ignored(tmp_path / "src" / "a.pyc") is True
assert matcher.is_ignored(tmp_path / "src" / "deep" / "a.pyc") is True
def test_no_gitignore_file(self, tmp_path):
matcher = GitignoreMatcher(tmp_path)
assert matcher.is_ignored(tmp_path / "file.py") is False
assert matcher.is_ignored(tmp_path / "anything") is False