feat(chunker): 使用 tree-sitter 实现语义感知代码分块
This commit is contained in:
parent
9aee24979f
commit
fdfe2fa97d
|
|
@ -1,3 +1,4 @@
|
||||||
|
**/__pycache__/
|
||||||
models/*
|
models/*
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -63,19 +63,20 @@
|
||||||
- **Python 优先**:原生 Python SDK,类型提示完善
|
- **Python 优先**:原生 Python SDK,类型提示完善
|
||||||
- **性能优秀**:基于 Apache Arrow,查询速度快
|
- **性能优秀**:基于 Apache Arrow,查询速度快
|
||||||
|
|
||||||
### 2.3 为什么选择 langchain-text-splitters(而非 chonkie)
|
### 2.3 为什么选择 tree-sitter
|
||||||
|
|
||||||
最初设计考虑使用 `chonkie` 实现语法感知分块,但经过实践发现:
|
使用 `tree-sitter` 实现 AST 级别的语义感知分块:
|
||||||
|
|
||||||
| 方案 | 优势 | 劣势 |
|
| 方案 | 优势 | 劣势 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| chonkie | 语法感知分块 | 需要额外下载 GPT-2 tokenizer,初始化慢 |
|
| tree-sitter | AST 感知、语义完整、不跨类/函数切割 | 依赖语言 parser |
|
||||||
| langchain-text-splitters | 无需额外模型、纯规则、速度快 | 非语法感知 |
|
| langchain-text-splitters | 轻量、规则简单 | chunk_size 优先,会切割语义单元 |
|
||||||
|
|
||||||
**最终选择 langchain-text-splitters**:
|
**最终选择 tree-sitter**:
|
||||||
- 对于 RAG 场景,分块精度不是最关键因素
|
- 代码以**类、函数**等完整语义单元为最小块
|
||||||
- langchain 的方案更加轻量,启动更快
|
- 超过 token 限制时,递归在子节点级别拆分
|
||||||
- 代码按自然段落(函数、类)边界切分,效果足够好
|
- 支持 40+ 编程语言(通过 tree-sitter-languages)
|
||||||
|
- 底层通过 ctypes 加载 parser,绕过语言绑定兼容性问题
|
||||||
|
|
||||||
### 2.4 为什么使用 Qwen3-Embedding-0.6B
|
### 2.4 为什么使用 Qwen3-Embedding-0.6B
|
||||||
|
|
||||||
|
|
@ -187,7 +188,7 @@
|
||||||
| 模块 | 职责 | 依赖 |
|
| 模块 | 职责 | 依赖 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| `cli.py` | 命令行入口,参数解析 | click |
|
| `cli.py` | 命令行入口,参数解析 | click |
|
||||||
| `chunker.py` | 代码分块处理 | langchain-text-splitters |
|
| `chunker.py` | 代码分块处理 | tree-sitter, tree-sitter-languages |
|
||||||
| `embedder.py` | 文本向量化 | sentence-transformers |
|
| `embedder.py` | 文本向量化 | sentence-transformers |
|
||||||
| `db.py` | 向量数据库操作 | lancedb, pyarrow |
|
| `db.py` | 向量数据库操作 | lancedb, pyarrow |
|
||||||
| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db |
|
| `commands/*.py` | 各命令实现 | cli, chunker, embedder, db |
|
||||||
|
|
@ -199,12 +200,13 @@
|
||||||
|
|
||||||
### 4.1 代码分块算法
|
### 4.1 代码分块算法
|
||||||
|
|
||||||
**实现**:使用 langchain 的 `RecursiveCharacterTextSplitter`,支持 29 种编程语言的语法感知分块。
|
**实现**:使用 tree-sitter 进行 AST 级别的语义分块,支持 40+ 编程语言。
|
||||||
|
|
||||||
**分块参数**:
|
**分块策略**:
|
||||||
- `chunk_size=2000`:每块约 2000 字符
|
- **源码文件**:使用 tree-sitter 解析为 AST,按 `class_definition`、`function_definition` 等语义节点提取完整块
|
||||||
- `chunk_overlap=200`:块之间保留 200 字符重叠,保证上下文连贯
|
- **非语义单元文件**(如 Markdown):使用 tiktoken 按行分块
|
||||||
- `separators`:按语言自动确定,按语义优先级(类 > 函数 > 空行段落 > 换行 > 空格 > 字符)递减
|
- **超大语义单元**(>4000 tokens):递归在子节点级别拆分,保证每个块不超过大小限制
|
||||||
|
- **Parser 加载**:通过 ctypes 从 `languages.so` 直接加载,绕过语言绑定兼容性问题
|
||||||
|
|
||||||
### 4.2 向量化策略
|
### 4.2 向量化策略
|
||||||
|
|
||||||
|
|
|
||||||
21
README.md
21
README.md
|
|
@ -11,7 +11,7 @@
|
||||||
- 🚀 **零配置**:嵌入式数据库,安装即用
|
- 🚀 **零配置**:嵌入式数据库,安装即用
|
||||||
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
|
- 🔒 **本地优先**:所有数据存储在本地,不上传到外部服务器
|
||||||
- 🔍 **语义搜索**:通过自然语言查询相关代码片段
|
- 🔍 **语义搜索**:通过自然语言查询相关代码片段
|
||||||
- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 29 种编程语言
|
- 📦 **多语言支持**:支持 Python、Rust、C++、Go、JavaScript、TypeScript 等 40+ 种编程语言
|
||||||
- ⚡ **极速响应**:搜索延迟 < 100ms
|
- ⚡ **极速响应**:搜索延迟 < 100ms
|
||||||
- 🤖 **AI 友好**:输出格式专为 AI 消费设计
|
- 🤖 **AI 友好**:输出格式专为 AI 消费设计
|
||||||
|
|
||||||
|
|
@ -48,17 +48,34 @@ cp opencode-skill/SKILL.md ~/.config/opencode/skills/ocrag/
|
||||||
# 添加文件到知识库
|
# 添加文件到知识库
|
||||||
uv run ocrag add ./src/main.py
|
uv run ocrag add ./src/main.py
|
||||||
|
|
||||||
# 递归添加整个目录
|
# 递归添加整个目录(自动应用 .gitignore + 二进制文件过滤)
|
||||||
uv run ocrag add ./src/ --recursive
|
uv run ocrag add ./src/ --recursive
|
||||||
|
|
||||||
|
# 排除特定文件
|
||||||
|
uv run ocrag add ./src/ --recursive --exclude "**/__pycache__" "**/test_*.py"
|
||||||
|
|
||||||
|
# 只添加特定类型的文件
|
||||||
|
uv run ocrag add ./src/ --recursive --include "*.py" "*.rs"
|
||||||
|
|
||||||
|
# 禁用 .gitignore 过滤
|
||||||
|
uv run ocrag add ./src/ --recursive --no-ignore
|
||||||
|
|
||||||
# 搜索相关代码
|
# 搜索相关代码
|
||||||
uv run ocrag search "如何实现用户认证"
|
uv run ocrag search "如何实现用户认证"
|
||||||
|
|
||||||
# 列出知识库中的文件
|
# 列出知识库中的文件
|
||||||
uv run ocrag list
|
uv run ocrag list
|
||||||
|
|
||||||
|
# 使用通配符筛选列表
|
||||||
|
uv run ocrag list "*.py" # 仅列出 Python 文件
|
||||||
|
uv run ocrag list "src/**" # 仅列出 src 目录下的文件
|
||||||
|
|
||||||
# 删除文件
|
# 删除文件
|
||||||
uv run ocrag remove ./src/main.py
|
uv run ocrag remove ./src/main.py
|
||||||
|
|
||||||
|
# 使用通配符批量删除
|
||||||
|
uv run ocrag remove "**/test_*.py" # 删除所有测试文件
|
||||||
|
uv run ocrag remove "src/**/*.js" # 删除 src 下所有 JS 文件
|
||||||
```
|
```
|
||||||
|
|
||||||
### OpenCode 集成
|
### OpenCode 集成
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,19 @@ description: 代码库 RAG 技能,用于向知识库添加代码或搜索已
|
||||||
2. 获得返回的代码片段后,结合片段内容回答用户
|
2. 获得返回的代码片段后,结合片段内容回答用户
|
||||||
3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码
|
3. 如果结果不相关,可以告知用户未找到,并建议添加更多代码
|
||||||
|
|
||||||
|
### 管理知识库
|
||||||
|
- `rag_list` 可选传入 pattern 参数进行通配符筛选,如 `"*.py"` 或 `"src/**"`
|
||||||
|
- `rag_remove` 支持通配符模式批量删除,如 `"**/test_*.py"` 删除所有测试文件
|
||||||
|
|
||||||
## 示例
|
## 示例
|
||||||
- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径
|
- 用户:"帮我记住这个文件的知识" → 调用 `rag_add` 传入当前文件路径
|
||||||
- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现"
|
- 用户:"认证模块是怎么实现的?" → 调用 `rag_search` query="认证模块实现"
|
||||||
- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true
|
- 用户:"把 ./src 目录下所有 Python 文件加入知识库" → 调用 `rag_add` paths=["./src"] recursive=true
|
||||||
|
- 用户:"知识库里有几个测试文件?" → 调用 `rag_list` pattern="**/test_*.py"
|
||||||
|
- 用户:"把测试文件都删了" → 调用 `rag_remove` pattern="**/test_*.py"
|
||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
- 搜索结果会显示代码片段的来源文件
|
- 搜索结果会显示代码片段的来源文件
|
||||||
- 可以通过 top_k 参数调整返回结果数量
|
- 可以通过 top_k 参数调整返回结果数量
|
||||||
- 添加文件后,搜索会立即包含新添加的内容
|
- 添加文件后,搜索会立即包含新添加的内容
|
||||||
|
- remove 和 list 支持 fnmatch 通配符:`*` 匹配任意字符,`**` 匹配路径分隔符
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,29 @@
|
||||||
[project]
|
[project]
|
||||||
name = "project"
|
name = "ocrag"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Local code knowledge base RAG plugin for OpenCode"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chonkie>=1.6.2",
|
"tree-sitter>=0.21",
|
||||||
|
"tree-sitter-languages>=1.10",
|
||||||
"click>=8.3.2",
|
"click>=8.3.2",
|
||||||
"lancedb>=0.30.2",
|
"lancedb>=0.30.2",
|
||||||
|
"pathspec>=0.12.0",
|
||||||
"sentence-transformers>=5.4.1",
|
"sentence-transformers>=5.4.1",
|
||||||
"sentencepiece>=0.2.1",
|
"sentencepiece>=0.2.1",
|
||||||
"tiktoken>=0.12.0",
|
"tiktoken>=0.12.0",
|
||||||
"tokenizers>=0.22.2",
|
"tokenizers>=0.22.2",
|
||||||
|
"tqdm>=4.66.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
ocrag = "ocrag.cli:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=9.0.3",
|
"pytest>=9.0.3",
|
||||||
|
|
|
||||||
|
|
@ -1,47 +1,478 @@
|
||||||
"""
|
"""
|
||||||
代码分块模块
|
代码分块模块
|
||||||
|
|
||||||
使用 langchain_text_splitters 进行语言感知分块。
|
使用 tree-sitter 进行语义感知分块。
|
||||||
支持 29 种编程语言的语法感知分块,tiktoken 已内置,无需额外下载。
|
源码按类/函数等语义单元切分,保证每个 chunk 都是完整的语义块。
|
||||||
|
非代码文件使用简单分块。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="int argument support is deprecated",
|
||||||
|
category=DeprecationWarning,
|
||||||
|
)
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
|
import tree_sitter
|
||||||
|
from tree_sitter import Parser, Language, Node
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("请安装 tree-sitter: uv pip install tree-sitter")
|
||||||
"请安装 langchain-text-splitters: uv pip install langchain-text-splitters"
|
|
||||||
)
|
|
||||||
|
|
||||||
LANGUAGE_MAP = {
|
LANGUAGE_MAP = {
|
||||||
".py": Language.PYTHON,
|
".py": "python",
|
||||||
".rs": Language.RUST,
|
".rs": "rust",
|
||||||
".cpp": Language.CPP,
|
".cpp": "cpp",
|
||||||
".cc": Language.CPP,
|
".cc": "cpp",
|
||||||
".cxx": Language.CPP,
|
".cxx": "cpp",
|
||||||
".go": Language.GO,
|
".go": "go",
|
||||||
".java": Language.JAVA,
|
".java": "java",
|
||||||
".js": Language.JS,
|
".js": "javascript",
|
||||||
".ts": Language.TS,
|
".ts": "typescript",
|
||||||
".jsx": Language.JS,
|
".jsx": "javascript",
|
||||||
".tsx": Language.TS,
|
".tsx": "typescript",
|
||||||
".c": Language.C,
|
".c": "c",
|
||||||
".h": Language.C,
|
".h": "cpp",
|
||||||
".hpp": Language.CPP,
|
".hpp": "cpp",
|
||||||
".rb": Language.RUBY,
|
".rb": "ruby",
|
||||||
".swift": Language.SWIFT,
|
".swift": "swift",
|
||||||
".kt": Language.KOTLIN,
|
".kt": "kotlin",
|
||||||
".md": Language.MARKDOWN,
|
".sh": "bash",
|
||||||
".txt": None,
|
".bash": "bash",
|
||||||
|
".zsh": "bash",
|
||||||
|
".lua": "lua",
|
||||||
|
".php": "php",
|
||||||
|
".scala": "scala",
|
||||||
|
".cs": "c_sharp",
|
||||||
|
".r": "r",
|
||||||
|
".R": "r",
|
||||||
|
".sql": "sql",
|
||||||
|
".toml": "toml",
|
||||||
|
".yaml": "yaml",
|
||||||
|
".yml": "yaml",
|
||||||
|
".json": "json",
|
||||||
|
".xml": "xml",
|
||||||
|
".html": "html",
|
||||||
|
".htm": "html",
|
||||||
|
".css": "css",
|
||||||
|
".scss": "css",
|
||||||
|
".less": "css",
|
||||||
|
".md": "markdown",
|
||||||
|
".rst": "rst",
|
||||||
|
".dockerfile": "dockerfile",
|
||||||
|
".tf": "hcl",
|
||||||
|
".hs": "haskell",
|
||||||
|
".ex": "elixir",
|
||||||
|
".exs": "elixir",
|
||||||
|
".erl": "erlang",
|
||||||
|
".ml": "ocaml",
|
||||||
|
".mli": "ocaml",
|
||||||
|
".fs": "fsharp",
|
||||||
|
".fsx": "fsharp",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SEMANTIC_UNITS = {
|
||||||
|
"python": [
|
||||||
|
"function_definition",
|
||||||
|
"class_definition",
|
||||||
|
"async_function_definition",
|
||||||
|
"decorated_definition",
|
||||||
|
],
|
||||||
|
"rust": [
|
||||||
|
"function_item",
|
||||||
|
"struct_item",
|
||||||
|
"impl_item",
|
||||||
|
"enum_item",
|
||||||
|
"trait_item",
|
||||||
|
"type_item",
|
||||||
|
"macro_invocation",
|
||||||
|
"macro_rule",
|
||||||
|
],
|
||||||
|
"cpp": [
|
||||||
|
"class_specifier",
|
||||||
|
"struct_specifier",
|
||||||
|
"namespace_definition",
|
||||||
|
"function_definition",
|
||||||
|
"method_declaration",
|
||||||
|
"constructor_initializer",
|
||||||
|
"template_declaration",
|
||||||
|
],
|
||||||
|
"c": [
|
||||||
|
"struct_specifier",
|
||||||
|
"union_specifier",
|
||||||
|
"enum_specifier",
|
||||||
|
"function_definition",
|
||||||
|
"type_definition",
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
"function_declaration",
|
||||||
|
"method_declaration",
|
||||||
|
"type_declaration",
|
||||||
|
"const_declaration",
|
||||||
|
"var_declaration",
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
"class_declaration",
|
||||||
|
"interface_declaration",
|
||||||
|
"enum_declaration",
|
||||||
|
"record_declaration",
|
||||||
|
"method_declaration",
|
||||||
|
],
|
||||||
|
"javascript": [
|
||||||
|
"class_declaration",
|
||||||
|
"function_declaration",
|
||||||
|
"arrow_function",
|
||||||
|
"method_definition",
|
||||||
|
],
|
||||||
|
"typescript": [
|
||||||
|
"class_declaration",
|
||||||
|
"function_declaration",
|
||||||
|
"arrow_function",
|
||||||
|
"method_definition",
|
||||||
|
"interface_declaration",
|
||||||
|
"type_alias_declaration",
|
||||||
|
"enum_declaration",
|
||||||
|
],
|
||||||
|
"ruby": [
|
||||||
|
"class",
|
||||||
|
"module",
|
||||||
|
"def",
|
||||||
|
"method",
|
||||||
|
],
|
||||||
|
"swift": [
|
||||||
|
"class_declaration",
|
||||||
|
"struct_declaration",
|
||||||
|
"enum_declaration",
|
||||||
|
"protocol_declaration",
|
||||||
|
"function_declaration",
|
||||||
|
"method_declaration",
|
||||||
|
],
|
||||||
|
"kotlin": [
|
||||||
|
"class_declaration",
|
||||||
|
"object_declaration",
|
||||||
|
"function_declaration",
|
||||||
|
"method_declaration",
|
||||||
|
],
|
||||||
|
"scala": [
|
||||||
|
"class_definition",
|
||||||
|
"object_definition",
|
||||||
|
"trait_definition",
|
||||||
|
"function_definition",
|
||||||
|
"method_definition",
|
||||||
|
],
|
||||||
|
"c_sharp": [
|
||||||
|
"class_declaration",
|
||||||
|
"struct_declaration",
|
||||||
|
"interface_declaration",
|
||||||
|
"enum_declaration",
|
||||||
|
"method_declaration",
|
||||||
|
],
|
||||||
|
"go_mod": ["import_declaration"],
|
||||||
|
"bash": [
|
||||||
|
"function_definition",
|
||||||
|
"command",
|
||||||
|
],
|
||||||
|
"lua": [
|
||||||
|
"function_declaration",
|
||||||
|
"local_function",
|
||||||
|
"assignment_statement",
|
||||||
|
],
|
||||||
|
"php": [
|
||||||
|
"class_declaration",
|
||||||
|
"interface_declaration",
|
||||||
|
"trait_declaration",
|
||||||
|
"function_definition",
|
||||||
|
"method_declaration",
|
||||||
|
],
|
||||||
|
"sql": [
|
||||||
|
"create_table_statement",
|
||||||
|
"create_index_statement",
|
||||||
|
"create_view_statement",
|
||||||
|
"create_trigger_statement",
|
||||||
|
"procedure_definition",
|
||||||
|
"function_definition",
|
||||||
|
],
|
||||||
|
"toml": ["table"],
|
||||||
|
"yaml": ["block_mapping"],
|
||||||
|
"json": ["pair", "array"],
|
||||||
|
"html": [
|
||||||
|
"element",
|
||||||
|
"script_element",
|
||||||
|
"style_element",
|
||||||
|
],
|
||||||
|
"css": ["rule_set", "at_rule"],
|
||||||
|
"markdown": ["section", "fenced_code_block"],
|
||||||
|
"rst": ["section", "directive"],
|
||||||
|
"dockerfile": ["from_instruction", "run_instruction", "cmd_instruction"],
|
||||||
|
"hcl": ["block", "attribute"],
|
||||||
|
"haskell": [
|
||||||
|
"type_declaration",
|
||||||
|
"data_declaration",
|
||||||
|
"function_definition",
|
||||||
|
"class_declaration",
|
||||||
|
"instance_declaration",
|
||||||
|
],
|
||||||
|
"elixir": [
|
||||||
|
"module_definition",
|
||||||
|
"function_definition",
|
||||||
|
"def",
|
||||||
|
"defp",
|
||||||
|
],
|
||||||
|
"erlang": [
|
||||||
|
"function_definition",
|
||||||
|
"module_definition",
|
||||||
|
],
|
||||||
|
"ocaml": [
|
||||||
|
"value_definition",
|
||||||
|
"type_definition",
|
||||||
|
"module_definition",
|
||||||
|
"class_definition",
|
||||||
|
],
|
||||||
|
"fsharp": [
|
||||||
|
"class_definition",
|
||||||
|
"module_definition",
|
||||||
|
"function_definition",
|
||||||
|
"type_definition",
|
||||||
|
],
|
||||||
|
"r": [
|
||||||
|
"function_definition",
|
||||||
|
"assignment",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def detect_language(file_path: str) -> Language | None:
|
MAX_CHUNK_CHARS = 8000
|
||||||
"""检测文件语言类型"""
|
MIN_CHUNK_CHARS = 100
|
||||||
ext = Path(file_path).suffix.lower()
|
|
||||||
return LANGUAGE_MAP.get(ext)
|
|
||||||
|
class TiktokenCounter:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
self._enc = tiktoken.encoding_for_model("gpt-4o")
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def count(self, text: str) -> int:
|
||||||
|
return len(self._enc.encode(text))
|
||||||
|
|
||||||
|
|
||||||
|
_token_counter = TiktokenCounter()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_languages_so() -> str:
|
||||||
|
import site
|
||||||
|
|
||||||
|
for path in site.getsitepackages() + [site.getusersitepackages()]:
|
||||||
|
so_path = os.path.join(path, "tree_sitter_languages", "languages.so")
|
||||||
|
if os.path.exists(so_path):
|
||||||
|
return so_path
|
||||||
|
for path in site.getsitepackages():
|
||||||
|
if os.path.exists(os.path.join(path, "tree_sitter_languages")):
|
||||||
|
return os.path.join(path, "tree_sitter_languages", "languages.so")
|
||||||
|
cache_dir = os.path.expanduser("~/.cache/uv/archive-v0")
|
||||||
|
if os.path.exists(cache_dir):
|
||||||
|
for entry in os.listdir(cache_dir):
|
||||||
|
entry_path = os.path.join(
|
||||||
|
cache_dir, entry, "tree_sitter_languages", "languages.so"
|
||||||
|
)
|
||||||
|
if os.path.exists(entry_path):
|
||||||
|
return entry_path
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"无法找到 tree-sitter-languages 的 languages.so。"
|
||||||
|
" 请确保已安装 tree-sitter-languages。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_lib = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lib() -> ctypes.CDLL:
|
||||||
|
global _lib
|
||||||
|
if _lib is None:
|
||||||
|
so_path = _find_languages_so()
|
||||||
|
_lib = ctypes.CDLL(so_path)
|
||||||
|
return _lib
|
||||||
|
|
||||||
|
|
||||||
|
_parsers_cache: Dict[str, Parser] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser(language_name: str) -> Tuple[Parser, Language]:
|
||||||
|
if language_name in _parsers_cache:
|
||||||
|
return _parsers_cache[language_name], _parsers_cache[f"_lang_{language_name}"]
|
||||||
|
|
||||||
|
lib = _get_lib()
|
||||||
|
symbol_name = f"tree_sitter_{language_name}"
|
||||||
|
if not hasattr(lib, symbol_name):
|
||||||
|
raise ValueError(f"不支持的语言: {language_name}")
|
||||||
|
|
||||||
|
func = getattr(lib, symbol_name)
|
||||||
|
func.argtypes = []
|
||||||
|
func.restype = ctypes.c_void_p
|
||||||
|
ptr = func()
|
||||||
|
lang = Language(ptr) # type: ignore[call-arg]
|
||||||
|
parser = Parser(lang)
|
||||||
|
_parsers_cache[language_name] = parser
|
||||||
|
_parsers_cache[f"_lang_{language_name}"] = lang
|
||||||
|
return parser, lang
|
||||||
|
|
||||||
|
|
||||||
|
def _is_semantic_unit(node: Node, semantic_types: List[str]) -> bool:
|
||||||
|
return node.type in semantic_types
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_semantic_chunks(
|
||||||
|
node: Node,
|
||||||
|
semantic_types: List[str],
|
||||||
|
source_bytes: bytes,
|
||||||
|
parent_context: str = "",
|
||||||
|
) -> List[Tuple[str, str, int, int]]:
|
||||||
|
chunks = []
|
||||||
|
node_type = node.type
|
||||||
|
|
||||||
|
if _is_semantic_unit(node, semantic_types):
|
||||||
|
text = source_bytes[node.start_byte : node.end_byte].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
token_count = _token_counter.count(text)
|
||||||
|
|
||||||
|
if token_count > 4000:
|
||||||
|
children_chunks = _split_large_node(
|
||||||
|
node, semantic_types, source_bytes, parent_context
|
||||||
|
)
|
||||||
|
if children_chunks:
|
||||||
|
chunks.extend(children_chunks)
|
||||||
|
else:
|
||||||
|
chunks.append((node_type, text, node.start_byte, node.end_byte))
|
||||||
|
else:
|
||||||
|
chunks.append((node_type, text, node.start_byte, node.end_byte))
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
chunks.extend(
|
||||||
|
_extract_semantic_chunks(
|
||||||
|
child, semantic_types, source_bytes, parent_context
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _split_large_node(
|
||||||
|
node: Node,
|
||||||
|
semantic_types: List[str],
|
||||||
|
source_bytes: bytes,
|
||||||
|
parent_context: str,
|
||||||
|
) -> List[Tuple[str, str, int, int]]:
|
||||||
|
children_with_semantic = [
|
||||||
|
(child, child.type in semantic_types)
|
||||||
|
for child in node.children
|
||||||
|
if child.type not in ("comment", "preproc_directive", "attribute")
|
||||||
|
]
|
||||||
|
|
||||||
|
if not children_with_semantic:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
current_text_parts = []
|
||||||
|
current_start = children_with_semantic[0][0].start_byte
|
||||||
|
|
||||||
|
for child, is_semantic in children_with_semantic:
|
||||||
|
if is_semantic:
|
||||||
|
if current_text_parts:
|
||||||
|
combined = source_bytes[
|
||||||
|
current_start : current_text_parts[0][0].start_byte
|
||||||
|
].decode("utf-8", errors="replace") + "".join(
|
||||||
|
source_bytes[p[0].start_byte : p[0].end_byte].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
for p in current_text_parts
|
||||||
|
)
|
||||||
|
if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4:
|
||||||
|
results.append(
|
||||||
|
(
|
||||||
|
"sub_chunk",
|
||||||
|
combined,
|
||||||
|
current_start,
|
||||||
|
current_text_parts[0][0].start_byte,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_text_parts = []
|
||||||
|
current_start = child.start_byte
|
||||||
|
|
||||||
|
child_text = source_bytes[child.start_byte : child.end_byte].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
if _token_counter.count(child_text) <= 4000:
|
||||||
|
results.append(
|
||||||
|
(child.type, child_text, child.start_byte, child.end_byte)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
nested = _split_large_node(
|
||||||
|
child, semantic_types, source_bytes, parent_context
|
||||||
|
)
|
||||||
|
results.extend(nested)
|
||||||
|
else:
|
||||||
|
current_text_parts.append((child, child.type))
|
||||||
|
|
||||||
|
if current_text_parts:
|
||||||
|
combined = source_bytes[
|
||||||
|
current_start : current_text_parts[0][0].start_byte
|
||||||
|
].decode("utf-8", errors="replace") + "".join(
|
||||||
|
source_bytes[p[0].start_byte : p[0].end_byte].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
for p in current_text_parts
|
||||||
|
)
|
||||||
|
if _token_counter.count(combined) >= MIN_CHUNK_CHARS // 4:
|
||||||
|
results.append(
|
||||||
|
(
|
||||||
|
"sub_chunk",
|
||||||
|
combined,
|
||||||
|
current_start,
|
||||||
|
current_text_parts[0][0].start_byte,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _simple_chunk_text(content: str, chunk_size: int = 512) -> List[str]:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
enc = tiktoken.encoding_for_model("gpt-4o")
|
||||||
|
lines = content.split("\n")
|
||||||
|
chunks = []
|
||||||
|
current = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line_tokens = len(enc.encode(line))
|
||||||
|
if current_tokens + line_tokens > chunk_size and current:
|
||||||
|
chunks.append("\n".join(current))
|
||||||
|
current = [line]
|
||||||
|
current_tokens = line_tokens
|
||||||
|
else:
|
||||||
|
current.append(line)
|
||||||
|
current_tokens += line_tokens
|
||||||
|
|
||||||
|
if current:
|
||||||
|
chunks.append("\n".join(current))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def chunk_code(
|
def chunk_code(
|
||||||
|
|
@ -50,41 +481,32 @@ def chunk_code(
|
||||||
chunk_size: int = 512,
|
chunk_size: int = 512,
|
||||||
chunk_overlap: int = 50,
|
chunk_overlap: int = 50,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""将代码文件分块,返回块列表
|
ext = Path(file_path).suffix.lower()
|
||||||
|
lang_name = LANGUAGE_MAP.get(ext)
|
||||||
|
|
||||||
使用 langchain_text_splitters 进行语言感知分块。
|
if not lang_name:
|
||||||
支持 29 种编程语言的语法感知分块。
|
texts = _simple_chunk_text(content, chunk_size)
|
||||||
|
lang_display = "text"
|
||||||
Args:
|
|
||||||
content: 文件内容
|
|
||||||
file_path: 文件路径
|
|
||||||
chunk_size: 分块大小(token 数)
|
|
||||||
chunk_overlap: 重叠大小(token 数)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
分块列表,每个块包含 text 和 metadata
|
|
||||||
"""
|
|
||||||
language = detect_language(file_path)
|
|
||||||
|
|
||||||
if language is None:
|
|
||||||
# 不支持的语言,使用简单分块
|
|
||||||
splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separators=["\n\n", "\n", " ", ""],
|
|
||||||
length_function=len,
|
|
||||||
)
|
|
||||||
texts = splitter.split_text(content)
|
|
||||||
language_name = "text"
|
|
||||||
else:
|
else:
|
||||||
# 使用语言感知分块
|
semantic_types = SEMANTIC_UNITS.get(lang_name, [])
|
||||||
splitter = RecursiveCharacterTextSplitter.from_language(
|
source_bytes = content.encode("utf-8")
|
||||||
language=language,
|
|
||||||
chunk_size=chunk_size,
|
try:
|
||||||
chunk_overlap=chunk_overlap,
|
parser, _ = get_parser(lang_name)
|
||||||
)
|
except Exception:
|
||||||
texts = splitter.split_text(content)
|
texts = _simple_chunk_text(content, chunk_size)
|
||||||
language_name = language.value
|
lang_display = lang_name
|
||||||
|
else:
|
||||||
|
tree = parser.parse(source_bytes)
|
||||||
|
raw_chunks = _extract_semantic_chunks(
|
||||||
|
tree.root_node, semantic_types, source_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
if not raw_chunks:
|
||||||
|
texts = _simple_chunk_text(content, chunk_size)
|
||||||
|
else:
|
||||||
|
texts = [text for _, text, _, _ in raw_chunks]
|
||||||
|
lang_display = lang_name
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|
@ -92,8 +514,13 @@ def chunk_code(
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"source_file": file_path,
|
"source_file": file_path,
|
||||||
"chunk_index": i,
|
"chunk_index": i,
|
||||||
"language": language_name,
|
"language": lang_display,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for i, text in enumerate(texts)
|
for i, text in enumerate(texts)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_language(file_path: str) -> Optional[str]:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
return LANGUAGE_MAP.get(ext)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import click
|
import click
|
||||||
|
from pathlib import Path
|
||||||
from ocrag.commands import add as add_cmd
|
from ocrag.commands import add as add_cmd
|
||||||
from ocrag.commands import search as search_cmd
|
from ocrag.commands import search as search_cmd
|
||||||
from ocrag.commands import remove as remove_cmd
|
from ocrag.commands import remove as remove_cmd
|
||||||
|
|
@ -13,11 +14,31 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
@click.argument("paths", nargs=-1, required=True, type=click.Path(exists=True))
|
@click.argument("paths", nargs=-1, required=True)
|
||||||
@click.option("--recursive", "-r", is_flag=True, help="递归处理目录")
|
@click.option("--recursive", "-r", is_flag=True, help="递归处理目录")
|
||||||
def add(paths, recursive):
|
@click.option(
|
||||||
"""向知识库添加文件或目录"""
|
"--include", "-i", multiple=True, help="只添加匹配的文件(glob 模式,可多次指定)"
|
||||||
add_cmd(paths, recursive)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--exclude", "-e", multiple=True, help="排除匹配的文件(glob 模式,可多次指定)"
|
||||||
|
)
|
||||||
|
@click.option("--no-ignore", is_flag=True, help="禁用 .gitignore 过滤")
|
||||||
|
def add(paths, recursive, include, exclude, no_ignore):
|
||||||
|
"""向知识库添加文件或目录
|
||||||
|
|
||||||
|
默认会自动读取 .gitignore 规则并跳过匹配的文件。
|
||||||
|
目录遍历时自动跳过 .git、__pycache__、node_modules 等目录。
|
||||||
|
"""
|
||||||
|
for p in paths:
|
||||||
|
if not Path(p).exists():
|
||||||
|
raise click.BadParameter(f"路径不存在: {p}")
|
||||||
|
add_cmd(
|
||||||
|
paths,
|
||||||
|
recursive,
|
||||||
|
include_patterns=list(include) if include else None,
|
||||||
|
exclude_patterns=list(exclude) if exclude else None,
|
||||||
|
use_gitignore=not no_ignore,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
|
|
@ -29,16 +50,17 @@ def search(query, top_k):
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
@click.argument("path", type=click.Path())
|
@click.argument("pattern")
|
||||||
def remove(path):
|
def remove(pattern):
|
||||||
"""从知识库中移除指定文件"""
|
"""从知识库中移除匹配模式的所有文件(支持通配符)"""
|
||||||
remove_cmd(path)
|
remove_cmd(pattern)
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command("list")
|
||||||
def list():
|
@click.argument("pattern", required=False)
|
||||||
"""列出知识库中的所有条目"""
|
def list_files(pattern):
|
||||||
list_cmd()
|
"""列出知识库中的条目(可选通配符筛选)"""
|
||||||
|
list_cmd(pattern)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,56 +1,143 @@
|
||||||
import os
|
import os
|
||||||
|
import pathspec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
from tqdm import tqdm
|
||||||
from ocrag.chunker import chunk_code
|
from ocrag.chunker import chunk_code
|
||||||
from ocrag.embedder import embedder
|
from ocrag.embedder import embedder
|
||||||
from ocrag.db import RagDB
|
from ocrag.db import RagDB
|
||||||
|
from ocrag.utils import is_text_file
|
||||||
|
from ocrag.gitignore import get_gitignore_matcher
|
||||||
|
|
||||||
|
|
||||||
def collect_files(paths, recursive):
|
def collect_files(
|
||||||
"""收集所有需要处理的文件"""
|
paths,
|
||||||
|
recursive: bool,
|
||||||
|
include_patterns: List[str] = None,
|
||||||
|
exclude_patterns: List[str] = None,
|
||||||
|
use_gitignore: bool = True,
|
||||||
|
) -> List[Path]:
|
||||||
|
"""收集所有需要处理的文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: 文件或目录路径列表
|
||||||
|
recursive: 是否递归处理目录
|
||||||
|
include_patterns: 只处理匹配这些 glob 模式的文件
|
||||||
|
exclude_patterns: 跳过匹配这些 glob 模式的文件
|
||||||
|
use_gitignore: 是否应用 .gitignore 规则
|
||||||
|
"""
|
||||||
files = []
|
files = []
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
exclude_spec = None
|
||||||
|
if exclude_patterns:
|
||||||
|
exclude_spec = pathspec.PathSpec.from_lines("gitignore", exclude_patterns)
|
||||||
|
|
||||||
|
include_spec = None
|
||||||
|
if include_patterns:
|
||||||
|
include_spec = pathspec.PathSpec.from_lines("gitignore", include_patterns)
|
||||||
|
|
||||||
for p in paths:
|
for p in paths:
|
||||||
p = Path(p)
|
p = Path(p).resolve()
|
||||||
if p.is_file():
|
if p.is_file():
|
||||||
files.append(p)
|
files.append(p)
|
||||||
elif p.is_dir() and recursive:
|
elif p.is_dir() and recursive:
|
||||||
for root, dirs, filenames in os.walk(p):
|
matcher = get_gitignore_matcher(p) if use_gitignore else None
|
||||||
for f in filenames:
|
|
||||||
files.append(Path(root) / f)
|
for root_dir, dirs, filenames in os.walk(p):
|
||||||
|
root_dir = Path(root_dir)
|
||||||
|
|
||||||
|
dirs[:] = [
|
||||||
|
d
|
||||||
|
for d in dirs
|
||||||
|
if d
|
||||||
|
not in (
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
".pytest_cache",
|
||||||
|
".mypy_cache",
|
||||||
|
"node_modules",
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
".env",
|
||||||
|
".tox",
|
||||||
|
".eggs",
|
||||||
|
)
|
||||||
|
and not d.endswith(".egg-info")
|
||||||
|
]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = root_dir / filename
|
||||||
|
|
||||||
|
if file_path in seen:
|
||||||
|
continue
|
||||||
|
seen.add(file_path)
|
||||||
|
|
||||||
|
rel = file_path.relative_to(p)
|
||||||
|
|
||||||
|
if not is_text_file(str(file_path)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if matcher and matcher.is_ignored(file_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if exclude_spec and exclude_spec.match_file(str(rel)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if include_spec and not include_spec.match_file(str(rel)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
files.append(file_path)
|
||||||
elif p.is_dir() and not recursive:
|
elif p.is_dir() and not recursive:
|
||||||
print(f"跳过目录 {p}(未使用 --recursive)")
|
print(f"跳过目录 {p}(未使用 --recursive)")
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def run(paths, recursive):
|
def run(
|
||||||
|
paths,
|
||||||
|
recursive: bool,
|
||||||
|
include_patterns: List[str] = None,
|
||||||
|
exclude_patterns: List[str] = None,
|
||||||
|
use_gitignore: bool = True,
|
||||||
|
):
|
||||||
db = RagDB()
|
db = RagDB()
|
||||||
files = collect_files(paths, recursive)
|
files = collect_files(
|
||||||
|
paths, recursive, include_patterns, exclude_patterns, use_gitignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
print("没有找到需要处理的文件")
|
||||||
|
return
|
||||||
|
|
||||||
total_chunks = 0
|
total_chunks = 0
|
||||||
for file_path in files:
|
with tqdm(total=len(files), desc="处理文件", unit="file") as pbar:
|
||||||
try:
|
for file_path in files:
|
||||||
content = file_path.read_text(encoding="utf-8")
|
try:
|
||||||
chunks = chunk_code(content, str(file_path))
|
content = file_path.read_text(encoding="utf-8")
|
||||||
if not chunks:
|
chunks = chunk_code(content, str(file_path))
|
||||||
continue
|
if not chunks:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
# 批量 embedding
|
texts = [c["text"] for c in chunks]
|
||||||
texts = [c["text"] for c in chunks]
|
vectors = embedder.embed(texts)
|
||||||
vectors = embedder.embed(texts)
|
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
for chunk, vec in zip(chunks, vectors):
|
for chunk, vec in zip(chunks, vectors):
|
||||||
documents.append(
|
documents.append(
|
||||||
{
|
{
|
||||||
"text": chunk["text"],
|
"text": chunk["text"],
|
||||||
"vector": vec,
|
"vector": vec,
|
||||||
"metadata": chunk["metadata"],
|
"metadata": chunk["metadata"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add_documents(documents)
|
db.add_documents(documents)
|
||||||
total_chunks += len(documents)
|
total_chunks += len(documents)
|
||||||
print(f"✅ {file_path} -> {len(documents)} 个块")
|
pbar.set_postfix({"块": f"{total_chunks}"})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 处理 {file_path} 失败: {e}")
|
print(f"\n❌ 处理 {file_path} 失败: {e}")
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
print(f"\n📦 总计添加 {total_chunks} 个块")
|
print(f"\n📦 总计添加 {total_chunks} 个块")
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
from ocrag.db import RagDB
|
from ocrag.db import RagDB
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run(pattern: str = None):
|
||||||
db = RagDB()
|
db = RagDB()
|
||||||
sources = db.list_sources()
|
sources = db.list_sources(pattern)
|
||||||
|
|
||||||
if not sources:
|
if not sources:
|
||||||
print("知识库为空")
|
print("知识库为空")
|
||||||
return
|
return
|
||||||
|
|
||||||
print("知识库中的文件:")
|
label = f'匹配 "{pattern}" 的' if pattern else ""
|
||||||
|
print(f"{label}知识库中的文件:")
|
||||||
for i, source in enumerate(sources, 1):
|
for i, source in enumerate(sources, 1):
|
||||||
print(f"{i}. {source}")
|
print(f"{i}. {source}")
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
|
import fnmatch
|
||||||
from ocrag.db import RagDB
|
from ocrag.db import RagDB
|
||||||
|
|
||||||
|
|
||||||
def run(path: str):
|
def run(pattern: str):
|
||||||
db = RagDB()
|
db = RagDB()
|
||||||
# 注意:当前db.py中的delete_by_source方法尚未完全实现
|
count = db.delete_by_pattern(pattern)
|
||||||
# 这里先调用以保持接口一致,后续需要完善db.py的实现
|
print(f'已删除 {count} 个匹配 "{pattern}" 的块')
|
||||||
db.delete_by_source(path)
|
|
||||||
print(f"已删除 {path} 的所有块")
|
|
||||||
|
|
|
||||||
100
src/ocrag/db.py
100
src/ocrag/db.py
|
|
@ -8,43 +8,51 @@ DB_PATH = Path.home() / ".ocrag" / "data.lance"
|
||||||
|
|
||||||
class RagDB:
|
class RagDB:
|
||||||
def __init__(self, db_path: str = None):
|
def __init__(self, db_path: str = None):
|
||||||
self.path = db_path or str(DB_PATH)
|
import os as _os
|
||||||
|
|
||||||
|
self.path = db_path or _os.environ.get("OCRAG_DB_PATH", str(DB_PATH))
|
||||||
self.conn = lancedb.connect(self.path)
|
self.conn = lancedb.connect(self.path)
|
||||||
self._init_table()
|
self._init_table()
|
||||||
|
|
||||||
def _init_table(self):
|
def _init_table(self):
|
||||||
"""如果表不存在则创建"""
|
"""如果表不存在则创建"""
|
||||||
schema = pa.schema(
|
|
||||||
[
|
|
||||||
("text", pa.string()),
|
|
||||||
(
|
|
||||||
"vector",
|
|
||||||
pa.list_(pa.float32(), 1024),
|
|
||||||
), # 1024维 (Qwen3-Embedding-0.6B)
|
|
||||||
("metadata", pa.string()), # JSON 字符串
|
|
||||||
]
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
self.table = self.conn.open_table("documents")
|
self.table = self.conn.open_table("documents")
|
||||||
except:
|
except:
|
||||||
self.table = self.conn.create_table("documents", schema=schema)
|
empty = pa.table(
|
||||||
|
{
|
||||||
|
"text": pa.array([], type=pa.string()),
|
||||||
|
"vector": pa.array([], type=pa.list_(pa.float32(), 1024)),
|
||||||
|
"metadata": pa.array([], type=pa.string()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.table = self.conn.create_table("documents", empty)
|
||||||
|
|
||||||
def add_documents(self, documents: List[Dict[str, Any]]):
|
def add_documents(self, documents: List[Dict[str, Any]]):
|
||||||
"""批量添加文档
|
"""批量添加文档
|
||||||
documents: [{"text": str, "vector": List[float], "metadata": dict}, ...]
|
documents: [{"text": str, "vector": List[float], "metadata": dict}, ...]
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
data = []
|
texts = []
|
||||||
|
vectors = []
|
||||||
|
metadatas = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
data.append(
|
texts.append(doc["text"])
|
||||||
{
|
vec = doc["vector"]
|
||||||
"text": doc["text"],
|
if isinstance(vec, list):
|
||||||
"vector": doc["vector"],
|
vec = np.array(vec, dtype=np.float32)
|
||||||
"metadata": json.dumps(doc.get("metadata", {})),
|
vectors.append(vec)
|
||||||
}
|
metadatas.append(json.dumps(doc.get("metadata", {})))
|
||||||
)
|
|
||||||
self.table.add(data)
|
arrays = [
|
||||||
|
pa.array(texts),
|
||||||
|
pa.array(vectors, type=pa.list_(pa.float32(), 1024)),
|
||||||
|
pa.array(metadatas),
|
||||||
|
]
|
||||||
|
batch = pa.record_batch(arrays, names=["text", "vector", "metadata"])
|
||||||
|
self.table.add(batch)
|
||||||
|
|
||||||
def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
|
def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
import json
|
import json
|
||||||
|
|
@ -96,9 +104,50 @@ class RagDB:
|
||||||
|
|
||||||
return num_deleted
|
return num_deleted
|
||||||
|
|
||||||
def list_sources(self) -> List[str]:
|
def delete_by_pattern(self, pattern: str) -> int:
|
||||||
"""列出所有已添加的源文件路径(去重)"""
|
"""删除匹配通配符模式的所有源文件的所有块
|
||||||
# 获取所有 metadata,提取 source_file
|
|
||||||
|
Args:
|
||||||
|
pattern: fnmatch 模式,如 "*.py", "src/**/*.js", "**/test_*.py"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的块数量
|
||||||
|
"""
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
|
||||||
|
df = self.table.to_pandas()
|
||||||
|
if df.empty:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
indices_to_keep = []
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
try:
|
||||||
|
meta = json.loads(row["metadata"])
|
||||||
|
source = meta.get("source_file", "")
|
||||||
|
if not fnmatch.fnmatch(source, pattern):
|
||||||
|
indices_to_keep.append(idx)
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
indices_to_keep.append(idx)
|
||||||
|
|
||||||
|
num_deleted = len(df) - len(indices_to_keep)
|
||||||
|
if num_deleted == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
df_remaining = df.loc[indices_to_keep]
|
||||||
|
self.conn.drop_table("documents")
|
||||||
|
self.table = self.conn.create_table("documents", df_remaining)
|
||||||
|
|
||||||
|
return num_deleted
|
||||||
|
|
||||||
|
def list_sources(self, pattern: str = None) -> List[str]:
|
||||||
|
"""列出所有已添加的源文件路径(去重)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 可选的 fnmatch 模式,筛选匹配的文件
|
||||||
|
"""
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
df = self.table.to_pandas()
|
df = self.table.to_pandas()
|
||||||
if df.empty:
|
if df.empty:
|
||||||
return []
|
return []
|
||||||
|
|
@ -108,5 +157,6 @@ class RagDB:
|
||||||
for meta_str in df["metadata"]:
|
for meta_str in df["metadata"]:
|
||||||
meta = json.loads(meta_str)
|
meta = json.loads(meta_str)
|
||||||
if "source_file" in meta:
|
if "source_file" in meta:
|
||||||
sources.add(meta["source_file"])
|
if pattern is None or fnmatch.fnmatch(meta["source_file"], pattern):
|
||||||
|
sources.add(meta["source_file"])
|
||||||
return sorted(sources)
|
return sorted(sources)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import os
|
||||||
|
import fnmatch
|
||||||
|
import pathspec
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Set, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class GitignoreMatcher:
|
||||||
|
"""多层 .gitignore 解析器
|
||||||
|
|
||||||
|
支持从项目根目录向下遍历,收集所有 .gitignore 文件,
|
||||||
|
每个 .gitignore 中的规则只对该文件所在目录及其子目录生效。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root: Path):
|
||||||
|
self.root = root.resolve()
|
||||||
|
self._specs: List[Tuple[Path, pathspec.PathSpec]] = []
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
"""遍历目录树,收集所有 .gitignore 并构建 pathspec"""
|
||||||
|
for dirpath, dirnames, filenames in os.walk(self.root):
|
||||||
|
dirpath = Path(dirpath)
|
||||||
|
gi_path = dirpath / ".gitignore"
|
||||||
|
if gi_path.exists() and gi_path.is_file():
|
||||||
|
try:
|
||||||
|
lines = gi_path.read_text(
|
||||||
|
encoding="utf-8", errors="ignore"
|
||||||
|
).splitlines()
|
||||||
|
spec = pathspec.PathSpec.from_lines("gitignore", lines)
|
||||||
|
self._specs.append((dirpath, spec))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
dirnames[:] = [d for d in dirnames if d not in (".git", "__pycache__")]
|
||||||
|
|
||||||
|
def is_ignored(self, file_path: Path) -> bool:
|
||||||
|
"""判断文件是否被任一层 .gitignore 忽略"""
|
||||||
|
abs_path = file_path.resolve()
|
||||||
|
|
||||||
|
for dirpath, spec in self._specs:
|
||||||
|
try:
|
||||||
|
rel = abs_path.relative_to(dirpath)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if spec.match_file(str(rel)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def find_project_root(start_path: Path) -> Optional[Path]:
|
||||||
|
"""从 start_path 向上查找包含 .gitignore 的目录作为项目根"""
|
||||||
|
current = start_path.resolve()
|
||||||
|
if start_path.is_file():
|
||||||
|
current = current.parent
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if (current / ".gitignore").exists():
|
||||||
|
return current
|
||||||
|
parent = current.parent
|
||||||
|
if parent == current:
|
||||||
|
return None
|
||||||
|
current = parent
|
||||||
|
|
||||||
|
|
||||||
|
def get_gitignore_matcher(root: Path) -> Optional[GitignoreMatcher]:
|
||||||
|
"""为给定目录创建 GitignoreMatcher"""
|
||||||
|
gi_path = find_project_root(root)
|
||||||
|
if gi_path is None:
|
||||||
|
return None
|
||||||
|
return GitignoreMatcher(gi_path)
|
||||||
|
|
@ -33,44 +33,23 @@ def format_size(size_bytes: int) -> str:
|
||||||
|
|
||||||
|
|
||||||
def is_text_file(file_path: str) -> bool:
|
def is_text_file(file_path: str) -> bool:
|
||||||
"""检查文件是否为文本文件"""
|
"""检查文件是否为文本文件
|
||||||
text_extensions = {
|
|
||||||
".py",
|
通过读取文件前 1KB 内容检测:
|
||||||
".js",
|
- 包含 null byte (\\x00) 则判定为二进制
|
||||||
".ts",
|
- 能够以 UTF-8 解码则判定为文本
|
||||||
".jsx",
|
"""
|
||||||
".tsx",
|
try:
|
||||||
".java",
|
with open(file_path, "rb") as f:
|
||||||
".c",
|
chunk = f.read(1024)
|
||||||
".cpp",
|
if not chunk:
|
||||||
".h",
|
return True
|
||||||
".hpp",
|
if b"\x00" in chunk:
|
||||||
".go",
|
return False
|
||||||
".rs",
|
chunk.decode("utf-8")
|
||||||
".rb",
|
return True
|
||||||
".php",
|
except (UnicodeDecodeError, OSError):
|
||||||
".swift",
|
return False
|
||||||
".kt",
|
|
||||||
".scala",
|
|
||||||
".md",
|
|
||||||
".txt",
|
|
||||||
".json",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".xml",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".sql",
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".ps1",
|
|
||||||
".bat",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
}
|
|
||||||
ext = Path(file_path).suffix.lower()
|
|
||||||
return ext in text_extensions
|
|
||||||
|
|
||||||
|
|
||||||
def get_project_root() -> Path:
|
def get_project_root() -> Path:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import os
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
from ocrag.cli import main
|
from ocrag.cli import main
|
||||||
from ocrag.db import RagDB
|
from ocrag.db import RagDB
|
||||||
|
from ocrag.commands.add import collect_files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -12,17 +13,15 @@ def runner():
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_test_env(tmpdir):
|
def setup_test_env(tmpdir):
|
||||||
# Create test files
|
|
||||||
os.makedirs(os.path.join(tmpdir, "test_dir"))
|
os.makedirs(os.path.join(tmpdir, "test_dir"))
|
||||||
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
|
with open(os.path.join(tmpdir, "test_dir", "file1.py"), "w") as f:
|
||||||
f.write("def test_func():\n pass")
|
f.write("def test_func():\n pass")
|
||||||
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
|
with open(os.path.join(tmpdir, "test_file.py"), "w") as f:
|
||||||
f.write("print('Hello')")
|
f.write("print('Hello')")
|
||||||
return tmpdir
|
return str(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
def test_add_command(runner, setup_test_env, tmpdir):
|
def test_add_command(runner, setup_test_env, tmpdir):
|
||||||
# Use temporary DB
|
|
||||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||||
os.environ["OCRAG_DB_PATH"] = db_path
|
os.environ["OCRAG_DB_PATH"] = db_path
|
||||||
|
|
||||||
|
|
@ -33,7 +32,6 @@ def test_add_command(runner, setup_test_env, tmpdir):
|
||||||
|
|
||||||
|
|
||||||
def test_search_command(runner, setup_test_env, tmpdir):
|
def test_search_command(runner, setup_test_env, tmpdir):
|
||||||
# Use temporary DB and add a file first
|
|
||||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||||
os.environ["OCRAG_DB_PATH"] = db_path
|
os.environ["OCRAG_DB_PATH"] = db_path
|
||||||
|
|
||||||
|
|
@ -42,14 +40,12 @@ def test_search_command(runner, setup_test_env, tmpdir):
|
||||||
)
|
)
|
||||||
assert add_result.exit_code == 0
|
assert add_result.exit_code == 0
|
||||||
|
|
||||||
# Search
|
|
||||||
result = runner.invoke(main, ["search", "Hello"])
|
result = runner.invoke(main, ["search", "Hello"])
|
||||||
assert "Hello" in result.output
|
assert "Hello" in result.output
|
||||||
assert "来源: " in result.output
|
assert "来源: " in result.output
|
||||||
|
|
||||||
|
|
||||||
def test_list_command(runner, setup_test_env, tmpdir):
|
def test_list_command(runner, setup_test_env, tmpdir):
|
||||||
# Use temporary DB and add a file first
|
|
||||||
db_path = os.path.join(tmpdir, "test_db.lance")
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||||
os.environ["OCRAG_DB_PATH"] = db_path
|
os.environ["OCRAG_DB_PATH"] = db_path
|
||||||
|
|
||||||
|
|
@ -58,6 +54,104 @@ def test_list_command(runner, setup_test_env, tmpdir):
|
||||||
)
|
)
|
||||||
assert add_result.exit_code == 0
|
assert add_result.exit_code == 0
|
||||||
|
|
||||||
# List
|
|
||||||
result = runner.invoke(main, ["list"])
|
result = runner.invoke(main, ["list"])
|
||||||
assert "test_file.py" in result.output
|
assert "test_file.py" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_wildcard(runner, setup_test_env, tmpdir):
|
||||||
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||||
|
os.environ["OCRAG_DB_PATH"] = db_path
|
||||||
|
|
||||||
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
|
||||||
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
|
||||||
|
|
||||||
|
result = runner.invoke(main, ["remove", "*.py"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "已删除" in result.output
|
||||||
|
|
||||||
|
list_result = runner.invoke(main, ["list"])
|
||||||
|
assert "test_file.py" not in list_result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_wildcard(runner, setup_test_env, tmpdir):
|
||||||
|
db_path = os.path.join(tmpdir, "test_db.lance")
|
||||||
|
os.environ["OCRAG_DB_PATH"] = db_path
|
||||||
|
|
||||||
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_file.py")])
|
||||||
|
runner.invoke(main, ["add", os.path.join(setup_test_env, "test_dir", "file1.py")])
|
||||||
|
|
||||||
|
result = runner.invoke(main, ["list", "*.py"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "test_file.py" in result.output
|
||||||
|
assert "file1.py" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectFiles:
|
||||||
|
def test_exclude_pattern(self, setup_test_env):
|
||||||
|
files = collect_files(
|
||||||
|
[setup_test_env], recursive=True, exclude_patterns=["**/test_*.py"]
|
||||||
|
)
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "test_file.py" not in names
|
||||||
|
assert "file1.py" in names
|
||||||
|
|
||||||
|
def test_include_pattern(self, setup_test_env):
|
||||||
|
files = collect_files(
|
||||||
|
[setup_test_env], recursive=True, include_patterns=["*.py"]
|
||||||
|
)
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "test_file.py" in names
|
||||||
|
assert "file1.py" in names
|
||||||
|
|
||||||
|
def test_binary_skip(self, tmpdir):
|
||||||
|
with open(os.path.join(tmpdir, "binary.bin"), "wb") as f:
|
||||||
|
f.write(b"\x00\x01\x02binary")
|
||||||
|
with open(os.path.join(tmpdir, "text.py"), "w") as f:
|
||||||
|
f.write("print('hello')")
|
||||||
|
|
||||||
|
files = collect_files([str(tmpdir)], recursive=True)
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "binary.bin" not in names
|
||||||
|
assert "text.py" in names
|
||||||
|
|
||||||
|
def test_gitignore(self, tmpdir):
|
||||||
|
os.makedirs(os.path.join(tmpdir, "src"))
|
||||||
|
with open(os.path.join(tmpdir, "src", "main.py"), "w") as f:
|
||||||
|
f.write("print('hello')")
|
||||||
|
with open(os.path.join(tmpdir, "src", "test_main.py"), "w") as f:
|
||||||
|
f.write("def test(): pass")
|
||||||
|
with open(os.path.join(tmpdir, ".gitignore"), "w") as f:
|
||||||
|
f.write("**/test_*.py\n")
|
||||||
|
|
||||||
|
files = collect_files([os.path.join(tmpdir, "src")], recursive=True)
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "main.py" in names
|
||||||
|
assert "test_main.py" not in names
|
||||||
|
|
||||||
|
def test_no_ignore_disables_gitignore(self, setup_test_env):
|
||||||
|
with open(os.path.join(setup_test_env, ".gitignore"), "w") as f:
|
||||||
|
f.write("test_file.py\n")
|
||||||
|
|
||||||
|
files = collect_files([setup_test_env], recursive=True, use_gitignore=False)
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "test_file.py" in names
|
||||||
|
|
||||||
|
def test_combined_filters(self, tmpdir):
|
||||||
|
os.makedirs(os.path.join(tmpdir, "lib"))
|
||||||
|
with open(os.path.join(tmpdir, "lib", "main.py"), "w") as f:
|
||||||
|
f.write("def main(): pass")
|
||||||
|
with open(os.path.join(tmpdir, "lib", "test_main.py"), "w") as f:
|
||||||
|
f.write("def test(): pass")
|
||||||
|
with open(os.path.join(tmpdir, "lib", "data.bin"), "wb") as f:
|
||||||
|
f.write(b"\x00binary")
|
||||||
|
|
||||||
|
files = collect_files(
|
||||||
|
[os.path.join(tmpdir, "lib")],
|
||||||
|
recursive=True,
|
||||||
|
include_patterns=["*.py"],
|
||||||
|
exclude_patterns=["**/test_*.py"],
|
||||||
|
)
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "main.py" in names
|
||||||
|
assert "test_main.py" not in names
|
||||||
|
assert "data.bin" not in names
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,126 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from ocrag.gitignore import GitignoreMatcher, get_gitignore_matcher, find_project_root
|
||||||
|
from ocrag.utils import is_text_file
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsTextFile:
|
||||||
|
def test_python_file_is_text(self, tmp_path):
|
||||||
|
f = tmp_path / "test.py"
|
||||||
|
f.write_text("def foo():\n pass")
|
||||||
|
assert is_text_file(str(f)) is True
|
||||||
|
|
||||||
|
def test_binary_file_is_not_text(self, tmp_path):
|
||||||
|
f = tmp_path / "test.bin"
|
||||||
|
f.write_bytes(b"\x00\x01\x02\x03binary")
|
||||||
|
assert is_text_file(str(f)) is False
|
||||||
|
|
||||||
|
def test_empty_file_is_text(self, tmp_path):
|
||||||
|
f = tmp_path / "empty.txt"
|
||||||
|
f.write_text("")
|
||||||
|
assert is_text_file(str(f)) is True
|
||||||
|
|
||||||
|
def test_utf8_with_non_ascii_is_text(self, tmp_path):
|
||||||
|
f = tmp_path / "unicode.txt"
|
||||||
|
f.write_text("中文内容测试\n日本語\n한국어")
|
||||||
|
assert is_text_file(str(f)) is True
|
||||||
|
|
||||||
|
def test_invalid_utf8_is_not_text(self, tmp_path):
|
||||||
|
f = tmp_path / "invalid.txt"
|
||||||
|
f.write_bytes(b"hello\x80world")
|
||||||
|
assert is_text_file(str(f)) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindProjectRoot:
|
||||||
|
def test_finds_gitignore(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("*.pyc")
|
||||||
|
sub = tmp_path / "src" / "deep"
|
||||||
|
sub.mkdir(parents=True)
|
||||||
|
found = find_project_root(sub)
|
||||||
|
assert found == tmp_path
|
||||||
|
|
||||||
|
def test_no_gitignore_returns_none(self, tmp_path):
|
||||||
|
sub = tmp_path / "src"
|
||||||
|
sub.mkdir(parents=True)
|
||||||
|
found = find_project_root(sub)
|
||||||
|
assert found is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitignoreMatcher:
|
||||||
|
def test_ignore_pattern(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("*.pyc\n__pycache__/\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
|
||||||
|
assert matcher.is_ignored(tmp_path / "file.pyc") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "file.py") is False
|
||||||
|
assert matcher.is_ignored(tmp_path / "__pycache__" / "cache.pyc") is True
|
||||||
|
|
||||||
|
def test_negation_pattern(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("*.log\n!important.log\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
|
||||||
|
assert matcher.is_ignored(tmp_path / "debug.log") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "important.log") is False
|
||||||
|
|
||||||
|
def test_subdirectory_gitignore(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("*.py\n")
|
||||||
|
sub_dir = tmp_path / "src"
|
||||||
|
sub_dir.mkdir()
|
||||||
|
(sub_dir / ".gitignore").write_text("test_*.py\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
|
||||||
|
assert matcher.is_ignored(tmp_path / "main.py") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "utils.py") is True
|
||||||
|
assert matcher.is_ignored(sub_dir / "helper.py") is True
|
||||||
|
assert matcher.is_ignored(sub_dir / "test_main.py") is True
|
||||||
|
|
||||||
|
def test_anchored_pattern(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("/build\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
|
||||||
|
assert matcher.is_ignored(tmp_path / "build" / "output") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "src" / "build" / "out") is False
|
||||||
|
|
||||||
|
def test_ignore_directory_ending_slash(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("node_modules/\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
|
||||||
|
assert matcher.is_ignored(tmp_path / "node_modules" / "package.json") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "other" / "node_modules" / "file") is True
|
||||||
|
|
||||||
|
def test_get_matcher_returns_none_when_no_gitignore(self, tmp_path):
|
||||||
|
result = get_gitignore_matcher(tmp_path / "nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_matcher_returns_matcher(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("*.pyc")
|
||||||
|
result = get_gitignore_matcher(tmp_path)
|
||||||
|
assert result is not None
|
||||||
|
assert result.is_ignored(tmp_path / "file.pyc") is True
|
||||||
|
|
||||||
|
def test_comments_ignored(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text(
|
||||||
|
"# this is a comment\n*.log\n # another comment\n"
|
||||||
|
)
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
assert matcher.is_ignored(tmp_path / "debug.log") is True
|
||||||
|
|
||||||
|
def test_whitespace_lines(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text(" *.log\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
assert matcher.is_ignored(tmp_path / " debug.log") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "debug.log") is False
|
||||||
|
|
||||||
|
def test_double_star_pattern(self, tmp_path):
|
||||||
|
(tmp_path / ".gitignore").write_text("**/*.pyc\n")
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
|
||||||
|
assert matcher.is_ignored(tmp_path / "a.pyc") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "src" / "a.pyc") is True
|
||||||
|
assert matcher.is_ignored(tmp_path / "src" / "deep" / "a.pyc") is True
|
||||||
|
|
||||||
|
def test_no_gitignore_file(self, tmp_path):
|
||||||
|
matcher = GitignoreMatcher(tmp_path)
|
||||||
|
assert matcher.is_ignored(tmp_path / "file.py") is False
|
||||||
|
assert matcher.is_ignored(tmp_path / "anything") is False
|
||||||
Loading…
Reference in New Issue