298 lines
8.5 KiB
Python
298 lines
8.5 KiB
Python
from typing import Tuple, Dict, List, Optional, Any
|
||
import os
|
||
from pathlib import Path
|
||
import queue
|
||
from collections import deque
|
||
from loguru import logger # 添加导入
|
||
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn
|
||
|
||
# 危险命令列表,可配置
|
||
DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
|
||
ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单
|
||
|
||
|
||
def is_dangerous_command(cmd: str) -> Tuple[bool, str]:
|
||
"""
|
||
判断命令是否危险
|
||
|
||
Args:
|
||
cmd: 命令字符串
|
||
|
||
Returns:
|
||
Tuple[bool, str]: (是否危险, 原因)
|
||
"""
|
||
cmd_lower = cmd.lower()
|
||
for danger in DANGEROUS_COMMANDS:
|
||
if danger in cmd_lower:
|
||
return True, f"包含危险关键词 '{danger}'"
|
||
return False, ""
|
||
|
||
|
||
def read_file(file_path: str) -> str:
|
||
"""
|
||
读取文件内容
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
str: 文件内容
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
except Exception as e:
|
||
raise IOError(f"读取文件失败: {file_path}, 错误: {e}")
|
||
|
||
|
||
def write_file(file_path: str, content: str) -> None:
|
||
"""
|
||
写入文件内容
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
content: 要写入的内容
|
||
"""
|
||
try:
|
||
path = Path(file_path)
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
except Exception as e:
|
||
raise IOError(f"写入文件失败: {file_path}, 错误: {e}")
|
||
|
||
|
||
def ensure_dir(directory: str) -> None:
|
||
"""
|
||
确保目录存在,如果不存在则创建
|
||
|
||
Args:
|
||
directory: 目录路径
|
||
"""
|
||
os.makedirs(directory, exist_ok=True)
|
||
|
||
|
||
def safe_join(base_path: str, *paths: str) -> str:
|
||
"""
|
||
安全地拼接路径,防止目录遍历攻击
|
||
|
||
Args:
|
||
base_path: 基础路径
|
||
*paths: 要拼接的部分
|
||
|
||
Returns:
|
||
str: 拼接后的绝对路径
|
||
"""
|
||
full_path = os.path.abspath(os.path.join(base_path, *paths))
|
||
base_abs = os.path.abspath(base_path)
|
||
if not full_path.startswith(base_abs):
|
||
raise ValueError(f"路径拼接越界: {full_path} 不在 {base_abs} 下")
|
||
return full_path
|
||
|
||
|
||
def log_error(error: Exception, message: str = None, is_fatal: bool = False) -> None:
|
||
"""
|
||
记录和显示错误
|
||
|
||
Args:
|
||
error: 异常对象
|
||
message: 可选的自定义错误消息,如果为 None 则使用 error 的字符串表示
|
||
is_fatal: 指示错误是否致命
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
if message is None:
|
||
message = str(error)
|
||
log_msg = f"错误: {message}"
|
||
if is_fatal:
|
||
logger.critical(log_msg)
|
||
else:
|
||
logger.error(log_msg)
|
||
|
||
|
||
def is_fatal_error(error: Exception) -> bool:
|
||
"""
|
||
判断错误类型是否为致命错误
|
||
|
||
Args:
|
||
error: 异常对象
|
||
|
||
Returns:
|
||
bool: 如果是致命错误返回 True,否则返回 False
|
||
"""
|
||
fatal_exceptions = (SystemExit, KeyboardInterrupt, MemoryError, OSError)
|
||
return isinstance(error, fatal_exceptions)
|
||
|
||
|
||
def build_dependency_graph(files: List[Dict[str, Any]]) -> Dict[str, List[str]]:
|
||
"""
|
||
构建依赖图,基于文件列表中的依赖关系
|
||
|
||
Args:
|
||
files: 文件列表,每个元素是字典,包含 'path' 和 'dependencies' 键
|
||
|
||
Returns:
|
||
Dict[str, List[str]]: 邻接表表示的依赖图,键为文件路径,值为依赖的文件路径列表
|
||
"""
|
||
graph: Dict[str, List[str]] = {}
|
||
for file in files:
|
||
path = file.get('path', '')
|
||
deps = file.get('dependencies', [])
|
||
if path:
|
||
graph[path] = deps
|
||
return graph
|
||
|
||
|
||
def compute_in_degrees(graph: Dict[str, List[str]]) -> Dict[str, int]:
|
||
"""
|
||
计算依赖图中每个节点的入度
|
||
|
||
Args:
|
||
graph: 依赖图,邻接表形式
|
||
|
||
Returns:
|
||
Dict[str, int]: 每个文件路径的入度值
|
||
"""
|
||
in_degrees: Dict[str, int] = {node: 0 for node in graph}
|
||
for node, deps in graph.items():
|
||
for dep in deps:
|
||
if dep in in_degrees:
|
||
in_degrees[dep] += 1
|
||
return in_degrees
|
||
|
||
|
||
class ConcurrentQueueManager:
|
||
"""
|
||
管理并发队列的简单类,用于并行任务如代码生成或检查
|
||
"""
|
||
def __init__(self, maxsize: int = 0):
|
||
"""
|
||
初始化队列
|
||
|
||
Args:
|
||
maxsize: 队列最大大小,0 表示无限制
|
||
"""
|
||
self.queue = queue.Queue(maxsize)
|
||
|
||
def enqueue(self, item: Any) -> None:
|
||
"""
|
||
将项目加入队列
|
||
|
||
Args:
|
||
item: 要加入队列的项目
|
||
"""
|
||
self.queue.put(item)
|
||
|
||
def dequeue(self, block: bool = True, timeout: Optional[float] = None) -> Any:
|
||
"""
|
||
从队列中取出项目
|
||
|
||
Args:
|
||
block: 是否阻塞直到有项目可用
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
Any: 取出的项目
|
||
"""
|
||
return self.queue.get(block=block, timeout=timeout)
|
||
|
||
def is_empty(self) -> bool:
|
||
"""
|
||
检查队列是否为空
|
||
|
||
Returns:
|
||
bool: 如果队列为空返回 True,否则返回 False
|
||
"""
|
||
return self.queue.empty()
|
||
|
||
def size(self) -> int:
|
||
"""
|
||
获取队列中项目数量
|
||
|
||
Returns:
|
||
int: 队列大小
|
||
"""
|
||
return self.queue.qsize()
|
||
|
||
|
||
def add_implicit_dependency(file_content: str, current_deps: List[str], implicit_dep_file: str = "src/llm_codegen/utils.py") -> List[str]:
|
||
"""
|
||
添加隐式依赖,例如如果文件内容引用了特定文件,则自动添加依赖
|
||
|
||
Args:
|
||
file_content: 文件内容字符串
|
||
current_deps: 当前依赖列表
|
||
implicit_dep_file: 要检查的隐式依赖文件路径
|
||
|
||
Returns:
|
||
List[str]: 更新后的依赖列表,如果检测到引用则添加隐式依赖
|
||
"""
|
||
updated_deps = current_deps.copy()
|
||
# 简化检查:如果文件内容包含导入或引用 utils.py 的迹象,则添加依赖
|
||
if implicit_dep_file in file_content or "utils" in file_content.lower():
|
||
if implicit_dep_file not in updated_deps:
|
||
updated_deps.append(implicit_dep_file)
|
||
return updated_deps
|
||
|
||
|
||
def create_progress_bar(total: int = 100, description: str = "Processing",
|
||
columns: Optional[List] = None, auto_refresh: bool = True) -> Progress:
|
||
"""
|
||
创建并配置一个标准的 rich 进度条。
|
||
|
||
Args:
|
||
total: 总任务数,默认为 100
|
||
description: 进度条的初始描述,默认为 "Processing"
|
||
columns: 自定义进度条列列表,如果为 None 则使用默认列
|
||
auto_refresh: 是否自动刷新显示,默认为 True
|
||
|
||
Returns:
|
||
Progress: 一个配置好的 Progress 实例,可以使用 start() 和 stop() 控制,或作为上下文管理器
|
||
"""
|
||
if columns is None:
|
||
columns = [
|
||
TextColumn("[progress.description]{task.description}"),
|
||
BarColumn(),
|
||
TaskProgressColumn(),
|
||
TimeElapsedColumn(),
|
||
]
|
||
progress = Progress(*columns, auto_refresh=auto_refresh)
|
||
return progress
|
||
|
||
|
||
def topological_sort(graph: Dict[str, List[str]]) -> List[str]:
|
||
"""
|
||
基于依赖图进行拓扑排序,检测循环依赖并报错。
|
||
|
||
Args:
|
||
graph: 依赖图,字典形式,键为节点(文件路径),值为该节点依赖的节点列表。
|
||
|
||
Returns:
|
||
List[str]: 拓扑排序后的节点列表。
|
||
|
||
Raises:
|
||
ValueError: 如果检测到循环依赖。
|
||
"""
|
||
# 计算入度
|
||
in_degrees = compute_in_degrees(graph)
|
||
|
||
# 初始化队列,入度为0的节点入队
|
||
zero_degree_queue = deque([node for node, degree in in_degrees.items() if degree == 0])
|
||
sorted_nodes = []
|
||
|
||
while zero_degree_queue:
|
||
node = zero_degree_queue.popleft()
|
||
sorted_nodes.append(node)
|
||
for neighbor in graph.get(node, []):
|
||
if neighbor in in_degrees:
|
||
in_degrees[neighbor] -= 1
|
||
if in_degrees[neighbor] == 0:
|
||
zero_degree_queue.append(neighbor)
|
||
|
||
# 检查循环依赖
|
||
if len(sorted_nodes) != len(graph):
|
||
raise ValueError(f"检测到循环依赖,排序节点数 {len(sorted_nodes)} 不等于总节点数 {len(graph)}")
|
||
|
||
return sorted_nodes
|