llmcodegen/src/llm_codegen/utils.py

298 lines
8.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Tuple, Dict, List, Optional, Any
import os
from pathlib import Path
import queue
from collections import deque
from loguru import logger # 添加导入
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TaskProgressColumn
# 危险命令列表,可配置
DANGEROUS_COMMANDS = ["rm", "sudo", "chmod", "dd", "mkfs", "> /dev/sda", "format"]
ALLOWED_COMMANDS = [] # 可设置白名单,为空则只检查黑名单
def is_dangerous_command(cmd: str) -> Tuple[bool, str]:
"""
判断命令是否危险
Args:
cmd: 命令字符串
Returns:
Tuple[bool, str]: (是否危险, 原因)
"""
cmd_lower = cmd.lower()
for danger in DANGEROUS_COMMANDS:
if danger in cmd_lower:
return True, f"包含危险关键词 '{danger}'"
return False, ""
def read_file(file_path: str) -> str:
"""
读取文件内容
Args:
file_path: 文件路径
Returns:
str: 文件内容
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
raise IOError(f"读取文件失败: {file_path}, 错误: {e}")
def write_file(file_path: str, content: str) -> None:
"""
写入文件内容
Args:
file_path: 文件路径
content: 要写入的内容
"""
try:
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
except Exception as e:
raise IOError(f"写入文件失败: {file_path}, 错误: {e}")
def ensure_dir(directory: str) -> None:
"""
确保目录存在,如果不存在则创建
Args:
directory: 目录路径
"""
os.makedirs(directory, exist_ok=True)
def safe_join(base_path: str, *paths: str) -> str:
"""
安全地拼接路径,防止目录遍历攻击
Args:
base_path: 基础路径
*paths: 要拼接的部分
Returns:
str: 拼接后的绝对路径
"""
full_path = os.path.abspath(os.path.join(base_path, *paths))
base_abs = os.path.abspath(base_path)
if not full_path.startswith(base_abs):
raise ValueError(f"路径拼接越界: {full_path} 不在 {base_abs}")
return full_path
def log_error(error: Exception, message: str = None, is_fatal: bool = False) -> None:
"""
记录和显示错误
Args:
error: 异常对象
message: 可选的自定义错误消息,如果为 None 则使用 error 的字符串表示
is_fatal: 指示错误是否致命
Returns:
None
"""
if message is None:
message = str(error)
log_msg = f"错误: {message}"
if is_fatal:
logger.critical(log_msg)
else:
logger.error(log_msg)
def is_fatal_error(error: Exception) -> bool:
"""
判断错误类型是否为致命错误
Args:
error: 异常对象
Returns:
bool: 如果是致命错误返回 True否则返回 False
"""
fatal_exceptions = (SystemExit, KeyboardInterrupt, MemoryError, OSError)
return isinstance(error, fatal_exceptions)
def build_dependency_graph(files: List[Dict[str, Any]]) -> Dict[str, List[str]]:
"""
构建依赖图,基于文件列表中的依赖关系
Args:
files: 文件列表,每个元素是字典,包含 'path''dependencies'
Returns:
Dict[str, List[str]]: 邻接表表示的依赖图,键为文件路径,值为依赖的文件路径列表
"""
graph: Dict[str, List[str]] = {}
for file in files:
path = file.get('path', '')
deps = file.get('dependencies', [])
if path:
graph[path] = deps
return graph
def compute_in_degrees(graph: Dict[str, List[str]]) -> Dict[str, int]:
"""
计算依赖图中每个节点的入度
Args:
graph: 依赖图,邻接表形式
Returns:
Dict[str, int]: 每个文件路径的入度值
"""
in_degrees: Dict[str, int] = {node: 0 for node in graph}
for node, deps in graph.items():
for dep in deps:
if dep in in_degrees:
in_degrees[dep] += 1
return in_degrees
class ConcurrentQueueManager:
"""
管理并发队列的简单类,用于并行任务如代码生成或检查
"""
def __init__(self, maxsize: int = 0):
"""
初始化队列
Args:
maxsize: 队列最大大小0 表示无限制
"""
self.queue = queue.Queue(maxsize)
def enqueue(self, item: Any) -> None:
"""
将项目加入队列
Args:
item: 要加入队列的项目
"""
self.queue.put(item)
def dequeue(self, block: bool = True, timeout: Optional[float] = None) -> Any:
"""
从队列中取出项目
Args:
block: 是否阻塞直到有项目可用
timeout: 超时时间(秒)
Returns:
Any: 取出的项目
"""
return self.queue.get(block=block, timeout=timeout)
def is_empty(self) -> bool:
"""
检查队列是否为空
Returns:
bool: 如果队列为空返回 True否则返回 False
"""
return self.queue.empty()
def size(self) -> int:
"""
获取队列中项目数量
Returns:
int: 队列大小
"""
return self.queue.qsize()
def add_implicit_dependency(file_content: str, current_deps: List[str], implicit_dep_file: str = "src/llm_codegen/utils.py") -> List[str]:
"""
添加隐式依赖,例如如果文件内容引用了特定文件,则自动添加依赖
Args:
file_content: 文件内容字符串
current_deps: 当前依赖列表
implicit_dep_file: 要检查的隐式依赖文件路径
Returns:
List[str]: 更新后的依赖列表,如果检测到引用则添加隐式依赖
"""
updated_deps = current_deps.copy()
# 简化检查:如果文件内容包含导入或引用 utils.py 的迹象,则添加依赖
if implicit_dep_file in file_content or "utils" in file_content.lower():
if implicit_dep_file not in updated_deps:
updated_deps.append(implicit_dep_file)
return updated_deps
def create_progress_bar(total: int = 100, description: str = "Processing",
columns: Optional[List] = None, auto_refresh: bool = True) -> Progress:
"""
创建并配置一个标准的 rich 进度条。
Args:
total: 总任务数,默认为 100
description: 进度条的初始描述,默认为 "Processing"
columns: 自定义进度条列列表,如果为 None 则使用默认列
auto_refresh: 是否自动刷新显示,默认为 True
Returns:
Progress: 一个配置好的 Progress 实例,可以使用 start() 和 stop() 控制,或作为上下文管理器
"""
if columns is None:
columns = [
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
]
progress = Progress(*columns, auto_refresh=auto_refresh)
return progress
def topological_sort(graph: Dict[str, List[str]]) -> List[str]:
"""
基于依赖图进行拓扑排序,检测循环依赖并报错。
Args:
graph: 依赖图,字典形式,键为节点(文件路径),值为该节点依赖的节点列表。
Returns:
List[str]: 拓扑排序后的节点列表。
Raises:
ValueError: 如果检测到循环依赖。
"""
# 计算入度
in_degrees = compute_in_degrees(graph)
# 初始化队列入度为0的节点入队
zero_degree_queue = deque([node for node, degree in in_degrees.items() if degree == 0])
sorted_nodes = []
while zero_degree_queue:
node = zero_degree_queue.popleft()
sorted_nodes.append(node)
for neighbor in graph.get(node, []):
if neighbor in in_degrees:
in_degrees[neighbor] -= 1
if in_degrees[neighbor] == 0:
zero_degree_queue.append(neighbor)
# 检查循环依赖
if len(sorted_nodes) != len(graph):
raise ValueError(f"检测到循环依赖,排序节点数 {len(sorted_nodes)} 不等于总节点数 {len(graph)}")
return sorted_nodes