533 lines
19 KiB
Python
533 lines
19 KiB
Python
import http.server
|
||
import json
|
||
import os
|
||
import socketserver
|
||
import subprocess
|
||
import sys
|
||
import threading
|
||
import time
|
||
import webbrowser
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Callable, Optional, Union
|
||
from urllib.parse import urlparse
|
||
|
||
import typer
|
||
|
||
app = typer.Typer(help="AI模型训练监控工具 - 基于JSON旁路记录法的移动端友好监控方案")
|
||
|
||
# 尝试导入Flask,如果失败则提供友好错误提示
|
||
try:
|
||
from .flask_monitor import start_flask_server, DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE
|
||
FLASK_AVAILABLE = True
|
||
except ImportError as e:
|
||
FLASK_AVAILABLE = False
|
||
FLASK_IMPORT_ERROR = str(e)
|
||
|
||
|
||
def get_package_dir() -> Path:
|
||
"""获取包目录路径"""
|
||
return Path(__file__).parent
|
||
|
||
|
||
def check_flask_available() -> bool:
|
||
"""检查Flask是否可用"""
|
||
return FLASK_AVAILABLE
|
||
|
||
|
||
def start_flask_monitor_server(
|
||
status_file: str,
|
||
port: int,
|
||
host: str,
|
||
open_browser: bool,
|
||
use_wsgi: bool = False,
|
||
) -> int:
|
||
"""
|
||
启动Flask监控服务器
|
||
|
||
Args:
|
||
status_file: 状态文件路径
|
||
port: 端口号
|
||
host: 主机地址
|
||
open_browser: 是否自动打开浏览器
|
||
use_wsgi: 是否使用Waitress WSGI服务器替代Flask开发服务器
|
||
|
||
Returns:
|
||
进程退出码
|
||
"""
|
||
if not FLASK_AVAILABLE:
|
||
typer.echo(f"❌ 错误: Flask未正确导入")
|
||
typer.echo(f"导入错误: {FLASK_IMPORT_ERROR}")
|
||
typer.echo("请安装Flask: pip install flask")
|
||
typer.echo("或在pyproject.toml中添加flask依赖")
|
||
return 1
|
||
|
||
# 设置环境变量,传递状态文件路径
|
||
os.environ["TRAINING_STATUS_FILE"] = os.path.abspath(status_file)
|
||
|
||
server_type = "Waitress WSGI" if use_wsgi else "Flask"
|
||
typer.echo(f"🚀 启动训练监控服务 ({server_type}版本)...")
|
||
typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}")
|
||
typer.echo(f"🌐 监控地址: http://{host}:{port}")
|
||
typer.echo(f"📊 API接口: http://{host}:{port}/api/status")
|
||
|
||
if open_browser:
|
||
# 等待服务器启动后打开浏览器
|
||
threading.Timer(2.0, lambda: webbrowser.open(f"http://{host}:{port}")).start()
|
||
typer.echo("🌐 正在打开浏览器...")
|
||
|
||
typer.echo("\n按 Ctrl+C 停止监控服务\n")
|
||
|
||
try:
|
||
# 导入并启动Flask服务器
|
||
from .flask_monitor import start_flask_server
|
||
return start_flask_server(host=host, port=port, debug=False, use_wsgi=use_wsgi)
|
||
except KeyboardInterrupt:
|
||
typer.echo("\n🛑 监控服务已停止")
|
||
return 0
|
||
except Exception as e:
|
||
typer.echo(f"❌ 启动监控服务时出错: {e}")
|
||
return 1
|
||
|
||
|
||
@app.command(name="monitor")
|
||
def monitor_training(
|
||
status_file: str = typer.Option(
|
||
"./output/training_status.json",
|
||
"--status-file",
|
||
"-s",
|
||
help="训练状态JSON文件路径",
|
||
),
|
||
port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"),
|
||
host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"),
|
||
open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"),
|
||
use_wsgi: bool = typer.Option(False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"),
|
||
):
|
||
"""
|
||
启动AI模型训练监控服务 (Flask版本)
|
||
|
||
基于JSON旁路记录法,提供移动端友好的训练监控界面。
|
||
服务启动后,可通过浏览器访问 http://<host>:<port> 查看实时训练指标。
|
||
支持本地文件和远程URL数据源,可在网页界面配置。
|
||
"""
|
||
# 检查状态文件是否存在
|
||
if not os.path.exists(status_file):
|
||
typer.echo(f"⚠️ 警告: 状态文件不存在: {status_file}")
|
||
typer.echo("开始训练后,训练脚本会自动创建此文件。")
|
||
typer.echo("您可以先启动监控服务,然后开始训练。")
|
||
|
||
# 创建目录(如果不存在)
|
||
os.makedirs(os.path.dirname(status_file), exist_ok=True)
|
||
|
||
# 创建空的JSON文件
|
||
with open(status_file, "w", encoding="utf-8") as f:
|
||
json.dump([], f)
|
||
typer.echo(f"✅ 已创建空状态文件: {status_file}")
|
||
|
||
# 检查Flask是否可用
|
||
if not check_flask_available():
|
||
typer.echo("❌ 错误: Flask未正确导入")
|
||
typer.echo("请安装Flask: pip install flask")
|
||
typer.echo("或在pyproject.toml中添加flask依赖")
|
||
raise typer.Exit(code=1)
|
||
|
||
# 启动Flask服务器
|
||
return_code = start_flask_monitor_server(
|
||
status_file=status_file,
|
||
port=port,
|
||
host=host,
|
||
open_browser=open_browser,
|
||
use_wsgi=use_wsgi,
|
||
)
|
||
|
||
raise typer.Exit(code=return_code)
|
||
|
||
|
||
@app.command(name="view")
|
||
def view_status(
|
||
status_file: str = typer.Argument(
|
||
"./output/training_status.json", help="训练状态JSON文件路径"
|
||
),
|
||
limit: int = typer.Option(10, "--limit", "-l", help="显示最近的数据条数"),
|
||
raw: bool = typer.Option(False, "--raw", help="显示原始JSON数据"),
|
||
):
|
||
"""
|
||
查看训练状态文件内容
|
||
|
||
快速查看训练状态JSON文件中的最新数据。
|
||
"""
|
||
if not os.path.exists(status_file):
|
||
typer.echo(f"❌ 错误: 状态文件不存在: {status_file}")
|
||
raise typer.Exit(code=1)
|
||
|
||
try:
|
||
with open(status_file, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
if not data:
|
||
typer.echo("ℹ️ 状态文件为空,暂无训练数据。")
|
||
return
|
||
|
||
typer.echo(f"📊 训练状态文件: {status_file}")
|
||
typer.echo(f"📈 数据总数: {len(data)} 条")
|
||
typer.echo("")
|
||
|
||
if raw:
|
||
# 显示原始JSON
|
||
typer.echo(json.dumps(data[-limit:], indent=2, ensure_ascii=False))
|
||
else:
|
||
# 显示格式化数据
|
||
display_data = data[-limit:] if limit > 0 else data
|
||
|
||
for i, record in enumerate(reversed(display_data)):
|
||
idx = len(display_data) - i
|
||
typer.echo(f"📌 记录 #{idx} (步数: {record.get('step', 'N/A')})")
|
||
typer.echo(f" 轮次: {record.get('epoch', 'N/A')}")
|
||
|
||
if "timestamp" in record:
|
||
typer.echo(f" 时间: {record['timestamp']}")
|
||
|
||
# 训练指标
|
||
if "train/loss" in record:
|
||
typer.echo(f" 训练损失: {record['train/loss']:.6f}")
|
||
if "train/accuracy" in record:
|
||
typer.echo(f" 训练准确率: {record['train/accuracy']:.6f}")
|
||
|
||
# 评估指标
|
||
if "eval/loss" in record:
|
||
typer.echo(f" 评估损失: {record['eval/loss']:.6f}")
|
||
if "eval/accuracy" in record:
|
||
typer.echo(f" 评估准确率: {record['eval/accuracy']:.6f}")
|
||
|
||
# 学习率
|
||
if "train/learning_rate" in record:
|
||
typer.echo(f" 学习率: {record['train/learning_rate']:.2e}")
|
||
|
||
if i < len(display_data) - 1:
|
||
typer.echo(" ---")
|
||
typer.echo("")
|
||
|
||
# 显示最新数据的摘要
|
||
latest = data[-1]
|
||
typer.echo("🎯 最新数据摘要:")
|
||
typer.echo(f" 当前步数: {latest.get('step', 'N/A')}")
|
||
typer.echo(f" 当前轮次: {latest.get('epoch', 'N/A')}")
|
||
if "train/loss" in latest:
|
||
typer.echo(f" 训练损失: {latest['train/loss']:.6f}")
|
||
if "eval/loss" in latest:
|
||
typer.echo(f" 评估损失: {latest['eval/loss']:.6f}")
|
||
|
||
except json.JSONDecodeError:
|
||
typer.echo("❌ 错误: 状态文件格式不正确,不是有效的JSON")
|
||
typer.echo("可能是文件正在写入中,请稍后再试。")
|
||
raise typer.Exit(code=1)
|
||
except Exception as e:
|
||
typer.echo(f"❌ 错误: 读取状态文件时出错: {e}")
|
||
raise typer.Exit(code=1)
|
||
|
||
|
||
@app.command(name="check")
|
||
def check_status(
|
||
status_file: str = typer.Argument(
|
||
"./output/training_status.json", help="训练状态JSON文件路径"
|
||
),
|
||
):
|
||
"""
|
||
检查训练状态文件
|
||
|
||
检查状态文件是否存在、格式是否正确,并显示基本信息。
|
||
"""
|
||
if not os.path.exists(status_file):
|
||
typer.echo(f"❌ 状态文件不存在: {status_file}")
|
||
typer.echo("💡 提示: 开始训练后会自动创建此文件")
|
||
raise typer.Exit(code=1)
|
||
|
||
try:
|
||
# 检查文件大小
|
||
file_size = os.path.getsize(status_file)
|
||
file_mtime = time.ctime(os.path.getmtime(status_file))
|
||
|
||
typer.echo(f"✅ 状态文件: {status_file}")
|
||
typer.echo(f" 文件大小: {file_size:,} 字节")
|
||
typer.echo(f" 修改时间: {file_mtime}")
|
||
|
||
# 检查JSON格式
|
||
with open(status_file, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
if isinstance(data, list):
|
||
typer.echo(f"✅ JSON格式正确,包含 {len(data)} 条记录")
|
||
|
||
if data:
|
||
latest = data[-1]
|
||
typer.echo(f" 最新步数: {latest.get('step', 'N/A')}")
|
||
typer.echo(f" 最新轮次: {latest.get('epoch', 'N/A')}")
|
||
|
||
# 检查关键指标
|
||
metrics = ["train/loss", "eval/loss", "train/accuracy", "eval/accuracy"]
|
||
available_metrics = [m for m in metrics if m in latest]
|
||
typer.echo(
|
||
f" 可用指标: {', '.join(available_metrics) if available_metrics else '无'}"
|
||
)
|
||
|
||
# 检查时间戳
|
||
if "timestamp" in latest:
|
||
typer.echo(f" 最后更新时间: {latest['timestamp']}")
|
||
|
||
# 检查数据新鲜度
|
||
if "timestamp" in latest:
|
||
try:
|
||
last_update = datetime.fromisoformat(
|
||
latest["timestamp"].replace("Z", "+00:00")
|
||
)
|
||
now = datetime.now()
|
||
diff = (now - last_update).total_seconds()
|
||
|
||
if diff < 60:
|
||
typer.echo(f" 数据状态: 🟢 实时 (最近 {int(diff)} 秒)")
|
||
elif diff < 300:
|
||
typer.echo(
|
||
f" 数据状态: 🟡 较新 (最近 {int(diff / 60)} 分钟)"
|
||
)
|
||
elif diff < 3600:
|
||
typer.echo(
|
||
f" 数据状态: 🟠 较旧 (最近 {int(diff / 60)} 分钟)"
|
||
)
|
||
else:
|
||
typer.echo(
|
||
f" 数据状态: 🔴 陈旧 (最近 {int(diff / 3600)} 小时)"
|
||
)
|
||
except Exception:
|
||
typer.echo(" 数据状态: ⚠️ 时间戳格式异常")
|
||
else:
|
||
typer.echo(" 数据状态: 文件为空,等待训练数据...")
|
||
else:
|
||
typer.echo("❌ JSON格式不正确: 根元素应为列表")
|
||
raise typer.Exit(code=1)
|
||
|
||
except json.JSONDecodeError:
|
||
typer.echo("❌ JSON格式错误: 文件内容不是有效的JSON")
|
||
typer.echo("可能是文件正在写入中,请稍后再试。")
|
||
raise typer.Exit(code=1)
|
||
except Exception as e:
|
||
typer.echo(f"❌ 检查状态文件时出错: {e}")
|
||
raise typer.Exit(code=1)
|
||
|
||
|
||
def create_http_handler(status_file_path: str, enable_cors: bool = True):
|
||
"""创建HTTP请求处理器"""
|
||
|
||
class TrainingStatusHTTPHandler(http.server.SimpleHTTPRequestHandler):
|
||
def do_GET(self):
|
||
# 只允许访问状态文件
|
||
parsed_path = urlparse(self.path)
|
||
if parsed_path.path not in ["/", "/training_status.json"]:
|
||
self.send_error(404, "File not found")
|
||
return
|
||
|
||
try:
|
||
# 检查文件是否存在(包括临时文件)
|
||
main_file = status_file_path
|
||
temp_file = f"{status_file_path}.tmp"
|
||
|
||
# 尝试读取数据,最多重试3次
|
||
max_retries = 3
|
||
retry_delay = 0.1 # 100ms
|
||
content = None
|
||
json_valid = False
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 首先检查主文件是否存在
|
||
if not os.path.exists(main_file):
|
||
# 检查临时文件是否存在(可能正在写入)
|
||
if os.path.exists(temp_file):
|
||
# 如果只有临时文件存在,尝试读取临时文件
|
||
file_to_read = temp_file
|
||
else:
|
||
# 两个文件都不存在
|
||
break
|
||
else:
|
||
file_to_read = main_file
|
||
|
||
# 读取文件内容
|
||
with open(file_to_read, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
|
||
# 验证JSON格式
|
||
if content:
|
||
json.loads(content) # 验证JSON格式
|
||
json_valid = True
|
||
break # JSON有效,跳出重试循环
|
||
else:
|
||
# 空内容,等待重试
|
||
time.sleep(retry_delay)
|
||
|
||
except (json.JSONDecodeError, IOError) as e:
|
||
# JSON解析错误或IO错误,等待后重试
|
||
if attempt < max_retries - 1:
|
||
time.sleep(retry_delay)
|
||
else:
|
||
# 最后一次尝试也失败
|
||
raise e
|
||
|
||
if not content or not json_valid:
|
||
self.send_error(404, "Status file not found or invalid JSON")
|
||
return
|
||
|
||
# 设置响应头
|
||
self.send_response(200)
|
||
self.send_header("Content-type", "application/json")
|
||
self.send_header("Content-Length", str(len(content)))
|
||
|
||
# 添加缓存控制头,避免浏览器缓存
|
||
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
self.send_header("Pragma", "no-cache")
|
||
self.send_header("Expires", "0")
|
||
|
||
# 添加CORS头
|
||
if enable_cors:
|
||
self.send_header("Access-Control-Allow-Origin", "*")
|
||
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
||
|
||
self.end_headers()
|
||
|
||
# 发送内容
|
||
self.wfile.write(content.encode("utf-8"))
|
||
|
||
except Exception as e:
|
||
self.send_error(500, f"Internal server error: {str(e)}")
|
||
|
||
def do_OPTIONS(self):
|
||
"""处理OPTIONS请求(用于CORS预检)"""
|
||
self.send_response(200)
|
||
self.send_header("Access-Control-Allow-Origin", "*")
|
||
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
||
self.end_headers()
|
||
|
||
def log_message(self, format, *args):
|
||
"""重写日志方法,减少日志输出"""
|
||
# 可以选择性地记录日志
|
||
# typer.echo(f"HTTP Server: {format % args}")
|
||
pass
|
||
|
||
return TrainingStatusHTTPHandler
|
||
|
||
|
||
def start_http_server(
|
||
status_file: str,
|
||
port: int,
|
||
host: str,
|
||
enable_cors: bool = True,
|
||
) -> Callable:
|
||
"""
|
||
启动HTTP服务器
|
||
|
||
Args:
|
||
status_file: 状态文件路径
|
||
port: 端口号
|
||
host: 主机地址
|
||
enable_cors: 是否启用CORS
|
||
|
||
Returns:
|
||
停止服务器的函数
|
||
"""
|
||
# 获取绝对路径
|
||
status_file_path = os.path.abspath(status_file)
|
||
|
||
# 创建自定义处理器
|
||
handler = create_http_handler(status_file_path, enable_cors)
|
||
|
||
# 创建服务器
|
||
server = socketserver.TCPServer((host, port), handler)
|
||
|
||
# 在后台启动服务器
|
||
server_thread = threading.Thread(target=server.serve_forever)
|
||
server_thread.daemon = True
|
||
server_thread.start()
|
||
|
||
typer.echo(f"🌐 HTTP服务器已启动")
|
||
typer.echo(f" 📁 状态文件: {status_file_path}")
|
||
typer.echo(f" 🔗 访问地址: http://{host}:{port}/training_status.json")
|
||
typer.echo(f" 🌍 CORS支持: {'已启用' if enable_cors else '已禁用'}")
|
||
typer.echo("\n按 Ctrl+C 停止服务器\n")
|
||
|
||
# 返回停止函数
|
||
def stop_server():
|
||
typer.echo("\n🛑 正在停止HTTP服务器...")
|
||
server.shutdown()
|
||
server.server_close()
|
||
typer.echo("✅ HTTP服务器已停止")
|
||
|
||
return stop_server
|
||
|
||
|
||
@app.command(name="serve")
|
||
def serve_status_file(
|
||
status_file: str = typer.Option(
|
||
"./output/training_status.json",
|
||
"--status-file",
|
||
"-s",
|
||
help="训练状态JSON文件路径",
|
||
),
|
||
port: int = typer.Option(8080, "--port", "-p", help="HTTP服务端口号"),
|
||
host: str = typer.Option("0.0.0.0", "--host", help="HTTP服务主机地址"),
|
||
cors: bool = typer.Option(True, "--cors", help="启用CORS支持"),
|
||
):
|
||
"""
|
||
启动HTTP服务,提供训练状态JSON文件访问
|
||
|
||
启动后可通过 http://<host>:<port>/training_status.json 访问数据
|
||
"""
|
||
# 检查状态文件是否存在
|
||
if not os.path.exists(status_file):
|
||
typer.echo(f"⚠️ 警告: 状态文件不存在: {status_file}")
|
||
typer.echo("开始训练后,训练脚本会自动创建此文件。")
|
||
typer.echo("您可以先启动HTTP服务,然后开始训练。")
|
||
|
||
# 创建目录(如果不存在)
|
||
os.makedirs(os.path.dirname(status_file), exist_ok=True)
|
||
|
||
# 创建空的JSON文件
|
||
with open(status_file, "w", encoding="utf-8") as f:
|
||
json.dump([], f)
|
||
typer.echo(f"✅ 已创建空状态文件: {status_file}")
|
||
|
||
try:
|
||
# 启动HTTP服务器
|
||
stop_server = start_http_server(
|
||
status_file=status_file,
|
||
port=port,
|
||
host=host,
|
||
enable_cors=cors,
|
||
)
|
||
|
||
# 等待用户中断
|
||
try:
|
||
while True:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
stop_server()
|
||
|
||
except OSError as e:
|
||
if "Address already in use" in str(e):
|
||
typer.echo(f"❌ 错误: 端口 {port} 已被占用")
|
||
typer.echo("请使用其他端口: --port <端口号>")
|
||
else:
|
||
typer.echo(f"❌ 启动HTTP服务器时出错: {e}")
|
||
raise typer.Exit(code=1)
|
||
except Exception as e:
|
||
typer.echo(f"❌ 启动HTTP服务器时出错: {e}")
|
||
raise typer.Exit(code=1)
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
app()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|