docs: 将Streamlit替换为Flask以支持移动端监控界面
This commit is contained in:
parent
7ac44a2731
commit
d5daba182a
20
README.md
20
README.md
|
|
@ -564,12 +564,12 @@ tensorboard --logdir ./output/tensorboard
|
||||||
|
|
||||||
### 6.6 基于JSON旁路记录法的移动端监控方案
|
### 6.6 基于JSON旁路记录法的移动端监控方案
|
||||||
|
|
||||||
为了提供移动端友好的训练监控体验,我们实现了基于JSON旁路记录法的监控方案。该方案在保持TensorBoard记录的同时,额外写入一份JSON状态文件,并通过Streamlit提供移动端友好的Web界面。
|
为了提供移动端友好的训练监控体验,我们实现了基于JSON旁路记录法的监控方案。该方案在保持TensorBoard记录的同时,额外写入一份JSON状态文件,并通过Flask提供移动端友好的Web界面。
|
||||||
|
|
||||||
#### 方案特点
|
#### 方案特点
|
||||||
|
|
||||||
**📱 移动端体验**
|
**📱 移动端体验**
|
||||||
- Streamlit自动生成响应式界面,完美适配手机屏幕
|
- Flask模板生成响应式界面,完美适配手机屏幕
|
||||||
- 图表支持双指缩放和滑动操作
|
- 图表支持双指缩放和滑动操作
|
||||||
- 大字体显示核心指标,触控操作便捷
|
- 大字体显示核心指标,触控操作便捷
|
||||||
|
|
||||||
|
|
@ -601,7 +601,7 @@ monitor-training monitor --status-file ./output/training_status.json --port 8080
|
||||||
# 不自动打开浏览器
|
# 不自动打开浏览器
|
||||||
monitor-training monitor --no-browser
|
monitor-training monitor --no-browser
|
||||||
|
|
||||||
# 指定自定义Streamlit脚本
|
# 指定自定义Flask应用脚本
|
||||||
monitor-training monitor --streamlit-script ./custom_monitor.py
|
monitor-training monitor --streamlit-script ./custom_monitor.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -671,13 +671,13 @@ monitor-training serve --host 0.0.0.0 --port 8080
|
||||||
- 支持从现有状态文件恢复,避免数据丢失
|
- 支持从现有状态文件恢复,避免数据丢失
|
||||||
|
|
||||||
**监控端搭建**
|
**监控端搭建**
|
||||||
- 使用Streamlit构建移动端友好的Web界面
|
- 使用Flask构建移动端友好的Web界面
|
||||||
- 采用Plotly图表库,支持触控交互
|
- 采用Plotly图表库,支持触控交互
|
||||||
- 自动刷新机制,实时更新训练状态
|
- 自动刷新机制,实时更新训练状态
|
||||||
|
|
||||||
**命令行工具**
|
**命令行工具**
|
||||||
- 提供 `monitor`、`view`、`check` 三个子命令
|
- 提供 `monitor`、`view`、`check` 三个子命令
|
||||||
- 自动检测Streamlit可用性
|
- 自动检测Flask可用性
|
||||||
- 支持环境变量传递
|
- 支持环境变量传递
|
||||||
|
|
||||||
#### 访问方式
|
#### 访问方式
|
||||||
|
|
@ -709,10 +709,10 @@ http://your-server.com:8501
|
||||||
# GPU服务器启动HTTP服务
|
# GPU服务器启动HTTP服务
|
||||||
monitor-training serve --port 8080 --host 0.0.0.0
|
monitor-training serve --port 8080 --host 0.0.0.0
|
||||||
|
|
||||||
# 本地运行Streamlit监控,从HTTP URL读取数据
|
# 本地运行Flask监控,从HTTP URL读取数据
|
||||||
monitor-training monitor --host 127.0.0.1 --port 8501
|
monitor-training monitor --host 127.0.0.1 --port 8501
|
||||||
|
|
||||||
# 在Streamlit界面输入远程URL:
|
# 在Flask界面输入远程URL:
|
||||||
http://<gpu服务器IP>:8080/training_status.json
|
http://<gpu服务器IP>:8080/training_status.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -749,9 +749,9 @@ http://<gpu服务器IP>:8080/training_status.json
|
||||||
**🚀 工作原理**
|
**🚀 工作原理**
|
||||||
1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件
|
1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件
|
||||||
2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务
|
2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务
|
||||||
3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面
|
3. 本地机器:运行`monitor-training monitor`启动Flask监控界面
|
||||||
4. 本地机器:在Streamlit界面输入HTTP URL访问远程数据
|
4. 本地机器:在Flask界面输入HTTP URL访问远程数据
|
||||||
5. Streamlit:通过HTTP轮询获取实时训练数据并展示
|
5. Flask:通过HTTP轮询获取实时训练数据并展示
|
||||||
|
|
||||||
**🛡️ 数据安全**
|
**🛡️ 数据安全**
|
||||||
- 原子写入:先写入临时文件,然后原子重命名,避免读取中断
|
- 原子写入:先写入临时文件,然后原子重命名,避免读取中断
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,916 @@
|
||||||
|
# 输入法预测模型架构设计 (Input Method Prediction Model)
|
||||||
|
|
||||||
|
## 1. 概述
|
||||||
|
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
|
||||||
|
|
||||||
|
## 2. 核心架构流程
|
||||||
|
数据流遵循以下路径:
|
||||||
|
`输入编码` → `Transformer 上下文编码` → `槽位记忆嵌入` → `交叉注意力融合` → `门控+专家混合 (MoE)` → `分类预测` → `束搜索解码`
|
||||||
|
|
||||||
|
### 2.1 输入层设计
|
||||||
|
模型接收三类输入,分别处理以保持语义清晰:
|
||||||
|
1. **当前文本上下文**:包含光标前文本(Prefix)和光标后文本(Suffix)。
|
||||||
|
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
|
||||||
|
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
|
||||||
|
|
||||||
|
### 2.2 模块详解
|
||||||
|
|
||||||
|
#### A. Transformer 编码器 (Context Encoder)
|
||||||
|
负责提取当前语境的深层语义表示。
|
||||||
|
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
|
||||||
|
* **骨干网络**:使用标准的 Transformer Encoder。
|
||||||
|
* **隐藏层维度**:512 [1]
|
||||||
|
* **Transformer 层数**:4 层(轻量级设计,从头训练) [1]
|
||||||
|
* **注意力头数**:4 头 [1]
|
||||||
|
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
|
||||||
|
|
||||||
|
#### B. 槽位记忆模块 (Slot Memory)
|
||||||
|
负责将非结构化的历史输入转化为结构化的记忆向量。
|
||||||
|
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
|
||||||
|
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
|
||||||
|
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。
|
||||||
|
|
||||||
|
#### C. 交叉注意力融合 (Cross-Attention Fusion)
|
||||||
|
这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。
|
||||||
|
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
|
||||||
|
* **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。
|
||||||
|
* **机制**:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。
|
||||||
|
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。
|
||||||
|
|
||||||
|
#### D. 门控与专家混合 (Gating + MoE)
|
||||||
|
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
|
||||||
|
* **专家数量**:20 个专家 [1]。
|
||||||
|
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
|
||||||
|
* **输出**:经过专家网络增强后的特征向量。
|
||||||
|
|
||||||
|
#### E. 分类头与解码
|
||||||
|
* **分类预测**:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
|
||||||
|
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
|
||||||
|
|
||||||
|
## 3. 关键超参数配置
|
||||||
|
|
||||||
|
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
|
||||||
|
|
||||||
|
| 参数项 | 推荐值 | 说明 |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
|
||||||
|
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
|
||||||
|
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
|
||||||
|
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
|
||||||
|
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
|
||||||
|
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
|
||||||
|
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
|
||||||
|
|
||||||
|
## 4. 训练策略
|
||||||
|
|
||||||
|
本模型采用标准的**序列到序列(Seq2Seq)监督学习**范式,直接对目标槽位序列进行逐步预测。
|
||||||
|
|
||||||
|
### 4.1 数据构造与标签
|
||||||
|
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
|
||||||
|
* **上下文**:光标前后的文本片段。
|
||||||
|
* **拼音**:当前待输入字的拼音序列。
|
||||||
|
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
||||||
|
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
||||||
|
|
||||||
|
### 4.2 损失函数与优化
|
||||||
|
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
||||||
|
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
||||||
|
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
|
||||||
|
|
||||||
|
### 4.3 训练流程细节
|
||||||
|
1. **前向传播**:
|
||||||
|
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
|
||||||
|
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
|
||||||
|
* 分类头输出当前步所有候选字的概率分布。
|
||||||
|
2. **Teacher Forcing**:
|
||||||
|
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。
|
||||||
|
3. **反向传播**:
|
||||||
|
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
|
||||||
|
|
||||||
|
### 4.4 推理与训练的差异
|
||||||
|
* **训练时**:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
|
||||||
|
* **推理时**:由于无法获取真实标签,模型采用**束搜索(Beam Search)** [1]。
|
||||||
|
* **束宽**:默认为 5 [1]。
|
||||||
|
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
|
||||||
|
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
|
||||||
|
|
||||||
|
## 5. Jupyter Lab 训练示例
|
||||||
|
|
||||||
|
以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# %% [markdown]
|
||||||
|
# # 输入法模型训练示例
|
||||||
|
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 1. 导入必要的库
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# 添加项目路径(适应不同的Jupyter Lab运行位置)
|
||||||
|
project_root = Path.cwd()
|
||||||
|
# 检查当前目录是否包含src目录,如果不包含则使用父目录
|
||||||
|
if not (project_root / "src").exists():
|
||||||
|
project_root = project_root.parent
|
||||||
|
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
|
||||||
|
|
||||||
|
# 导入项目模块
|
||||||
|
from src.model.model import InputMethodEngine
|
||||||
|
from src.model.dataset import PinyinInputDataset
|
||||||
|
from src.model.trainer import Trainer, worker_init_fn, collate_fn
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 2. 配置训练参数
|
||||||
|
config = {
|
||||||
|
# 数据参数
|
||||||
|
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
|
||||||
|
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
|
||||||
|
"output_dir": "./training_output",
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
"vocab_size": 10019,
|
||||||
|
"pinyin_vocab_size": 30,
|
||||||
|
"dim": 512,
|
||||||
|
"num_slots": 8,
|
||||||
|
"n_layers": 4,
|
||||||
|
"n_heads": 4,
|
||||||
|
"num_experts": 20,
|
||||||
|
"max_seq_len": 128,
|
||||||
|
|
||||||
|
# 训练参数
|
||||||
|
"batch_size": 64, # 根据GPU内存调整
|
||||||
|
"num_epochs": 10,
|
||||||
|
"learning_rate": 3e-4,
|
||||||
|
"min_learning_rate": 1e-9,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"warmup_ratio": 0.1,
|
||||||
|
"label_smoothing": 0.15,
|
||||||
|
"grad_accum_steps": 2, # 梯度累积,模拟更大batch size
|
||||||
|
"clip_grad_norm": 1.0,
|
||||||
|
"eval_frequency": 500, # 每500步评估一次
|
||||||
|
"save_frequency": 2000, # 每2000步保存检查点
|
||||||
|
|
||||||
|
# 高级选项
|
||||||
|
"mixed_precision": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"seed": 42,
|
||||||
|
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
|
||||||
|
}
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 3. 设置随机种子和设备
|
||||||
|
torch.manual_seed(config["seed"])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(config["seed"])
|
||||||
|
device = torch.device("cuda")
|
||||||
|
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 4. 创建数据集和数据加载器
|
||||||
|
print("📊 创建数据集和数据加载器...")
|
||||||
|
|
||||||
|
# 训练数据集
|
||||||
|
train_dataset = PinyinInputDataset(
|
||||||
|
data_path=config["train_data_path"],
|
||||||
|
max_workers=-1, # 自动选择worker数量
|
||||||
|
max_iter_length=config["max_iter_length"],
|
||||||
|
max_seq_length=config["max_seq_len"],
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=5000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练数据加载器
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=32,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 评估数据集
|
||||||
|
eval_dataset = PinyinInputDataset(
|
||||||
|
data_path=config["eval_data_path"],
|
||||||
|
max_workers=-1,
|
||||||
|
max_iter_length=1024, # 评估集较小
|
||||||
|
max_seq_length=config["max_seq_len"],
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=1000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
num_workers=1,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=32,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ 数据加载器创建完成")
|
||||||
|
print(f" 训练批次大小: {config['batch_size']}")
|
||||||
|
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 5. 创建模型
|
||||||
|
print("🧠 创建输入法模型...")
|
||||||
|
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
pinyin_vocab_size=config["pinyin_vocab_size"],
|
||||||
|
dim=config["dim"],
|
||||||
|
num_slots=config["num_slots"],
|
||||||
|
n_layers=config["n_layers"],
|
||||||
|
n_heads=config["n_heads"],
|
||||||
|
num_experts=config["num_experts"],
|
||||||
|
max_seq_len=config["max_seq_len"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将模型移动到设备
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 计算参数量
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
print(f"✅ 模型创建完成")
|
||||||
|
print(f" 总参数量: {total_params:,}")
|
||||||
|
print(f" 可训练参数量: {trainable_params:,}")
|
||||||
|
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 6. 创建训练器
|
||||||
|
print("⚙️ 创建训练器...")
|
||||||
|
|
||||||
|
# 计算总训练步数
|
||||||
|
total_steps = int(config["max_iter_length"] / config["batch_size"])
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
total_steps=total_steps,
|
||||||
|
output_dir=config["output_dir"],
|
||||||
|
num_epochs=config["num_epochs"],
|
||||||
|
learning_rate=config["learning_rate"],
|
||||||
|
min_learning_rate=config["min_learning_rate"],
|
||||||
|
weight_decay=config["weight_decay"],
|
||||||
|
warmup_ratio=config["warmup_ratio"],
|
||||||
|
label_smoothing=config["label_smoothing"],
|
||||||
|
grad_accum_steps=config["grad_accum_steps"],
|
||||||
|
clip_grad_norm=config["clip_grad_norm"],
|
||||||
|
eval_frequency=config["eval_frequency"],
|
||||||
|
save_frequency=config["save_frequency"],
|
||||||
|
mixed_precision=config["mixed_precision"],
|
||||||
|
use_tensorboard=config["use_tensorboard"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ 训练器创建完成")
|
||||||
|
print(f" 总训练步数: {total_steps:,}")
|
||||||
|
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
|
||||||
|
print(f" 输出目录: {config['output_dir']}")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 7. 开始训练
|
||||||
|
print("🚀 开始训练...")
|
||||||
|
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 开始训练(可以从检查点恢复训练)
|
||||||
|
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
|
||||||
|
|
||||||
|
print("✅ 训练完成!")
|
||||||
|
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
print(f"模型和日志保存在: {config['output_dir']}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("⏹️ 训练被用户中断")
|
||||||
|
print("💾 保存当前检查点...")
|
||||||
|
trainer.save_checkpoint("interrupted")
|
||||||
|
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 训练过程中出现错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 8. 监控训练进度(如果使用TensorBoard)
|
||||||
|
if config["use_tensorboard"]:
|
||||||
|
print("📈 TensorBoard日志已记录在:")
|
||||||
|
print(f" {config['output_dir']}/tensorboard")
|
||||||
|
print("\n启动TensorBoard查看训练进度:")
|
||||||
|
print(" tensorboard --logdir ./training_output/tensorboard")
|
||||||
|
print("然后在浏览器中打开: http://localhost:6006")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 9. 加载训练好的模型进行推理(示例)
|
||||||
|
def load_trained_model(checkpoint_path):
|
||||||
|
"""加载训练好的模型进行检查点"""
|
||||||
|
print(f"📥 加载检查点: {checkpoint_path}")
|
||||||
|
|
||||||
|
# 创建与训练时相同配置的模型
|
||||||
|
loaded_model = InputMethodEngine(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
pinyin_vocab_size=config["pinyin_vocab_size"],
|
||||||
|
dim=config["dim"],
|
||||||
|
num_slots=config["num_slots"],
|
||||||
|
n_layers=config["n_layers"],
|
||||||
|
n_heads=config["n_heads"],
|
||||||
|
num_experts=config["num_experts"],
|
||||||
|
max_seq_len=config["max_seq_len"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载检查点
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
loaded_model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
loaded_model.to(device)
|
||||||
|
loaded_model.eval()
|
||||||
|
|
||||||
|
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
|
||||||
|
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
|
||||||
|
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
# 使用示例(取消注释以使用)
|
||||||
|
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键说明
|
||||||
|
|
||||||
|
1. **环境要求**:
|
||||||
|
- Python 3.12+
|
||||||
|
- PyTorch 2.10+
|
||||||
|
- 建议使用GPU进行训练
|
||||||
|
- 安装项目依赖:`pip install -e .`
|
||||||
|
|
||||||
|
2. **数据集格式**:
|
||||||
|
- 使用Hugging Face `datasets`格式
|
||||||
|
- 必须包含`text`字段
|
||||||
|
- 支持流式读取(streaming=True)
|
||||||
|
|
||||||
|
3. **训练监控**:
|
||||||
|
- 控制台输出训练进度和指标
|
||||||
|
- TensorBoard记录损失、准确率、学习率等
|
||||||
|
- 定期保存模型检查点
|
||||||
|
|
||||||
|
4. **可调整参数**:
|
||||||
|
- `batch_size`: 根据GPU内存调整
|
||||||
|
- `learning_rate`: 建议在1e-4到5e-4之间
|
||||||
|
- `grad_accum_steps`: 模拟更大batch size
|
||||||
|
- `num_epochs`: 根据数据集大小调整
|
||||||
|
|
||||||
|
5. **故障排除**:
|
||||||
|
- GPU内存不足:减小`batch_size`或增加`grad_accum_steps`
|
||||||
|
- 训练不稳定:降低`learning_rate`或增加`warmup_ratio`
|
||||||
|
- 过拟合:增加`label_smoothing`或使用更大数据集
|
||||||
|
|
||||||
|
## 6. 使用指南
|
||||||
|
|
||||||
|
本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。
|
||||||
|
|
||||||
|
### 6.1 安装与准备
|
||||||
|
|
||||||
|
#### 使用 uv(推荐)
|
||||||
|
本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器,它比传统的 pip 更快且更可靠。
|
||||||
|
|
||||||
|
1. **安装 uv**(如果尚未安装):
|
||||||
|
```bash
|
||||||
|
# Linux/macOS
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
|
||||||
|
# 或使用 pipx
|
||||||
|
pipx install uv
|
||||||
|
|
||||||
|
# Windows (PowerShell)
|
||||||
|
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **安装项目依赖**:
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用传统 pip
|
||||||
|
如果不使用 uv,也可以用标准的 pip 安装:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建并激活虚拟环境(推荐)
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate # Linux/macOS
|
||||||
|
# .venv\Scripts\activate # Windows
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 验证安装
|
||||||
|
安装完成后,可通过以下命令验证:
|
||||||
|
```bash
|
||||||
|
train-model --help
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 数据格式
|
||||||
|
|
||||||
|
训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 `text` 字段,并支持流式读取(streaming=True)。
|
||||||
|
|
||||||
|
#### 本地数据集示例
|
||||||
|
```python
|
||||||
|
# dataset.py
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
|
||||||
|
}
|
||||||
|
dataset = Dataset.from_dict(data)
|
||||||
|
dataset.save_to_disk("./local_dataset")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 远程数据集示例
|
||||||
|
支持Hugging Face Hub或ModelScope上的数据集:
|
||||||
|
- `huggingface.co/datasets/username/dataset_name`
|
||||||
|
- `modelscope.cn/datasets/username/dataset_name`
|
||||||
|
|
||||||
|
#### 数据格式要求
|
||||||
|
- **必需字段**: `text`(字符串类型,包含中文文本)
|
||||||
|
- **流式读取**: 数据集必须支持 `streaming=True` 参数
|
||||||
|
- **数据量**: 建议至少数百万条文本以获得良好效果
|
||||||
|
|
||||||
|
#### 数据预处理
|
||||||
|
数据集会自动进行以下处理:
|
||||||
|
1. 文本分词和编码
|
||||||
|
2. 拼音转换和编码
|
||||||
|
3. 上下文窗口滑动生成训练样本
|
||||||
|
4. 频率调整(削峰填谷)以平衡高频/低频字词
|
||||||
|
|
||||||
|
### 6.3 基本训练命令
|
||||||
|
|
||||||
|
使用 `train-model train` 命令开始训练:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--output-dir "./output" \
|
||||||
|
--batch-size 128 \
|
||||||
|
--num-epochs 10 \
|
||||||
|
--learning-rate 1e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 检查点恢复训练
|
||||||
|
|
||||||
|
要从检查点恢复训练(保持原有的训练状态):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--resume-from "./output/checkpoints/latest_checkpoint.pt"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 重置训练状态
|
||||||
|
|
||||||
|
如果只想加载模型权重,从头开始训练(学习率、epoch等都重新开始):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--resume-from "./output/checkpoints/best_model.pt" \
|
||||||
|
--reset-training-state
|
||||||
|
```
|
||||||
|
|
||||||
|
这个功能在以下场景非常有用:
|
||||||
|
- 想要用预训练权重初始化模型,但用新的训练计划重新训练
|
||||||
|
- 需要调整学习率策略或训练时长
|
||||||
|
- 在现有模型基础上进行迁移学习
|
||||||
|
|
||||||
|
#### 学习率建议
|
||||||
|
根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围:
|
||||||
|
- **标准范围**: 1e-4 ~ 5e-4
|
||||||
|
- **配合Warmup策略**:在训练初期逐步提高学习率
|
||||||
|
- **余弦退火**:使用最小学习率 1e-9 进行细调
|
||||||
|
|
||||||
|
### 6.4 参数详解
|
||||||
|
|
||||||
|
#### 数据参数
|
||||||
|
- `--train-data-path`, `-t`: 训练数据集路径(必需)
|
||||||
|
- `--eval-data-path`, `-e`: 评估数据集路径(必需)
|
||||||
|
- `--output-dir`, `-o`: 输出目录(默认:`./output`)
|
||||||
|
- `--max_iter_length`: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728)
|
||||||
|
|
||||||
|
#### 模型参数
|
||||||
|
- `--vocab-size`: 词汇表大小(默认:10019)
|
||||||
|
- `--pinyin-vocab-size`: 拼音词汇表大小(默认:30)
|
||||||
|
- `--dim`: 模型维度(默认:512)
|
||||||
|
- `--num-slots`: 历史槽位数量(默认:8)
|
||||||
|
- `--n-layers`: Transformer层数(默认:4)
|
||||||
|
- `--n-heads`: 注意力头数(默认:4)
|
||||||
|
- `--num-experts`: MoE专家数量(默认:20)
|
||||||
|
- `--max-seq-len`: 最大序列长度(默认:128)
|
||||||
|
- `--use-pinyin`: 是否使用拼音特征(默认:False)
|
||||||
|
|
||||||
|
#### 训练参数
|
||||||
|
- `--batch-size`, `-b`: 批次大小(默认:128)
|
||||||
|
- `--num-epochs`: 训练轮数(默认:10)
|
||||||
|
- `--learning-rate`, `-lr`: 学习率(默认:1e-5)
|
||||||
|
- `--min-learning-rate`: 最小学习率(默认:1e-9)
|
||||||
|
- `--weight-decay`: 权重衰减(默认:0.1)
|
||||||
|
- `--warmup-ratio`: 热身步数比例(默认:0.1)
|
||||||
|
- `--label-smoothing`: 标签平滑参数(默认:0.15)
|
||||||
|
- `--grad-accum-steps`: 梯度累积步数(默认:1)
|
||||||
|
- `--clip-grad-norm`: 梯度裁剪范数(默认:1.0)
|
||||||
|
- `--eval-frequency`: 评估频率(默认:500步)
|
||||||
|
- `--save-frequency`: 保存频率(默认:10000步)
|
||||||
|
|
||||||
|
#### 高级选项
|
||||||
|
- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用)
|
||||||
|
- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard(默认:启用)
|
||||||
|
- `--resume-from`: 从检查点恢复训练(可选)
|
||||||
|
- `--reset-training-state`: 重置训练状态,只加载模型权重从头开始训练(默认:False)
|
||||||
|
- `--seed`: 随机种子(默认:42)
|
||||||
|
|
||||||
|
### 6.5 监控训练进度
|
||||||
|
|
||||||
|
训练过程中会显示:
|
||||||
|
- 当前训练步数/总步数
|
||||||
|
- 损失值和准确率
|
||||||
|
- 学习率变化
|
||||||
|
- 内存使用情况
|
||||||
|
|
||||||
|
启用TensorBoard后,可使用以下命令查看可视化结果:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --logdir ./output/tensorboard
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.6 基于JSON旁路记录法的移动端监控方案
|
||||||
|
|
||||||
|
为了提供移动端友好的训练监控体验,我们实现了基于JSON旁路记录法的监控方案。该方案在保持TensorBoard记录的同时,额外写入一份JSON状态文件,并通过Streamlit提供移动端友好的Web界面。
|
||||||
|
|
||||||
|
#### 方案特点
|
||||||
|
|
||||||
|
**📱 移动端体验**
|
||||||
|
- Streamlit自动生成响应式界面,完美适配手机屏幕
|
||||||
|
- 图表支持双指缩放和滑动操作
|
||||||
|
- 大字体显示核心指标,触控操作便捷
|
||||||
|
|
||||||
|
**🚀 低耦合架构**
|
||||||
|
- 训练和监控通过文件系统解耦
|
||||||
|
- 监控服务重启不影响训练进程
|
||||||
|
- 训练脚本只需几行代码修改即可支持
|
||||||
|
|
||||||
|
**🔒 安全稳定**
|
||||||
|
- 纯文本JSON文件,无文件锁冲突问题
|
||||||
|
- 读写速度快,稳定性高
|
||||||
|
- 不会影响原有的TensorBoard记录流程
|
||||||
|
|
||||||
|
**📊 实时监控**
|
||||||
|
- 默认每5秒自动刷新数据
|
||||||
|
- 实时显示训练进度和指标趋势
|
||||||
|
- 数据新鲜度状态指示(实时/较新/较旧/陈旧)
|
||||||
|
|
||||||
|
#### 使用方法
|
||||||
|
|
||||||
|
**启动监控服务**
|
||||||
|
```bash
|
||||||
|
# 启动监控服务(默认端口8501)
|
||||||
|
monitor-training monitor
|
||||||
|
|
||||||
|
# 指定状态文件路径和端口
|
||||||
|
monitor-training monitor --status-file ./output/training_status.json --port 8080
|
||||||
|
|
||||||
|
# 不自动打开浏览器
|
||||||
|
monitor-training monitor --no-browser
|
||||||
|
|
||||||
|
# 指定自定义Streamlit脚本
|
||||||
|
monitor-training monitor --streamlit-script ./custom_monitor.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**查看训练状态**
|
||||||
|
```bash
|
||||||
|
# 查看最近10条训练记录
|
||||||
|
monitor-training view
|
||||||
|
|
||||||
|
# 查看最近50条记录(原始JSON格式)
|
||||||
|
monitor-training view --limit 50 --raw
|
||||||
|
|
||||||
|
# 查看指定状态文件
|
||||||
|
monitor-training view /path/to/status.json
|
||||||
|
```
|
||||||
|
|
||||||
|
**检查状态文件**
|
||||||
|
```bash
|
||||||
|
# 检查状态文件状态
|
||||||
|
monitor-training check
|
||||||
|
|
||||||
|
# 检查指定文件
|
||||||
|
monitor-training check ./output/training_status.json
|
||||||
|
```
|
||||||
|
|
||||||
|
**启动HTTP静态文件服务**
|
||||||
|
```bash
|
||||||
|
# 启动HTTP静态文件服务(默认端口8080)
|
||||||
|
monitor-training serve
|
||||||
|
|
||||||
|
# 指定状态文件路径和端口
|
||||||
|
monitor-training serve --status-file ./output/training_status.json --port 8080
|
||||||
|
|
||||||
|
# 禁用CORS支持(默认启用)
|
||||||
|
monitor-training serve --no-cors
|
||||||
|
|
||||||
|
# 指定主机地址
|
||||||
|
monitor-training serve --host 0.0.0.0 --port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 监控界面功能
|
||||||
|
|
||||||
|
**📊 核心指标看板**
|
||||||
|
- 当前步数、轮次、训练损失、准确率
|
||||||
|
- 评估损失和准确率
|
||||||
|
- 当前学习率、最后更新时间
|
||||||
|
|
||||||
|
**📈 趋势图表**
|
||||||
|
- 损失曲线(训练损失 + 评估损失)
|
||||||
|
- 准确率曲线(训练准确率 + 评估准确率)
|
||||||
|
- 学习率变化图(对数坐标)
|
||||||
|
|
||||||
|
**📋 数据详情**
|
||||||
|
- 完整的训练记录表格
|
||||||
|
- 数据统计信息(总数据点、训练时长、总步数)
|
||||||
|
- 训练进度条
|
||||||
|
|
||||||
|
**⚙️ 配置选项**
|
||||||
|
- 状态文件路径(支持环境变量 `TRAINING_STATUS_FILE`)
|
||||||
|
- 自动刷新间隔(1-30秒可调)
|
||||||
|
- 显示数据点数量(10-1000条可调)
|
||||||
|
|
||||||
|
#### 技术实现
|
||||||
|
|
||||||
|
**训练端改造**
|
||||||
|
- 在 `Trainer.__init__` 中添加 `status_file` 参数
|
||||||
|
- 实现 `_write_training_status()` 方法,在每次评估时写入JSON文件
|
||||||
|
- 支持从现有状态文件恢复,避免数据丢失
|
||||||
|
|
||||||
|
**监控端搭建**
|
||||||
|
- 使用Streamlit构建移动端友好的Web界面
|
||||||
|
- 采用Plotly图表库,支持触控交互
|
||||||
|
- 自动刷新机制,实时更新训练状态
|
||||||
|
|
||||||
|
**命令行工具**
|
||||||
|
- 提供 `monitor`、`view`、`check` 三个子命令
|
||||||
|
- 自动检测Streamlit可用性
|
||||||
|
- 支持环境变量传递
|
||||||
|
|
||||||
|
#### 访问方式
|
||||||
|
|
||||||
|
**本地访问**
|
||||||
|
```bash
|
||||||
|
# 启动监控服务后,通过浏览器访问
|
||||||
|
http://localhost:8501
|
||||||
|
```
|
||||||
|
|
||||||
|
**局域网访问**
|
||||||
|
```bash
|
||||||
|
# 启动服务时指定主机地址
|
||||||
|
monitor-training monitor --host 0.0.0.0 --port 8080
|
||||||
|
|
||||||
|
# 手机浏览器访问(同一局域网)
|
||||||
|
http://192.168.1.100:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
**公网访问**(需端口转发)
|
||||||
|
```bash
|
||||||
|
# 确保服务器防火墙开放对应端口
|
||||||
|
# 通过域名或公网IP访问
|
||||||
|
http://your-server.com:8501
|
||||||
|
```
|
||||||
|
|
||||||
|
**远程HTTP监控**
|
||||||
|
```bash
|
||||||
|
# GPU服务器启动HTTP服务
|
||||||
|
monitor-training serve --port 8080 --host 0.0.0.0
|
||||||
|
|
||||||
|
# 本地运行Streamlit监控,从HTTP URL读取数据
|
||||||
|
monitor-training monitor --host 127.0.0.1 --port 8501
|
||||||
|
|
||||||
|
# 在Streamlit界面输入远程URL:
|
||||||
|
http://<gpu服务器IP>:8080/training_status.json
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 状态文件格式
|
||||||
|
|
||||||
|
状态文件 `training_status.json` 位于训练输出目录,格式如下:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"step": 100,
|
||||||
|
"epoch": 1,
|
||||||
|
"timestamp": "2024-01-01T12:00:00",
|
||||||
|
"train/loss": 2.345,
|
||||||
|
"train/accuracy": 0.456,
|
||||||
|
"eval/loss": 2.123,
|
||||||
|
"eval/accuracy": 0.512,
|
||||||
|
"train/learning_rate": 0.0001
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### HTTP静态文件服务与远程监控
|
||||||
|
|
||||||
|
针对GPU服务器只支持HTTP协议(不支持WebSockets)的环境,我们提供了HTTP静态文件服务方案,实现远程训练监控。
|
||||||
|
|
||||||
|
**🔧 技术特点**
|
||||||
|
- 纯HTTP协议,无需WebSockets支持
|
||||||
|
- 原子写入机制,避免读取不完整JSON数据
|
||||||
|
- 自动重试和JSON验证,确保数据完整性
|
||||||
|
- CORS支持,方便跨域访问
|
||||||
|
- 轻量级设计,不影响训练性能
|
||||||
|
|
||||||
|
**🚀 工作原理**
|
||||||
|
1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件
|
||||||
|
2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务
|
||||||
|
3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面
|
||||||
|
4. 本地机器:在Streamlit界面输入HTTP URL访问远程数据
|
||||||
|
5. Streamlit:通过HTTP轮询获取实时训练数据并展示
|
||||||
|
|
||||||
|
**🛡️ 数据安全**
|
||||||
|
- 原子写入:先写入临时文件,然后原子重命名,避免读取中断
|
||||||
|
- JSON验证:HTTP服务端验证JSON格式后才返回数据
|
||||||
|
- 临时文件处理:智能识别和读取`.tmp`临时文件
|
||||||
|
- 重试机制:JSON解析失败时自动重试读取
|
||||||
|
|
||||||
|
**🌐 网络要求**
|
||||||
|
- GPU服务器:需要开放HTTP端口(默认8080)
|
||||||
|
- 本地机器:需要能访问GPU服务器的HTTP端口
|
||||||
|
- 网络协议:纯HTTP,兼容防火墙和代理
|
||||||
|
|
||||||
|
#### 注意事项
|
||||||
|
1. 首次监控时如果状态文件不存在,会自动创建空文件
|
||||||
|
2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0`
|
||||||
|
3. 从检查点恢复训练时会自动加载已有的状态数据
|
||||||
|
4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟
|
||||||
|
5. HTTP服务支持原子写入,避免训练进程写入时读取不完整JSON
|
||||||
|
6. 远程监控需要确保GPU服务器防火墙开放对应HTTP端口
|
||||||
|
7. 建议使用`--host 0.0.0.0`参数使HTTP服务可被远程访问
|
||||||
|
|
||||||
|
|
||||||
|
### 6.7 评估模型(开发中)
|
||||||
|
|
||||||
|
当前评估功能尚在开发中:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model evaluate \
|
||||||
|
--checkpoint "./output/checkpoint_final.pt" \
|
||||||
|
--data-path "path/to/eval/dataset" \
|
||||||
|
--batch-size 32
|
||||||
|
```
|
||||||
|
|
||||||
|
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
||||||
|
- 加载训练好的模型检查点
|
||||||
|
- 在评估数据集上计算准确率、困惑度等指标
|
||||||
|
- 生成详细的性能报告
|
||||||
|
|
||||||
|
### 6.8 模型扩容两阶段训练
|
||||||
|
|
||||||
|
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||||
|
|
||||||
|
#### 训练策略
|
||||||
|
|
||||||
|
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||||
|
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||||
|
|
||||||
|
#### 基础用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model expand-and-train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--base-model-path "./pretrained/model.pt" \
|
||||||
|
--new-model-spec "model:InputMethodEngine" \
|
||||||
|
--num-experts 40 \
|
||||||
|
--frozen-lr 2e-3 \
|
||||||
|
--full-lr 5e-5 \
|
||||||
|
--frozen-patience 8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 完整参数示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model expand-and-train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--output-dir "./expansion_output" \
|
||||||
|
--base-model-path "./pretrained/model.pt" \
|
||||||
|
--new-model-spec "custom_model:ExpandedModel" \
|
||||||
|
--vocab-size 10019 \
|
||||||
|
--dim 512 \
|
||||||
|
--num-experts 40 \
|
||||||
|
--frozen-patience 10 \
|
||||||
|
--frozen-lr 1e-3 \
|
||||||
|
--full-lr 1e-4 \
|
||||||
|
--frozen-scheduler cosine \
|
||||||
|
--full-scheduler cosine \
|
||||||
|
--batch-size 128 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--compile
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 参数详解
|
||||||
|
|
||||||
|
**模型扩容参数**
|
||||||
|
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||||
|
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||||
|
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||||
|
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||||
|
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||||
|
|
||||||
|
**两阶段训练参数**
|
||||||
|
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||||
|
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||||
|
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||||
|
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||||
|
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||||
|
|
||||||
|
**其他参数**
|
||||||
|
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||||
|
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||||
|
|
||||||
|
#### 使用场景
|
||||||
|
|
||||||
|
1. **增加专家数量**(20→40)
|
||||||
|
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||||
|
- 新增参数:新专家网络、gate层
|
||||||
|
|
||||||
|
2. **增加top_k值**(2→3)
|
||||||
|
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||||
|
- 新增参数:无
|
||||||
|
|
||||||
|
3. **修改专家内部结构**(如增加resblocks)
|
||||||
|
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||||
|
- 新增参数:新增的resblocks层
|
||||||
|
|
||||||
|
4. **增加Transformer层数**(4→5)
|
||||||
|
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||||
|
- 新增参数:新增的第5层
|
||||||
|
|
||||||
|
#### 自定义模型类示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# my_model.py
|
||||||
|
from model.model import InputMethodEngine
|
||||||
|
|
||||||
|
class MyExpandedModel(InputMethodEngine):
|
||||||
|
def __init__(self, num_experts=40, **kwargs):
|
||||||
|
# 调用父类构造函数,覆盖num_experts参数
|
||||||
|
super().__init__(num_experts=num_experts, **kwargs)
|
||||||
|
# 可以在这里添加额外的层或修改现有层
|
||||||
|
|
||||||
|
# 使用命令
|
||||||
|
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 注意事项
|
||||||
|
|
||||||
|
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||||
|
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||||
|
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||||
|
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||||
|
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||||
|
|
||||||
|
### 6.9 导出模型(开发中)
|
||||||
|
|
||||||
|
当前导出功能尚在开发中:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model export \
|
||||||
|
--checkpoint "./output/checkpoint_final.pt" \
|
||||||
|
--output "./exported_model.onnx"
|
||||||
|
```
|
||||||
|
|
||||||
|
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
|
||||||
|
- 将PyTorch模型转换为ONNX格式
|
||||||
|
- 支持在不同推理引擎上部署
|
||||||
|
- 提供优化后的推理模型
|
||||||
|
|
||||||
|
## 7. 总结
|
||||||
|
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。
|
||||||
|
|
@ -0,0 +1,607 @@
|
||||||
|
# 训练指南
|
||||||
|
|
||||||
|
## Jupyter Lab 训练示例
|
||||||
|
|
||||||
|
以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# %% [markdown]
|
||||||
|
# # 输入法模型训练示例
|
||||||
|
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 1. 导入必要的库
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# 添加项目路径(适应不同的Jupyter Lab运行位置)
|
||||||
|
project_root = Path.cwd()
|
||||||
|
# 检查当前目录是否包含src目录,如果不包含则使用父目录
|
||||||
|
if not (project_root / "src").exists():
|
||||||
|
project_root = project_root.parent
|
||||||
|
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
|
||||||
|
|
||||||
|
# 导入项目模块
|
||||||
|
from src.model.model import InputMethodEngine
|
||||||
|
from src.model.dataset import PinyinInputDataset
|
||||||
|
from src.model.trainer import Trainer, worker_init_fn, collate_fn
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 2. 配置训练参数
|
||||||
|
config = {
|
||||||
|
# 数据参数
|
||||||
|
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
|
||||||
|
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
|
||||||
|
"output_dir": "./training_output",
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
"vocab_size": 10019,
|
||||||
|
"pinyin_vocab_size": 30,
|
||||||
|
"dim": 512,
|
||||||
|
"num_slots": 8,
|
||||||
|
"n_layers": 4,
|
||||||
|
"n_heads": 4,
|
||||||
|
"num_experts": 20,
|
||||||
|
"max_seq_len": 128,
|
||||||
|
|
||||||
|
# 训练参数
|
||||||
|
"batch_size": 64, # 根据GPU内存调整
|
||||||
|
"num_epochs": 10,
|
||||||
|
"learning_rate": 3e-4,
|
||||||
|
"min_learning_rate": 1e-9,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"warmup_ratio": 0.1,
|
||||||
|
"label_smoothing": 0.15,
|
||||||
|
"grad_accum_steps": 2, # 梯度累积,模拟更大batch size
|
||||||
|
"clip_grad_norm": 1.0,
|
||||||
|
"eval_frequency": 500, # 每500步评估一次
|
||||||
|
"save_frequency": 2000, # 每2000步保存检查点
|
||||||
|
|
||||||
|
# 高级选项
|
||||||
|
"mixed_precision": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"seed": 42,
|
||||||
|
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
|
||||||
|
}
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 3. 设置随机种子和设备
|
||||||
|
torch.manual_seed(config["seed"])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(config["seed"])
|
||||||
|
device = torch.device("cuda")
|
||||||
|
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 4. 创建数据集和数据加载器
|
||||||
|
print("📊 创建数据集和数据加载器...")
|
||||||
|
|
||||||
|
# 训练数据集
|
||||||
|
train_dataset = PinyinInputDataset(
|
||||||
|
data_path=config["train_data_path"],
|
||||||
|
max_workers=-1, # 自动选择worker数量
|
||||||
|
max_iter_length=config["max_iter_length"],
|
||||||
|
max_seq_length=config["max_seq_len"],
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=5000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练数据加载器
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=32,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 评估数据集
|
||||||
|
eval_dataset = PinyinInputDataset(
|
||||||
|
data_path=config["eval_data_path"],
|
||||||
|
max_workers=-1,
|
||||||
|
max_iter_length=1024, # 评估集较小
|
||||||
|
max_seq_length=config["max_seq_len"],
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=1000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
num_workers=1,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=32,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ 数据加载器创建完成")
|
||||||
|
print(f" 训练批次大小: {config['batch_size']}")
|
||||||
|
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 5. 创建模型
|
||||||
|
print("🧠 创建输入法模型...")
|
||||||
|
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
pinyin_vocab_size=config["pinyin_vocab_size"],
|
||||||
|
dim=config["dim"],
|
||||||
|
num_slots=config["num_slots"],
|
||||||
|
n_layers=config["n_layers"],
|
||||||
|
n_heads=config["n_heads"],
|
||||||
|
num_experts=config["num_experts"],
|
||||||
|
max_seq_len=config["max_seq_len"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将模型移动到设备
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 计算参数量
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
print(f"✅ 模型创建完成")
|
||||||
|
print(f" 总参数量: {total_params:,}")
|
||||||
|
print(f" 可训练参数量: {trainable_params:,}")
|
||||||
|
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 6. 创建训练器
|
||||||
|
print("⚙️ 创建训练器...")
|
||||||
|
|
||||||
|
# 计算总训练步数
|
||||||
|
total_steps = int(config["max_iter_length"] / config["batch_size"])
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
total_steps=total_steps,
|
||||||
|
output_dir=config["output_dir"],
|
||||||
|
num_epochs=config["num_epochs"],
|
||||||
|
learning_rate=config["learning_rate"],
|
||||||
|
min_learning_rate=config["min_learning_rate"],
|
||||||
|
weight_decay=config["weight_decay"],
|
||||||
|
warmup_ratio=config["warmup_ratio"],
|
||||||
|
label_smoothing=config["label_smoothing"],
|
||||||
|
grad_accum_steps=config["grad_accum_steps"],
|
||||||
|
clip_grad_norm=config["clip_grad_norm"],
|
||||||
|
eval_frequency=config["eval_frequency"],
|
||||||
|
save_frequency=config["save_frequency"],
|
||||||
|
mixed_precision=config["mixed_precision"],
|
||||||
|
use_tensorboard=config["use_tensorboard"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ 训练器创建完成")
|
||||||
|
print(f" 总训练步数: {total_steps:,}")
|
||||||
|
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
|
||||||
|
print(f" 输出目录: {config['output_dir']}")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 7. 开始训练
|
||||||
|
print("🚀 开始训练...")
|
||||||
|
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 开始训练(可以从检查点恢复训练)
|
||||||
|
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
|
||||||
|
|
||||||
|
print("✅ 训练完成!")
|
||||||
|
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
print(f"模型和日志保存在: {config['output_dir']}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("⏹️ 训练被用户中断")
|
||||||
|
print("💾 保存当前检查点...")
|
||||||
|
trainer.save_checkpoint("interrupted")
|
||||||
|
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 训练过程中出现错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 8. 监控训练进度(如果使用TensorBoard)
|
||||||
|
if config["use_tensorboard"]:
|
||||||
|
print("📈 TensorBoard日志已记录在:")
|
||||||
|
print(f" {config['output_dir']}/tensorboard")
|
||||||
|
print("\n启动TensorBoard查看训练进度:")
|
||||||
|
print(" tensorboard --logdir ./training_output/tensorboard")
|
||||||
|
print("然后在浏览器中打开: http://localhost:6006")
|
||||||
|
|
||||||
|
# %% [code]
|
||||||
|
# 9. 加载训练好的模型进行推理(示例)
|
||||||
|
def load_trained_model(checkpoint_path):
|
||||||
|
"""加载训练好的模型进行检查点"""
|
||||||
|
print(f"📥 加载检查点: {checkpoint_path}")
|
||||||
|
|
||||||
|
# 创建与训练时相同配置的模型
|
||||||
|
loaded_model = InputMethodEngine(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
pinyin_vocab_size=config["pinyin_vocab_size"],
|
||||||
|
dim=config["dim"],
|
||||||
|
num_slots=config["num_slots"],
|
||||||
|
n_layers=config["n_layers"],
|
||||||
|
n_heads=config["n_heads"],
|
||||||
|
num_experts=config["num_experts"],
|
||||||
|
max_seq_len=config["max_seq_len"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载检查点
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
loaded_model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
loaded_model.to(device)
|
||||||
|
loaded_model.eval()
|
||||||
|
|
||||||
|
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
|
||||||
|
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
|
||||||
|
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
# 使用示例(取消注释以使用)
|
||||||
|
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键说明
|
||||||
|
|
||||||
|
1. **环境要求**:
|
||||||
|
- Python 3.12+
|
||||||
|
- PyTorch 2.10+
|
||||||
|
- 建议使用GPU进行训练
|
||||||
|
- 安装项目依赖:`pip install -e .`
|
||||||
|
|
||||||
|
2. **数据集格式**:
|
||||||
|
- 使用Hugging Face `datasets`格式
|
||||||
|
- 必须包含`text`字段
|
||||||
|
- 支持流式读取(streaming=True)
|
||||||
|
|
||||||
|
3. **训练监控**:
|
||||||
|
- 控制台输出训练进度和指标
|
||||||
|
- TensorBoard记录损失、准确率、学习率等
|
||||||
|
- 定期保存模型检查点
|
||||||
|
|
||||||
|
4. **可调整参数**:
|
||||||
|
- `batch_size`: 根据GPU内存调整
|
||||||
|
- `learning_rate`: 建议在1e-4到5e-4之间
|
||||||
|
- `grad_accum_steps`: 模拟更大batch size
|
||||||
|
- `num_epochs`: 根据数据集大小调整
|
||||||
|
|
||||||
|
5. **故障排除**:
|
||||||
|
- GPU内存不足:减小`batch_size`或增加`grad_accum_steps`
|
||||||
|
- 训练不稳定:降低`learning_rate`或增加`warmup_ratio`
|
||||||
|
- 过拟合:增加`label_smoothing`或使用更大数据集
|
||||||
|
|
||||||
|
## 使用指南
|
||||||
|
|
||||||
|
本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。
|
||||||
|
|
||||||
|
### 安装与准备
|
||||||
|
|
||||||
|
#### 使用 uv(推荐)
|
||||||
|
本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器,它比传统的 pip 更快且更可靠。
|
||||||
|
|
||||||
|
1. **安装 uv**(如果尚未安装):
|
||||||
|
```bash
|
||||||
|
# Linux/macOS
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
|
||||||
|
# 或使用 pipx
|
||||||
|
pipx install uv
|
||||||
|
|
||||||
|
# Windows (PowerShell)
|
||||||
|
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **安装项目依赖**:
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用传统 pip
|
||||||
|
如果不使用 uv,也可以用标准的 pip 安装:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建并激活虚拟环境(推荐)
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate # Linux/macOS
|
||||||
|
# .venv\Scripts\activate # Windows
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 验证安装
|
||||||
|
安装完成后,可通过以下命令验证:
|
||||||
|
```bash
|
||||||
|
train-model --help
|
||||||
|
```
|
||||||
|
|
||||||
|
### 数据格式
|
||||||
|
|
||||||
|
训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 `text` 字段,并支持流式读取(streaming=True)。
|
||||||
|
|
||||||
|
#### 本地数据集示例
|
||||||
|
```python
|
||||||
|
# dataset.py
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
|
||||||
|
}
|
||||||
|
dataset = Dataset.from_dict(data)
|
||||||
|
dataset.save_to_disk("./local_dataset")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 远程数据集示例
|
||||||
|
支持Hugging Face Hub或ModelScope上的数据集:
|
||||||
|
- `huggingface.co/datasets/username/dataset_name`
|
||||||
|
- `modelscope.cn/datasets/username/dataset_name`
|
||||||
|
|
||||||
|
#### 数据格式要求
|
||||||
|
- **必需字段**: `text`(字符串类型,包含中文文本)
|
||||||
|
- **流式读取**: 数据集必须支持 `streaming=True` 参数
|
||||||
|
- **数据量**: 建议至少数百万条文本以获得良好效果
|
||||||
|
|
||||||
|
#### 数据预处理
|
||||||
|
数据集会自动进行以下处理:
|
||||||
|
1. 文本分词和编码
|
||||||
|
2. 拼音转换和编码
|
||||||
|
3. 上下文窗口滑动生成训练样本
|
||||||
|
4. 频率调整(削峰填谷)以平衡高频/低频字词
|
||||||
|
|
||||||
|
### 基本训练命令
|
||||||
|
|
||||||
|
使用 `train-model train` 命令开始训练:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--output-dir "./output" \
|
||||||
|
--batch-size 128 \
|
||||||
|
--num-epochs 10 \
|
||||||
|
--learning-rate 1e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 检查点恢复训练
|
||||||
|
|
||||||
|
要从检查点恢复训练(保持原有的训练状态):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--resume-from "./output/checkpoints/latest_checkpoint.pt"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 重置训练状态
|
||||||
|
|
||||||
|
如果只想加载模型权重,从头开始训练(学习率、epoch等都重新开始):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--resume-from "./output/checkpoints/best_model.pt" \
|
||||||
|
--reset-training-state
|
||||||
|
```
|
||||||
|
|
||||||
|
这个功能在以下场景非常有用:
|
||||||
|
- 想要用预训练权重初始化模型,但用新的训练计划重新训练
|
||||||
|
- 需要调整学习率策略或训练时长
|
||||||
|
- 在现有模型基础上进行迁移学习
|
||||||
|
|
||||||
|
#### 学习率建议
|
||||||
|
根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围:
|
||||||
|
- **标准范围**: 1e-4 ~ 5e-4
|
||||||
|
- **配合Warmup策略**:在训练初期逐步提高学习率
|
||||||
|
- **余弦退火**:使用最小学习率 1e-9 进行细调
|
||||||
|
|
||||||
|
### 参数详解
|
||||||
|
|
||||||
|
#### 数据参数
|
||||||
|
- `--train-data-path`, `-t`: 训练数据集路径(必需)
|
||||||
|
- `--eval-data-path`, ` -e`: 评估数据集路径(必需)
|
||||||
|
- `--output-dir`, `-o`: 输出目录(默认:`./output`)
|
||||||
|
- `--max_iter_length`: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728)
|
||||||
|
|
||||||
|
#### 模型参数
|
||||||
|
- `--vocab-size`: 词汇表大小(默认:10019)
|
||||||
|
- `--pinyin-vocab-size`: 拼音词汇表大小(默认:30)
|
||||||
|
- `--dim`: 模型维度(默认:512)
|
||||||
|
- `--num-slots`: 历史槽位数量(默认:8)
|
||||||
|
- `--n-layers`: Transformer层数(默认:4)
|
||||||
|
- `--n-heads`: 注意力头数(默认:4)
|
||||||
|
- `--num-experts`: MoE专家数量(默认:20)
|
||||||
|
- `--max-seq-len`: 最大序列长度(默认:128)
|
||||||
|
- `--use-pinyin`: 是否使用拼音特征(默认:False)
|
||||||
|
|
||||||
|
#### 训练参数
|
||||||
|
- `--batch-size`, `-b`: 批次大小(默认:128)
|
||||||
|
- `--num-epochs`: 训练轮数(默认:10)
|
||||||
|
- `--learning-rate`, `-lr`: 学习率(默认:1e-5)
|
||||||
|
- `--min-learning-rate`: 最小学习率(默认:1e-9)
|
||||||
|
- `--weight-decay`: 权重衰减(默认:0.1)
|
||||||
|
- `--warmup-ratio`: 热身步数比例(默认:0.1)
|
||||||
|
- `--label-smoothing`: 标签平滑参数(默认:0.15)
|
||||||
|
- `--grad-accum-steps`: 梯度累积步数(默认:1)
|
||||||
|
- `--clip-grad-norm`: 梯度裁剪范数(默认:1.0)
|
||||||
|
- `--eval-frequency`: 评估频率(默认:500步)
|
||||||
|
- `--save-frequency`: 保存频率(默认:10000步)
|
||||||
|
|
||||||
|
#### 高级选项
|
||||||
|
- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用)
|
||||||
|
- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard(默认:启用)
|
||||||
|
- `--resume-from`: 从检查点恢复训练(可选)
|
||||||
|
- `--reset-training-state`: 重置训练状态,只加载模型权重从头开始训练(默认:False)
|
||||||
|
- `--seed`: 随机种子(默认:42)
|
||||||
|
|
||||||
|
### 监控训练进度
|
||||||
|
|
||||||
|
训练过程中会显示:
|
||||||
|
- 当前训练步数/总步数
|
||||||
|
- 损失值和准确率
|
||||||
|
- 学习率变化
|
||||||
|
- 内存使用情况
|
||||||
|
|
||||||
|
启用TensorBoard后,可使用以下命令查看可视化结果:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --logdir ./output/tensorboard
|
||||||
|
```
|
||||||
|
|
||||||
|
### 评估模型(开发中)
|
||||||
|
|
||||||
|
当前评估功能尚在开发中:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model evaluate \
|
||||||
|
--checkpoint "./output/checkpoint_final.pt" \
|
||||||
|
--data-path "path/to/eval/dataset" \
|
||||||
|
--batch-size 32
|
||||||
|
```
|
||||||
|
|
||||||
|
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
||||||
|
- 加载训练好的模型检查点
|
||||||
|
- 在评估数据集上计算准确率、困惑度等指标
|
||||||
|
- 生成详细的性能报告
|
||||||
|
|
||||||
|
### 模型扩容两阶段训练
|
||||||
|
|
||||||
|
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||||
|
|
||||||
|
#### 训练策略
|
||||||
|
|
||||||
|
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||||
|
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||||
|
|
||||||
|
#### 基础用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model expand-and-train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--base-model-path "./pretrained/model.pt" \
|
||||||
|
--new-model-spec "model:InputMethodEngine" \
|
||||||
|
--num-experts 40 \
|
||||||
|
--frozen-lr 2e-3 \
|
||||||
|
--full-lr 5e-5 \
|
||||||
|
--frozen-patience 8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 完整参数示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model expand-and-train \
|
||||||
|
--train-data-path "path/to/train/dataset" \
|
||||||
|
--eval-data-path "path/to/eval/dataset" \
|
||||||
|
--output-dir "./expansion_output" \
|
||||||
|
--base-model-path "./pretrained/model.pt" \
|
||||||
|
--new-model-spec "custom_model:ExpandedModel" \
|
||||||
|
--vocab-size 10019 \
|
||||||
|
--dim 512 \
|
||||||
|
--num-experts 40 \
|
||||||
|
--frozen-patience 10 \
|
||||||
|
--frozen-lr 1e-3 \
|
||||||
|
--full-lr 1e-4 \
|
||||||
|
--frozen-scheduler cosine \
|
||||||
|
--full-scheduler cosine \
|
||||||
|
--batch-size 128 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--compile
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 参数详解
|
||||||
|
|
||||||
|
**模型扩容参数**
|
||||||
|
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||||
|
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||||
|
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||||
|
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||||
|
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||||
|
|
||||||
|
**两阶段训练参数**
|
||||||
|
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||||
|
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||||
|
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||||
|
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||||
|
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||||
|
|
||||||
|
**其他参数**
|
||||||
|
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||||
|
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||||
|
|
||||||
|
#### 使用场景
|
||||||
|
|
||||||
|
1. **增加专家数量**(20→40)
|
||||||
|
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||||
|
- 新增参数:新专家网络、gate层
|
||||||
|
|
||||||
|
2. **增加top_k值**(2→3)
|
||||||
|
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||||
|
- 新增参数:无
|
||||||
|
|
||||||
|
3. **修改专家内部结构**(如增加resblocks)
|
||||||
|
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||||
|
- 新增参数:新增的resblocks层
|
||||||
|
|
||||||
|
4. **增加Transformer层数**(4→5)
|
||||||
|
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||||
|
- 新增参数:新增的第5层
|
||||||
|
|
||||||
|
#### 自定义模型类示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# my_model.py
|
||||||
|
from model.model import InputMethodEngine
|
||||||
|
|
||||||
|
class MyExpandedModel(InputMethodEngine):
|
||||||
|
def __init__(self, num_experts=40, **kwargs):
|
||||||
|
# 调用父类构造函数,覆盖num_experts参数
|
||||||
|
super().__init__(num_experts=num_experts, **kwargs)
|
||||||
|
# 可以在这里添加额外的层或修改现有层
|
||||||
|
|
||||||
|
# 使用命令
|
||||||
|
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 注意事项
|
||||||
|
|
||||||
|
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||||
|
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||||
|
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||||
|
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||||
|
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||||
|
|
||||||
|
### 导出模型(开发中)
|
||||||
|
|
||||||
|
当前导出功能尚在开发中:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
train-model export \
|
||||||
|
--checkpoint "./output/checkpoint_final.pt" \
|
||||||
|
--output "./exported_model.onnx"
|
||||||
|
```
|
||||||
|
|
||||||
|
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
|
||||||
|
- 将PyTorch模型转换为ONNX格式
|
||||||
|
- 支持在不同推理引擎上部署
|
||||||
|
- 提供优化后的推理模型
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
[{"step": 100, "epoch": 1, "train/loss": 0.1234, "eval/loss": 0.2345}, {"step": 200, "epoch": 1, "train/loss": 0.1134, "eval/loss": 0.2245}]
|
||||||
|
|
@ -17,11 +17,12 @@ dependencies = [
|
||||||
"pypinyin>=0.55.0",
|
"pypinyin>=0.55.0",
|
||||||
"requests>=2.32.5",
|
"requests>=2.32.5",
|
||||||
"rich>=14.3.1",
|
"rich>=14.3.1",
|
||||||
"streamlit>=1.56.0",
|
"flask>=3.1.0",
|
||||||
"tensorboard>=2.20.0",
|
"tensorboard>=2.20.0",
|
||||||
"torch>=2.10.0",
|
"torch>=2.10.0",
|
||||||
"transformers==5.1.0",
|
"transformers==5.1.0",
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
|
"waitress>=3.0.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
@ -37,6 +38,9 @@ dev = [
|
||||||
"autocommit",
|
"autocommit",
|
||||||
"pytest>=9.0.2",
|
"pytest>=9.0.2",
|
||||||
]
|
]
|
||||||
|
wsgi = [
|
||||||
|
"waitress>=3.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
autocommit = { git = "https://gitea.winkinshly.site/songsenand/autocommit.git" }
|
autocommit = { git = "https://gitea.winkinshly.site/songsenand/autocommit.git" }
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,180 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from flask import Flask, render_template, request, jsonify, send_from_directory
|
||||||
|
|
||||||
|
app = Flask(__name__,
|
||||||
|
template_folder=Path(__file__).parent / 'templates',
|
||||||
|
static_folder=Path(__file__).parent / 'static')
|
||||||
|
|
||||||
|
# 全局配置
|
||||||
|
DEFAULT_STATUS_FILE = "./output/training_status.json"
|
||||||
|
DEFAULT_PORT = 8501
|
||||||
|
DEFAULT_HOST = "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_data(file_path: str, is_url: bool = False) -> list:
|
||||||
|
"""加载训练状态数据"""
|
||||||
|
if is_url:
|
||||||
|
return load_from_url(file_path)
|
||||||
|
else:
|
||||||
|
return load_from_local_file(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_url(url: str) -> list:
|
||||||
|
"""从HTTP URL加载数据"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError("requests库未安装,无法从HTTP URL加载数据")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("远程返回的数据不是列表格式")
|
||||||
|
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"从URL加载数据失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_local_file(file_path: str) -> list:
|
||||||
|
"""从本地文件加载数据"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("文件内容不是列表格式")
|
||||||
|
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_clean_data(data: list) -> list:
|
||||||
|
"""验证和清理数据"""
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
cleaned = []
|
||||||
|
for item in data:
|
||||||
|
if isinstance(item, dict) and ('step' in item or 'train/loss' in item or 'timestamp' in item):
|
||||||
|
cleaned.append(item)
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""主页"""
|
||||||
|
data_source_type = request.args.get('data_source_type', 'local')
|
||||||
|
data_source = request.args.get('data_source', DEFAULT_STATUS_FILE)
|
||||||
|
refresh_interval = int(request.args.get('refresh_interval', 5))
|
||||||
|
|
||||||
|
return render_template('index.html',
|
||||||
|
data_source_type=data_source_type,
|
||||||
|
data_source=data_source,
|
||||||
|
refresh_interval=refresh_interval)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/status')
|
||||||
|
def api_status():
|
||||||
|
"""API接口,返回训练状态数据"""
|
||||||
|
data_source_type = request.args.get('data_source_type', 'local')
|
||||||
|
data_source = request.args.get('data_source', DEFAULT_STATUS_FILE)
|
||||||
|
|
||||||
|
try:
|
||||||
|
is_url = data_source_type == 'remote'
|
||||||
|
raw_data = load_training_data(data_source, is_url)
|
||||||
|
cleaned_data = validate_and_clean_data(raw_data)
|
||||||
|
|
||||||
|
return jsonify(cleaned_data)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'error': str(e),
|
||||||
|
'data': []
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/info')
|
||||||
|
def api_info():
|
||||||
|
"""获取服务器信息"""
|
||||||
|
return jsonify({
|
||||||
|
'server_time': datetime.now().isoformat(),
|
||||||
|
'default_status_file': DEFAULT_STATUS_FILE,
|
||||||
|
'python_version': sys.version,
|
||||||
|
'working_directory': os.getcwd()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def start_flask_server(host: str, port: int, debug: bool = False, use_wsgi: bool = False):
|
||||||
|
"""启动Flask服务器"""
|
||||||
|
from flask import cli
|
||||||
|
|
||||||
|
# 禁用Flask的默认启动消息
|
||||||
|
cli.show_server_banner = lambda *args: None
|
||||||
|
|
||||||
|
print(f"🚀 启动训练监控服务 ({'Waitress WSGI' if use_wsgi else 'Flask'})...")
|
||||||
|
print(f"📁 默认状态文件: {os.path.abspath(DEFAULT_STATUS_FILE)}")
|
||||||
|
print(f"🌐 监控地址: http://{host}:{port}")
|
||||||
|
print(f"📊 API接口: http://{host}:{port}/api/status")
|
||||||
|
print("\n按 Ctrl+C 停止监控服务\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if use_wsgi:
|
||||||
|
try:
|
||||||
|
import waitress
|
||||||
|
waitress.serve(app, host=host, port=port, threads=4)
|
||||||
|
except ImportError:
|
||||||
|
print("⚠️ waitress未安装,回退到Flask开发服务器")
|
||||||
|
print(" 安装waitress: pip install waitress")
|
||||||
|
app.run(host=host, port=port, debug=debug, threaded=True)
|
||||||
|
else:
|
||||||
|
app.run(host=host, port=port, debug=debug, threaded=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n🛑 监控服务已停止")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 启动监控服务时出错: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""命令行入口点"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
global DEFAULT_STATUS_FILE
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="AI模型训练监控工具 - Flask版本")
|
||||||
|
parser.add_argument('--host', default=DEFAULT_HOST, help=f'监控服务主机地址 (默认: {DEFAULT_HOST})')
|
||||||
|
parser.add_argument('--port', type=int, default=DEFAULT_PORT, help=f'监控服务端口号 (默认: {DEFAULT_PORT})')
|
||||||
|
parser.add_argument('--debug', action='store_true', help='启用调试模式')
|
||||||
|
parser.add_argument('--use-wsgi', action='store_true', help='使用Waitress WSGI服务器替代Flask开发服务器')
|
||||||
|
parser.add_argument('--status-file', default=DEFAULT_STATUS_FILE, help='默认状态文件路径')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 更新默认状态文件
|
||||||
|
DEFAULT_STATUS_FILE = args.status_file
|
||||||
|
|
||||||
|
return start_flask_server(args.host, args.port, args.debug, args.use_wsgi)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.exit(main())
|
||||||
|
|
@ -16,109 +16,72 @@ import typer
|
||||||
|
|
||||||
app = typer.Typer(help="AI模型训练监控工具 - 基于JSON旁路记录法的移动端友好监控方案")
|
app = typer.Typer(help="AI模型训练监控工具 - 基于JSON旁路记录法的移动端友好监控方案")
|
||||||
|
|
||||||
|
# 尝试导入Flask,如果失败则提供友好错误提示
|
||||||
|
try:
|
||||||
|
from .flask_monitor import start_flask_server, DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE
|
||||||
|
FLASK_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
FLASK_AVAILABLE = False
|
||||||
|
FLASK_IMPORT_ERROR = str(e)
|
||||||
|
|
||||||
|
|
||||||
def get_package_dir() -> Path:
|
def get_package_dir() -> Path:
|
||||||
"""获取包目录路径"""
|
"""获取包目录路径"""
|
||||||
return Path(__file__).parent
|
return Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
def check_streamlit_available() -> bool:
|
def check_flask_available() -> bool:
|
||||||
"""检查Streamlit是否可用"""
|
"""检查Flask是否可用"""
|
||||||
try:
|
return FLASK_AVAILABLE
|
||||||
import streamlit
|
|
||||||
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def start_streamlit_server(
|
def start_flask_monitor_server(
|
||||||
status_file: str,
|
status_file: str,
|
||||||
port: int,
|
port: int,
|
||||||
host: str,
|
host: str,
|
||||||
open_browser: bool,
|
open_browser: bool,
|
||||||
streamlit_script: Optional[Union[str, Path]] = None,
|
use_wsgi: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
启动Streamlit服务器
|
启动Flask监控服务器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
status_file: 状态文件路径
|
status_file: 状态文件路径
|
||||||
port: 端口号
|
port: 端口号
|
||||||
host: 主机地址
|
host: 主机地址
|
||||||
open_browser: 是否自动打开浏览器
|
open_browser: 是否自动打开浏览器
|
||||||
streamlit_script: Streamlit脚本路径,如果为None则使用默认
|
use_wsgi: 是否使用Waitress WSGI服务器替代Flask开发服务器
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
进程退出码
|
进程退出码
|
||||||
"""
|
"""
|
||||||
# 确定Streamlit脚本路径
|
if not FLASK_AVAILABLE:
|
||||||
if streamlit_script is None:
|
typer.echo(f"❌ 错误: Flask未正确导入")
|
||||||
# 使用包内的training_monitor.py
|
typer.echo(f"导入错误: {FLASK_IMPORT_ERROR}")
|
||||||
package_dir = get_package_dir()
|
typer.echo("请安装Flask: pip install flask")
|
||||||
script_path = package_dir / "training_monitor.py"
|
typer.echo("或在pyproject.toml中添加flask依赖")
|
||||||
streamlit_script = str(script_path)
|
return 1
|
||||||
if not script_path.exists():
|
|
||||||
typer.echo(f"❌ 错误: 找不到Streamlit脚本: {streamlit_script}")
|
|
||||||
typer.echo("请确保training_monitor.py文件存在。")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
# 设置环境变量,传递状态文件路径
|
# 设置环境变量,传递状态文件路径
|
||||||
env = os.environ.copy()
|
os.environ["TRAINING_STATUS_FILE"] = os.path.abspath(status_file)
|
||||||
env["TRAINING_STATUS_FILE"] = os.path.abspath(status_file)
|
|
||||||
# 配置Streamlit CORS和WebSocket设置
|
server_type = "Waitress WSGI" if use_wsgi else "Flask"
|
||||||
env["STREAMLIT_SERVER_ENABLE_CORS"] = "true"
|
typer.echo(f"🚀 启动训练监控服务 ({server_type}版本)...")
|
||||||
env["STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION"] = "false"
|
|
||||||
env["STREAMLIT_SERVER_ENABLE_WEBSOCKET_COMPRESSION"] = "true"
|
|
||||||
env["STREAMLIT_SERVER_ALLOW_ORIGIN"] = "*"
|
|
||||||
|
|
||||||
# 构建Streamlit命令
|
|
||||||
cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
"streamlit",
|
|
||||||
"run",
|
|
||||||
str(streamlit_script),
|
|
||||||
"--server.port",
|
|
||||||
str(port),
|
|
||||||
"--server.address",
|
|
||||||
host,
|
|
||||||
"--server.headless",
|
|
||||||
"true" if not open_browser else "false",
|
|
||||||
"--theme.base",
|
|
||||||
"light",
|
|
||||||
"--browser.serverAddress",
|
|
||||||
host,
|
|
||||||
"--browser.gatherUsageStats",
|
|
||||||
"false",
|
|
||||||
"--server.enableCORS",
|
|
||||||
"true",
|
|
||||||
"--server.enableXsrfProtection",
|
|
||||||
"false",
|
|
||||||
"--server.enableWebsocketCompression",
|
|
||||||
"true",
|
|
||||||
"--server.maxUploadSize",
|
|
||||||
"200",
|
|
||||||
]
|
|
||||||
|
|
||||||
typer.echo("🚀 启动训练监控服务...")
|
|
||||||
typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}")
|
typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}")
|
||||||
typer.echo(f"🌐 监控地址: http://{host}:{port}")
|
typer.echo(f"🌐 监控地址: http://{host}:{port}")
|
||||||
typer.echo(f"📊 Streamlit脚本: {streamlit_script}")
|
typer.echo(f"📊 API接口: http://{host}:{port}/api/status")
|
||||||
|
|
||||||
if open_browser:
|
if open_browser:
|
||||||
# 等待服务器启动后打开浏览器
|
# 等待服务器启动后打开浏览器
|
||||||
time.sleep(2)
|
threading.Timer(2.0, lambda: webbrowser.open(f"http://{host}:{port}")).start()
|
||||||
webbrowser.open(f"http://{host}:{port}")
|
|
||||||
typer.echo("🌐 正在打开浏览器...")
|
typer.echo("🌐 正在打开浏览器...")
|
||||||
|
|
||||||
typer.echo("\n按 Ctrl+C 停止监控服务\n")
|
typer.echo("\n按 Ctrl+C 停止监控服务\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 启动Streamlit进程
|
# 导入并启动Flask服务器
|
||||||
process = subprocess.Popen(cmd, env=env)
|
from .flask_monitor import start_flask_server
|
||||||
process.wait()
|
return start_flask_server(host=host, port=port, debug=False, use_wsgi=use_wsgi)
|
||||||
return process.returncode
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
typer.echo("\n🛑 监控服务已停止")
|
typer.echo("\n🛑 监控服务已停止")
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -138,15 +101,14 @@ def monitor_training(
|
||||||
port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"),
|
port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"),
|
||||||
host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"),
|
host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"),
|
||||||
open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"),
|
open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"),
|
||||||
streamlit_script: Optional[str] = typer.Option(
|
use_wsgi: bool = typer.Option(False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"),
|
||||||
None, "--streamlit-script", help="自定义Streamlit脚本路径"
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
启动AI模型训练监控服务
|
启动AI模型训练监控服务 (Flask版本)
|
||||||
|
|
||||||
基于JSON旁路记录法,提供移动端友好的训练监控界面。
|
基于JSON旁路记录法,提供移动端友好的训练监控界面。
|
||||||
服务启动后,可通过浏览器访问 http://<host>:<port> 查看实时训练指标。
|
服务启动后,可通过浏览器访问 http://<host>:<port> 查看实时训练指标。
|
||||||
|
支持本地文件和远程URL数据源,可在网页界面配置。
|
||||||
"""
|
"""
|
||||||
# 检查状态文件是否存在
|
# 检查状态文件是否存在
|
||||||
if not os.path.exists(status_file):
|
if not os.path.exists(status_file):
|
||||||
|
|
@ -162,20 +124,20 @@ def monitor_training(
|
||||||
json.dump([], f)
|
json.dump([], f)
|
||||||
typer.echo(f"✅ 已创建空状态文件: {status_file}")
|
typer.echo(f"✅ 已创建空状态文件: {status_file}")
|
||||||
|
|
||||||
# 检查Streamlit是否可用
|
# 检查Flask是否可用
|
||||||
if not check_streamlit_available():
|
if not check_flask_available():
|
||||||
typer.echo("❌ 错误: Streamlit未安装")
|
typer.echo("❌ 错误: Flask未正确导入")
|
||||||
typer.echo("请安装Streamlit: pip install streamlit")
|
typer.echo("请安装Flask: pip install flask")
|
||||||
typer.echo("或在pyproject.toml中添加streamlit依赖")
|
typer.echo("或在pyproject.toml中添加flask依赖")
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
# 启动Streamlit服务器
|
# 启动Flask服务器
|
||||||
return_code = start_streamlit_server(
|
return_code = start_flask_monitor_server(
|
||||||
status_file=status_file,
|
status_file=status_file,
|
||||||
port=port,
|
port=port,
|
||||||
host=host,
|
host=host,
|
||||||
open_browser=open_browser,
|
open_browser=open_browser,
|
||||||
streamlit_script=streamlit_script,
|
use_wsgi=use_wsgi,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise typer.Exit(code=return_code)
|
raise typer.Exit(code=return_code)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,256 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{% block title %}AI模型训练监控看板{% endblock %}</title>
|
||||||
|
<link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🧠</text></svg>" type="image/svg+xml">
|
||||||
|
<link rel="alternate icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🧠</text></svg>" type="image/svg+xml">
|
||||||
|
<link rel="stylesheet" href="{{ url_for('static', filename='css/bulma.css') }}">
|
||||||
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||||
|
<script src="{{ url_for('static', filename='js/plotly-2.27.0.min.js') }}" charset="utf-8"></script>
|
||||||
|
<style>
|
||||||
|
@font-face {
|
||||||
|
font-family: 'SourceHanSans';
|
||||||
|
src: url("{{ url_for('static', filename='fonts/SourceHanSansSC-Medium.otf') }}") format('opentype');
|
||||||
|
font-weight: normal;
|
||||||
|
font-style: normal;
|
||||||
|
}
|
||||||
|
|
||||||
|
@font-face {
|
||||||
|
font-family: 'SmileySans';
|
||||||
|
src: url("{{ url_for('static', filename='fonts/SmileySans-Oblique.ttf') }}") format('truetype');
|
||||||
|
font-weight: normal;
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
|
||||||
|
:root {
|
||||||
|
--primary-color: #00d1b2;
|
||||||
|
--secondary-color: #3273dc;
|
||||||
|
--success-color: #23d160;
|
||||||
|
--warning-color: #ffdd57;
|
||||||
|
--danger-color: #ff3860;
|
||||||
|
--dark-color: #363636;
|
||||||
|
--light-color: #f5f5f5;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: 'SourceHanSans', 'SmileySans', 'Segoe UI', 'Microsoft YaHei', sans-serif;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
min-height: 100vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card {
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
|
||||||
|
border: none;
|
||||||
|
transition: transform 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card:hover {
|
||||||
|
transform: translateY(-5px);
|
||||||
|
box-shadow: 0 8px 30px rgba(0,0,0,0.12);
|
||||||
|
}
|
||||||
|
|
||||||
|
.card-header {
|
||||||
|
background-color: transparent;
|
||||||
|
border-bottom: 1px solid #e8e8e8;
|
||||||
|
padding: 1.25rem 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card-content {
|
||||||
|
padding: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-card {
|
||||||
|
background: linear-gradient(135deg, #ffffff 0%, #f8f9fa 100%);
|
||||||
|
border-left: 4px solid var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-value {
|
||||||
|
font-size: 2.5rem;
|
||||||
|
font-weight: 800;
|
||||||
|
color: var(--dark-color);
|
||||||
|
line-height: 1;
|
||||||
|
margin-bottom: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-label {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
color: #666;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-delta {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.positive-delta {
|
||||||
|
color: var(--success-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.negative-delta {
|
||||||
|
color: var(--danger-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chart-container {
|
||||||
|
background: white;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-indicator {
|
||||||
|
display: inline-block;
|
||||||
|
width: 10px;
|
||||||
|
height: 10px;
|
||||||
|
border-radius: 50%;
|
||||||
|
margin-right: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-active {
|
||||||
|
background-color: var(--success-color);
|
||||||
|
box-shadow: 0 0 10px var(--success-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-inactive {
|
||||||
|
background-color: var(--danger-color);
|
||||||
|
box-shadow: 0 0 10px var(--danger-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.loading {
|
||||||
|
opacity: 0.7;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.footer {
|
||||||
|
background-color: var(--dark-color);
|
||||||
|
color: white;
|
||||||
|
padding: 2rem 1.5rem;
|
||||||
|
margin-top: 3rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 768px) {
|
||||||
|
.metric-value {
|
||||||
|
font-size: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card {
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chart-container {
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-panel {
|
||||||
|
background: white;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
{% block extra_css %}{% endblock %}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<nav class="navbar is-primary" role="navigation" aria-label="main navigation">
|
||||||
|
<div class="navbar-brand">
|
||||||
|
<a class="navbar-item" href="{{ url_for('index') }}">
|
||||||
|
<span class="icon is-large">
|
||||||
|
<i class="fas fa-chart-line fa-2x"></i>
|
||||||
|
</span>
|
||||||
|
<span class="title is-4 ml-2">AI模型训练监控看板</span>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
<a role="button" class="navbar-burger" aria-label="menu" aria-expanded="false" data-target="navbarMenu">
|
||||||
|
<span aria-hidden="true"></span>
|
||||||
|
<span aria-hidden="true"></span>
|
||||||
|
<span aria-hidden="true"></span>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="navbarMenu" class="navbar-menu">
|
||||||
|
<div class="navbar-end">
|
||||||
|
<div class="navbar-item">
|
||||||
|
<div class="buttons">
|
||||||
|
<button class="button is-light">
|
||||||
|
<span class="icon">
|
||||||
|
<i class="fas fa-sync-alt"></i>
|
||||||
|
</span>
|
||||||
|
<span>刷新数据</span>
|
||||||
|
</button>
|
||||||
|
<a href="{{ url_for('api_status') }}" class="button is-info" target="_blank">
|
||||||
|
<span class="icon">
|
||||||
|
<i class="fas fa-code"></i>
|
||||||
|
</span>
|
||||||
|
<span>API接口</span>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
<section class="section">
|
||||||
|
<div class="container">
|
||||||
|
{% block content %}{% endblock %}
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<footer class="footer">
|
||||||
|
<div class="content has-text-centered">
|
||||||
|
<p>
|
||||||
|
<strong>AI模型训练监控系统</strong> - 基于JSON旁路记录法的移动端友好监控方案
|
||||||
|
</p>
|
||||||
|
<p class="mt-2">
|
||||||
|
最后更新: <span id="lastUpdateTime">--:--:--</span> |
|
||||||
|
数据源: <span id="dataSource">{{ data_source if data_source else '未设置' }}</span> |
|
||||||
|
刷新间隔: <span id="refreshInterval">5</span>秒
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</footer>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
const navbarBurger = document.querySelector('.navbar-burger');
|
||||||
|
const navbarMenu = document.querySelector('#navbarMenu');
|
||||||
|
|
||||||
|
if (navbarBurger) {
|
||||||
|
navbarBurger.addEventListener('click', () => {
|
||||||
|
navbarBurger.classList.toggle('is-active');
|
||||||
|
navbarMenu.classList.toggle('is-active');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const refreshBtn = document.getElementById('refreshBtn');
|
||||||
|
if (refreshBtn) {
|
||||||
|
refreshBtn.addEventListener('click', () => {
|
||||||
|
refreshBtn.classList.add('is-loading');
|
||||||
|
setTimeout(() => {
|
||||||
|
refreshBtn.classList.remove('is-loading');
|
||||||
|
}, 1000);
|
||||||
|
if (typeof window.refreshData === 'function') {
|
||||||
|
window.refreshData();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTime() {
|
||||||
|
const now = new Date();
|
||||||
|
const timeStr = now.toLocaleTimeString('zh-CN');
|
||||||
|
document.getElementById('lastUpdateTime').textContent = timeStr;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateTime();
|
||||||
|
setInterval(updateTime, 1000);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
{% block extra_js %}{% endblock %}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
|
@ -0,0 +1,721 @@
|
||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<div class="control-panel">
|
||||||
|
<div class="columns is-mobile is-vcentered">
|
||||||
|
<div class="column">
|
||||||
|
<div class="field">
|
||||||
|
<label class="label">数据源配置</label>
|
||||||
|
<div class="control">
|
||||||
|
<div class="select is-fullwidth">
|
||||||
|
<select id="dataSourceSelect">
|
||||||
|
<option value="local" {% if data_source_type == 'local' %}selected{% endif %}>本地文件</option>
|
||||||
|
<option value="remote" {% if data_source_type == 'remote' %}selected{% endif %}>远程URL</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column">
|
||||||
|
<div class="field" id="localFileField" {% if data_source_type == 'remote' %}style="display: none;"{% endif %}>
|
||||||
|
<label class="label">本地文件路径</label>
|
||||||
|
<div class="control">
|
||||||
|
<input class="input" type="text" id="localFilePath" value="{{ data_source if data_source_type == 'local' else './output/training_status.json' }}" placeholder="./output/training_status.json">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="field" id="remoteUrlField" {% if data_source_type != 'remote' %}style="display: none;"{% endif %}>
|
||||||
|
<label class="label">远程URL地址</label>
|
||||||
|
<div class="control">
|
||||||
|
<input class="input" type="text" id="remoteUrl" value="{{ data_source if data_source_type == 'remote' else '' }}" placeholder="http://服务器IP:端口/training_status.json">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column">
|
||||||
|
<div class="field">
|
||||||
|
<label class="label">刷新间隔(秒)</label>
|
||||||
|
<div class="control">
|
||||||
|
<div class="select is-fullwidth">
|
||||||
|
<select id="refreshIntervalSelect">
|
||||||
|
<option value="1" {% if refresh_interval == 1 %}selected{% endif %}>1秒</option>
|
||||||
|
<option value="2" {% if refresh_interval == 2 %}selected{% endif %}>2秒</option>
|
||||||
|
<option value="5" {% if refresh_interval == 5 %}selected{% endif %}>5秒</option>
|
||||||
|
<option value="10" {% if refresh_interval == 10 %}selected{% endif %}>10秒</option>
|
||||||
|
<option value="30" {% if refresh_interval == 30 %}selected{% endif %}>30秒</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-narrow">
|
||||||
|
<div class="field">
|
||||||
|
<label class="label"> </label>
|
||||||
|
<div class="control">
|
||||||
|
<button id="applyConfigBtn" class="button is-primary is-fullwidth">
|
||||||
|
<span class="icon">
|
||||||
|
<i class="fas fa-check"></i>
|
||||||
|
</span>
|
||||||
|
<span>应用配置</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="columns is-multiline">
|
||||||
|
<div class="column is-12">
|
||||||
|
<div class="card metric-card">
|
||||||
|
<div class="card-content">
|
||||||
|
<div class="columns is-mobile is-multiline">
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="currentStep">0</div>
|
||||||
|
<div class="metric-label">当前步数</div>
|
||||||
|
<div class="metric-delta" id="stepDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="currentEpoch">0</div>
|
||||||
|
<div class="metric-label">当前轮次</div>
|
||||||
|
<div class="metric-delta" id="epochDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="trainLoss">0.0000</div>
|
||||||
|
<div class="metric-label">训练损失</div>
|
||||||
|
<div class="metric-delta" id="trainLossDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="trainAccuracy">0.0000</div>
|
||||||
|
<div class="metric-label">训练准确率</div>
|
||||||
|
<div class="metric-delta" id="trainAccuracyDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="evalLoss">0.0000</div>
|
||||||
|
<div class="metric-label">评估损失</div>
|
||||||
|
<div class="metric-delta" id="evalLossDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="evalAccuracy">0.0000</div>
|
||||||
|
<div class="metric-label">评估准确率</div>
|
||||||
|
<div class="metric-delta" id="evalAccuracyDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="learningRate">0.00e+0</div>
|
||||||
|
<div class="metric-label">学习率</div>
|
||||||
|
<div class="metric-delta" id="learningRateDelta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-6-mobile is-3-tablet">
|
||||||
|
<div class="has-text-centered">
|
||||||
|
<div class="metric-value" id="dataPoints">0</div>
|
||||||
|
<div class="metric-label">数据点数</div>
|
||||||
|
<div class="metric-delta"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-12">
|
||||||
|
<div class="columns is-multiline">
|
||||||
|
<div class="column is-12-tablet is-6-desktop">
|
||||||
|
<div class="chart-container">
|
||||||
|
<h3 class="title is-5 mb-4">损失曲线</h3>
|
||||||
|
<div id="lossChart"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-12-tablet is-6-desktop">
|
||||||
|
<div class="chart-container">
|
||||||
|
<h3 class="title is-5 mb-4">准确率曲线</h3>
|
||||||
|
<div id="accuracyChart"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-12">
|
||||||
|
<div class="chart-container">
|
||||||
|
<h3 class="title is-5 mb-4">学习率变化</h3>
|
||||||
|
<div id="learningRateChart"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="column is-12">
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-header">
|
||||||
|
<h3 class="title is-5">数据详情</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-content">
|
||||||
|
<div class="table-container">
|
||||||
|
<table class="table is-fullwidth is-striped is-hoverable">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>步数</th>
|
||||||
|
<th>轮次</th>
|
||||||
|
<th>时间</th>
|
||||||
|
<th>训练损失</th>
|
||||||
|
<th>训练准确率</th>
|
||||||
|
<th>评估损失</th>
|
||||||
|
<th>评估准确率</th>
|
||||||
|
<th>学习率</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody id="dataTableBody">
|
||||||
|
<tr>
|
||||||
|
<td colspan="8" class="has-text-centered">加载数据中...</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
<nav class="pagination is-centered mt-4" role="navigation" aria-label="pagination">
|
||||||
|
<a class="pagination-previous" id="prevPageBtn" disabled>上一页</a>
|
||||||
|
<a class="pagination-next" id="nextPageBtn" disabled>下一页</a>
|
||||||
|
<ul class="pagination-list">
|
||||||
|
<li><span class="pagination-ellipsis">…</span></li>
|
||||||
|
<li><span id="pageInfo" class="pagination-link is-current">第1页</span></li>
|
||||||
|
<li><span class="pagination-ellipsis">…</span></li>
|
||||||
|
</ul>
|
||||||
|
<div class="field has-addons is-pulled-right">
|
||||||
|
<div class="control">
|
||||||
|
<input class="input" type="number" id="pageSizeInput" min="5" max="50" step="5" value="10" style="width: 80px;">
|
||||||
|
</div>
|
||||||
|
<div class="control">
|
||||||
|
<a class="button is-static">条/页</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</nav>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block extra_js %}
|
||||||
|
<script>
|
||||||
|
let currentData = [];
|
||||||
|
let refreshTimer = null;
|
||||||
|
let currentRefreshInterval = {{ refresh_interval }};
|
||||||
|
let currentDataSourceType = '{{ data_source_type }}';
|
||||||
|
let currentDataSource = '{{ data_source }}';
|
||||||
|
let lastDataHash = '';
|
||||||
|
let currentPage = 1;
|
||||||
|
let pageSize = 10;
|
||||||
|
let totalPages = 1;
|
||||||
|
let prevPageBtn = null;
|
||||||
|
let nextPageBtn = null;
|
||||||
|
let pageInfo = null;
|
||||||
|
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
const dataSourceSelect = document.getElementById('dataSourceSelect');
|
||||||
|
const localFileField = document.getElementById('localFileField');
|
||||||
|
const remoteUrlField = document.getElementById('remoteUrlField');
|
||||||
|
const refreshIntervalSelect = document.getElementById('refreshIntervalSelect');
|
||||||
|
const applyConfigBtn = document.getElementById('applyConfigBtn');
|
||||||
|
prevPageBtn = document.getElementById('prevPageBtn');
|
||||||
|
nextPageBtn = document.getElementById('nextPageBtn');
|
||||||
|
const pageSizeInput = document.getElementById('pageSizeInput');
|
||||||
|
pageInfo = document.getElementById('pageInfo');
|
||||||
|
|
||||||
|
dataSourceSelect.addEventListener('change', function() {
|
||||||
|
if (this.value === 'local') {
|
||||||
|
localFileField.style.display = 'block';
|
||||||
|
remoteUrlField.style.display = 'none';
|
||||||
|
} else {
|
||||||
|
localFileField.style.display = 'none';
|
||||||
|
remoteUrlField.style.display = 'block';
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
applyConfigBtn.addEventListener('click', function() {
|
||||||
|
const dataSourceType = dataSourceSelect.value;
|
||||||
|
let dataSource = '';
|
||||||
|
|
||||||
|
if (dataSourceType === 'local') {
|
||||||
|
dataSource = document.getElementById('localFilePath').value.trim();
|
||||||
|
if (!dataSource) {
|
||||||
|
showNotification('请输入本地文件路径', 'warning');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dataSource = document.getElementById('remoteUrl').value.trim();
|
||||||
|
if (!dataSource) {
|
||||||
|
showNotification('请输入远程URL地址', 'warning');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!dataSource.startsWith('http://') && !dataSource.startsWith('https://')) {
|
||||||
|
showNotification('URL必须以http://或https://开头', 'warning');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const refreshInterval = parseInt(refreshIntervalSelect.value);
|
||||||
|
|
||||||
|
currentDataSourceType = dataSourceType;
|
||||||
|
currentDataSource = dataSource;
|
||||||
|
currentRefreshInterval = refreshInterval;
|
||||||
|
|
||||||
|
document.getElementById('refreshInterval').textContent = refreshInterval;
|
||||||
|
document.getElementById('dataSource').textContent = dataSource;
|
||||||
|
|
||||||
|
localStorage.setItem('monitor_config', JSON.stringify({
|
||||||
|
dataSourceType: dataSourceType,
|
||||||
|
dataSource: dataSource,
|
||||||
|
refreshInterval: refreshInterval
|
||||||
|
}));
|
||||||
|
|
||||||
|
showNotification('配置已应用,正在重新加载数据...', 'success');
|
||||||
|
|
||||||
|
// 重置数据状态,因为数据源已改变
|
||||||
|
lastDataHash = '';
|
||||||
|
currentPage = 1;
|
||||||
|
|
||||||
|
clearInterval(refreshTimer);
|
||||||
|
loadData();
|
||||||
|
startRefreshTimer();
|
||||||
|
});
|
||||||
|
|
||||||
|
const savedConfig = localStorage.getItem('monitor_config');
|
||||||
|
if (savedConfig) {
|
||||||
|
try {
|
||||||
|
const config = JSON.parse(savedConfig);
|
||||||
|
if (config.dataSourceType && config.dataSource) {
|
||||||
|
dataSourceSelect.value = config.dataSourceType;
|
||||||
|
if (config.dataSourceType === 'local') {
|
||||||
|
document.getElementById('localFilePath').value = config.dataSource;
|
||||||
|
localFileField.style.display = 'block';
|
||||||
|
remoteUrlField.style.display = 'none';
|
||||||
|
} else {
|
||||||
|
document.getElementById('remoteUrl').value = config.dataSource;
|
||||||
|
localFileField.style.display = 'none';
|
||||||
|
remoteUrlField.style.display = 'block';
|
||||||
|
}
|
||||||
|
refreshIntervalSelect.value = config.refreshInterval || 5;
|
||||||
|
|
||||||
|
currentDataSourceType = config.dataSourceType;
|
||||||
|
currentDataSource = config.dataSource;
|
||||||
|
currentRefreshInterval = config.refreshInterval || 5;
|
||||||
|
|
||||||
|
document.getElementById('refreshInterval').textContent = config.refreshInterval || 5;
|
||||||
|
document.getElementById('dataSource').textContent = config.dataSource;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('加载配置失败:', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分页事件监听器
|
||||||
|
if (prevPageBtn) {
|
||||||
|
prevPageBtn.addEventListener('click', () => {
|
||||||
|
if (currentPage > 1) {
|
||||||
|
currentPage--;
|
||||||
|
updateDataTable(currentData);
|
||||||
|
updatePagination();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nextPageBtn) {
|
||||||
|
nextPageBtn.addEventListener('click', () => {
|
||||||
|
if (currentPage < totalPages) {
|
||||||
|
currentPage++;
|
||||||
|
updateDataTable(currentData);
|
||||||
|
updatePagination();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pageSizeInput) {
|
||||||
|
pageSizeInput.addEventListener('change', () => {
|
||||||
|
const newPageSize = parseInt(pageSizeInput.value);
|
||||||
|
if (newPageSize >= 5 && newPageSize <= 50) {
|
||||||
|
pageSize = newPageSize;
|
||||||
|
currentPage = 1;
|
||||||
|
updateDataTable(currentData);
|
||||||
|
updatePagination();
|
||||||
|
} else {
|
||||||
|
pageSizeInput.value = pageSize;
|
||||||
|
showNotification('每页显示数量应在5到50之间', 'warning');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
function startRefreshTimer() {
|
||||||
|
clearInterval(refreshTimer);
|
||||||
|
refreshTimer = setInterval(loadData, currentRefreshInterval * 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
window.refreshData = loadData;
|
||||||
|
|
||||||
|
loadData();
|
||||||
|
startRefreshTimer();
|
||||||
|
});
|
||||||
|
|
||||||
|
function computeDataHash(data) {
|
||||||
|
if (!data || data.length === 0) {
|
||||||
|
return 'empty';
|
||||||
|
}
|
||||||
|
const latest = data[data.length - 1];
|
||||||
|
const latestStep = latest.step || 0;
|
||||||
|
const latestTimestamp = latest.timestamp || '';
|
||||||
|
return `${data.length}:${latestStep}:${latestTimestamp}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadData() {
|
||||||
|
const loadingElement = document.getElementById('dataTableBody');
|
||||||
|
if (loadingElement) {
|
||||||
|
loadingElement.innerHTML = '<tr><td colspan="8" class="has-text-centered">加载数据中...</td></tr>';
|
||||||
|
}
|
||||||
|
|
||||||
|
const params = new URLSearchParams();
|
||||||
|
params.append('data_source_type', currentDataSourceType);
|
||||||
|
params.append('data_source', currentDataSource);
|
||||||
|
|
||||||
|
fetch(`/api/status?${params.toString()}`)
|
||||||
|
.then(response => {
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
return response.json();
|
||||||
|
})
|
||||||
|
.then(data => {
|
||||||
|
const newHash = computeDataHash(data);
|
||||||
|
if (newHash === lastDataHash) {
|
||||||
|
// 数据无变化,不刷新界面
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
lastDataHash = newHash;
|
||||||
|
currentPage = 1; // 新数据到来,重置到第1页
|
||||||
|
currentData = data;
|
||||||
|
updateMetrics(data);
|
||||||
|
updateCharts(data);
|
||||||
|
updateDataTable(data);
|
||||||
|
showNotification('数据更新成功', 'success');
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
console.error('加载数据失败:', error);
|
||||||
|
showNotification(`加载数据失败: ${error.message}`, 'danger');
|
||||||
|
if (loadingElement) {
|
||||||
|
loadingElement.innerHTML = `<tr><td colspan="8" class="has-text-centered has-text-danger">加载失败: ${error.message}</td></tr>`;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateMetrics(data) {
|
||||||
|
if (!data || data.length === 0) {
|
||||||
|
resetMetrics();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const latest = data[data.length - 1];
|
||||||
|
const prev = data.length >= 2 ? data[data.length - 2] : null;
|
||||||
|
|
||||||
|
function formatValue(value, format = 'number') {
|
||||||
|
if (value === undefined || value === null) return 'N/A';
|
||||||
|
|
||||||
|
if (format === 'number') {
|
||||||
|
return Number(value).toLocaleString('zh-CN');
|
||||||
|
} else if (format === 'float4') {
|
||||||
|
return Number(value).toFixed(4);
|
||||||
|
} else if (format === 'scientific') {
|
||||||
|
return Number(value).toExponential(2);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateMetric(elementId, value, delta = null, format = 'number') {
|
||||||
|
const element = document.getElementById(elementId);
|
||||||
|
if (element) {
|
||||||
|
element.textContent = formatValue(value, format);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (delta !== null) {
|
||||||
|
const deltaElement = document.getElementById(elementId + 'Delta');
|
||||||
|
if (deltaElement) {
|
||||||
|
const deltaNum = Number(delta);
|
||||||
|
if (!isNaN(deltaNum)) {
|
||||||
|
deltaElement.textContent = (deltaNum >= 0 ? '+' : '') + deltaNum.toFixed(4);
|
||||||
|
deltaElement.className = 'metric-delta ' + (deltaNum >= 0 ? 'positive-delta' : 'negative-delta');
|
||||||
|
} else {
|
||||||
|
deltaElement.textContent = '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updateMetric('currentStep', latest.step, prev ? latest.step - prev.step : null);
|
||||||
|
updateMetric('currentEpoch', latest.epoch, prev ? latest.epoch - prev.epoch : null);
|
||||||
|
updateMetric('trainLoss', latest['train/loss'], prev ? latest['train/loss'] - prev['train/loss'] : null, 'float4');
|
||||||
|
updateMetric('trainAccuracy', latest['train/accuracy'], prev ? latest['train/accuracy'] - prev['train/accuracy'] : null, 'float4');
|
||||||
|
updateMetric('evalLoss', latest['eval/loss'], prev ? latest['eval/loss'] - prev['eval/loss'] : null, 'float4');
|
||||||
|
updateMetric('evalAccuracy', latest['eval/accuracy'], prev ? latest['eval/accuracy'] - prev['eval/accuracy'] : null, 'float4');
|
||||||
|
updateMetric('learningRate', latest['train/learning_rate'], prev ? latest['train/learning_rate'] - prev['train/learning_rate'] : null, 'scientific');
|
||||||
|
updateMetric('dataPoints', data.length, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
function resetMetrics() {
|
||||||
|
document.getElementById('currentStep').textContent = '0';
|
||||||
|
document.getElementById('currentEpoch').textContent = '0';
|
||||||
|
document.getElementById('trainLoss').textContent = '0.0000';
|
||||||
|
document.getElementById('trainAccuracy').textContent = '0.0000';
|
||||||
|
document.getElementById('evalLoss').textContent = '0.0000';
|
||||||
|
document.getElementById('evalAccuracy').textContent = '0.0000';
|
||||||
|
document.getElementById('learningRate').textContent = '0.00e+0';
|
||||||
|
document.getElementById('dataPoints').textContent = '0';
|
||||||
|
|
||||||
|
const deltaElements = document.querySelectorAll('.metric-delta');
|
||||||
|
deltaElements.forEach(el => el.textContent = '');
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateCharts(data) {
|
||||||
|
if (!data || data.length === 0) {
|
||||||
|
createEmptyChart('lossChart', '损失曲线', '暂无数据');
|
||||||
|
createEmptyChart('accuracyChart', '准确率曲线', '暂无数据');
|
||||||
|
createEmptyChart('learningRateChart', '学习率变化', '暂无数据');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const steps = data.map(d => d.step);
|
||||||
|
|
||||||
|
const lossTrace1 = {
|
||||||
|
x: steps,
|
||||||
|
y: data.map(d => d['train/loss']),
|
||||||
|
mode: 'lines+markers',
|
||||||
|
name: '训练损失',
|
||||||
|
line: { color: '#1f77b4', width: 2 },
|
||||||
|
marker: { size: 4 }
|
||||||
|
};
|
||||||
|
|
||||||
|
const lossTrace2 = {
|
||||||
|
x: steps,
|
||||||
|
y: data.map(d => d['eval/loss']),
|
||||||
|
mode: 'lines+markers',
|
||||||
|
name: '评估损失',
|
||||||
|
line: { color: '#ff7f0e', width: 2, dash: 'dash' },
|
||||||
|
marker: { size: 4 }
|
||||||
|
};
|
||||||
|
|
||||||
|
Plotly.newPlot('lossChart', [lossTrace1, lossTrace2], {
|
||||||
|
title: '损失曲线',
|
||||||
|
xaxis: { title: '训练步数' },
|
||||||
|
yaxis: { title: '损失值' },
|
||||||
|
hovermode: 'x unified',
|
||||||
|
template: 'plotly_white',
|
||||||
|
margin: { l: 50, r: 30, t: 50, b: 50 },
|
||||||
|
legend: { orientation: 'h', y: 1.1 }
|
||||||
|
});
|
||||||
|
|
||||||
|
const accuracyTrace1 = {
|
||||||
|
x: steps,
|
||||||
|
y: data.map(d => d['train/accuracy']),
|
||||||
|
mode: 'lines+markers',
|
||||||
|
name: '训练准确率',
|
||||||
|
line: { color: '#2ca02c', width: 2 },
|
||||||
|
marker: { size: 4 }
|
||||||
|
};
|
||||||
|
|
||||||
|
const accuracyTrace2 = {
|
||||||
|
x: steps,
|
||||||
|
y: data.map(d => d['eval/accuracy']),
|
||||||
|
mode: 'lines+markers',
|
||||||
|
name: '评估准确率',
|
||||||
|
line: { color: '#d62728', width: 2, dash: 'dash' },
|
||||||
|
marker: { size: 4 }
|
||||||
|
};
|
||||||
|
|
||||||
|
Plotly.newPlot('accuracyChart', [accuracyTrace1, accuracyTrace2], {
|
||||||
|
title: '准确率曲线',
|
||||||
|
xaxis: { title: '训练步数' },
|
||||||
|
yaxis: { title: '准确率' },
|
||||||
|
hovermode: 'x unified',
|
||||||
|
template: 'plotly_white',
|
||||||
|
margin: { l: 50, r: 30, t: 50, b: 50 },
|
||||||
|
legend: { orientation: 'h', y: 1.1 }
|
||||||
|
});
|
||||||
|
|
||||||
|
if (data[0]['train/learning_rate'] !== undefined) {
|
||||||
|
const lrTrace = {
|
||||||
|
x: steps,
|
||||||
|
y: data.map(d => d['train/learning_rate']),
|
||||||
|
mode: 'lines+markers',
|
||||||
|
name: '学习率',
|
||||||
|
line: { color: '#9467bd', width: 2 },
|
||||||
|
marker: { size: 4 }
|
||||||
|
};
|
||||||
|
|
||||||
|
Plotly.newPlot('learningRateChart', [lrTrace], {
|
||||||
|
title: '学习率变化',
|
||||||
|
xaxis: { title: '训练步数' },
|
||||||
|
yaxis: { title: '学习率', type: 'log' },
|
||||||
|
hovermode: 'x unified',
|
||||||
|
template: 'plotly_white',
|
||||||
|
margin: { l: 50, r: 30, t: 50, b: 50 }
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
createEmptyChart('learningRateChart', '学习率变化', '暂无学习率数据');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function createEmptyChart(containerId, title, message) {
|
||||||
|
const trace = {
|
||||||
|
x: [0],
|
||||||
|
y: [0],
|
||||||
|
mode: 'markers',
|
||||||
|
marker: { size: 0 },
|
||||||
|
showlegend: false
|
||||||
|
};
|
||||||
|
|
||||||
|
Plotly.newPlot(containerId, [trace], {
|
||||||
|
title: title,
|
||||||
|
xaxis: { showgrid: false, zeroline: false, showticklabels: false },
|
||||||
|
yaxis: { showgrid: false, zeroline: false, showticklabels: false },
|
||||||
|
annotations: [{
|
||||||
|
text: message,
|
||||||
|
xref: 'paper',
|
||||||
|
yref: 'paper',
|
||||||
|
x: 0.5,
|
||||||
|
y: 0.5,
|
||||||
|
showarrow: false,
|
||||||
|
font: { size: 16, color: '#999' }
|
||||||
|
}],
|
||||||
|
margin: { l: 50, r: 30, t: 50, b: 50 }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateDataTable(data) {
|
||||||
|
const tbody = document.getElementById('dataTableBody');
|
||||||
|
if (!tbody) return;
|
||||||
|
|
||||||
|
if (!data || data.length === 0) {
|
||||||
|
tbody.innerHTML = '<tr><td colspan="8" class="has-text-centered">暂无数据</td></tr>';
|
||||||
|
updatePagination();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算分页数据
|
||||||
|
const totalItems = data.length;
|
||||||
|
totalPages = Math.max(1, Math.ceil(totalItems / pageSize));
|
||||||
|
|
||||||
|
// 确保当前页在有效范围内
|
||||||
|
if (currentPage > totalPages) currentPage = totalPages;
|
||||||
|
if (currentPage < 1) currentPage = 1;
|
||||||
|
|
||||||
|
// 计算要显示的数据范围(从最新数据开始分页)
|
||||||
|
const startIndex = Math.max(0, totalItems - currentPage * pageSize);
|
||||||
|
const endIndex = Math.max(0, totalItems - (currentPage - 1) * pageSize);
|
||||||
|
|
||||||
|
// 获取数据并反转顺序(最新的数据在表格顶部)
|
||||||
|
const pageData = data.slice(startIndex, endIndex).reverse();
|
||||||
|
|
||||||
|
let html = '';
|
||||||
|
pageData.forEach(record => {
|
||||||
|
const timestamp = record.timestamp ?
|
||||||
|
new Date(record.timestamp).toLocaleString('zh-CN') :
|
||||||
|
'N/A';
|
||||||
|
|
||||||
|
html += `
|
||||||
|
<tr>
|
||||||
|
<td>${record.step || 'N/A'}</td>
|
||||||
|
<td>${record.epoch || 'N/A'}</td>
|
||||||
|
<td>${timestamp}</td>
|
||||||
|
<td>${record['train/loss'] !== undefined ? Number(record['train/loss']).toFixed(4) : 'N/A'}</td>
|
||||||
|
<td>${record['train/accuracy'] !== undefined ? Number(record['train/accuracy']).toFixed(4) : 'N/A'}</td>
|
||||||
|
<td>${record['eval/loss'] !== undefined ? Number(record['eval/loss']).toFixed(4) : 'N/A'}</td>
|
||||||
|
<td>${record['eval/accuracy'] !== undefined ? Number(record['eval/accuracy']).toFixed(4) : 'N/A'}</td>
|
||||||
|
<td>${record['train/learning_rate'] !== undefined ? Number(record['train/learning_rate']).toExponential(2) : 'N/A'}</td>
|
||||||
|
</tr>
|
||||||
|
`;
|
||||||
|
});
|
||||||
|
|
||||||
|
tbody.innerHTML = html;
|
||||||
|
updatePagination();
|
||||||
|
}
|
||||||
|
|
||||||
|
function updatePagination() {
|
||||||
|
if (!prevPageBtn || !nextPageBtn || !pageInfo) return;
|
||||||
|
|
||||||
|
// 更新按钮状态
|
||||||
|
prevPageBtn.disabled = currentPage <= 1;
|
||||||
|
nextPageBtn.disabled = currentPage >= totalPages;
|
||||||
|
|
||||||
|
// 更新页数信息
|
||||||
|
pageInfo.textContent = `第${currentPage}页 / 共${totalPages}页`;
|
||||||
|
|
||||||
|
// 更新按钮样式
|
||||||
|
if (prevPageBtn.disabled) {
|
||||||
|
prevPageBtn.classList.add('is-disabled');
|
||||||
|
} else {
|
||||||
|
prevPageBtn.classList.remove('is-disabled');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nextPageBtn.disabled) {
|
||||||
|
nextPageBtn.classList.add('is-disabled');
|
||||||
|
} else {
|
||||||
|
nextPageBtn.classList.remove('is-disabled');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function showNotification(message, type = 'info') {
|
||||||
|
const notification = document.createElement('div');
|
||||||
|
notification.className = `notification is-${type}`;
|
||||||
|
notification.style.cssText = 'position: fixed; bottom: 20px; right: 20px; z-index: 1000; max-width: 350px; transform: translateX(120%); opacity: 0; transition: transform 0.4s cubic-bezier(0.68, -0.55, 0.27, 1.55), opacity 0.4s ease;';
|
||||||
|
|
||||||
|
notification.innerHTML = `<button class="delete"></button>${message}`;
|
||||||
|
notification.querySelector('.delete').addEventListener('click', () => {
|
||||||
|
notification.style.transform = 'translateX(120%)';
|
||||||
|
notification.style.opacity = '0';
|
||||||
|
setTimeout(() => {
|
||||||
|
if (notification.parentNode) {
|
||||||
|
notification.remove();
|
||||||
|
}
|
||||||
|
}, 400);
|
||||||
|
});
|
||||||
|
|
||||||
|
document.body.appendChild(notification);
|
||||||
|
|
||||||
|
// 触发滑动动画
|
||||||
|
setTimeout(() => {
|
||||||
|
notification.style.transform = 'translateX(0)';
|
||||||
|
notification.style.opacity = '1';
|
||||||
|
}, 10);
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
if (notification.parentNode) {
|
||||||
|
notification.style.transform = 'translateX(120%)';
|
||||||
|
notification.style.opacity = '0';
|
||||||
|
setTimeout(() => {
|
||||||
|
if (notification.parentNode) {
|
||||||
|
notification.remove();
|
||||||
|
}
|
||||||
|
}, 400);
|
||||||
|
}
|
||||||
|
}, 3000);
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
{% endblock %}
|
||||||
|
|
@ -1,700 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# import plotly.express as px # 暂时未使用
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
import streamlit as st
|
|
||||||
|
|
||||||
# from plotly.subplots import make_subplots # 暂时未使用
|
|
||||||
|
|
||||||
# 用于HTTP URL支持
|
|
||||||
# requests模块将在load_from_url函数中动态导入
|
|
||||||
|
|
||||||
# 添加项目路径到系统路径,以便导入模型
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
# 设置页面配置 - 移动端友好
|
|
||||||
st.set_page_config(
|
|
||||||
page_title="AI模型训练监控看板",
|
|
||||||
page_icon="📈",
|
|
||||||
layout="wide",
|
|
||||||
initial_sidebar_state="collapsed", # 移动端默认收起侧边栏
|
|
||||||
)
|
|
||||||
|
|
||||||
# 自定义CSS样式,优化移动端体验
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
<style>
|
|
||||||
/* 基础响应式设置 */
|
|
||||||
.main > div {
|
|
||||||
padding-left: 1rem;
|
|
||||||
padding-right: 1rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 卡片样式 */
|
|
||||||
.metric-card {
|
|
||||||
background-color: #f0f2f6;
|
|
||||||
border-radius: 10px;
|
|
||||||
padding: 15px;
|
|
||||||
margin-bottom: 15px;
|
|
||||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 指标值大字体 */
|
|
||||||
.metric-value {
|
|
||||||
font-size: 2.5rem !important;
|
|
||||||
font-weight: bold !important;
|
|
||||||
color: #1f77b4 !important;
|
|
||||||
margin-bottom: 5px !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 指标标签 */
|
|
||||||
.metric-label {
|
|
||||||
font-size: 1rem !important;
|
|
||||||
color: #666 !important;
|
|
||||||
margin-top: 0 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 图表容器 */
|
|
||||||
.plot-container {
|
|
||||||
margin-top: 20px;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 响应式调整 */
|
|
||||||
@media (max-width: 768px) {
|
|
||||||
.metric-value {
|
|
||||||
font-size: 2rem !important;
|
|
||||||
}
|
|
||||||
.stButton > button {
|
|
||||||
width: 100% !important;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 状态指示器 */
|
|
||||||
.status-indicator {
|
|
||||||
display: inline-block;
|
|
||||||
width: 12px;
|
|
||||||
height: 12px;
|
|
||||||
border-radius: 50%;
|
|
||||||
margin-right: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-active {
|
|
||||||
background-color: #00cc00;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-inactive {
|
|
||||||
background-color: #ff6666;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
""",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_training_data(file_path):
|
|
||||||
"""从本地文件或HTTP URL加载训练状态数据"""
|
|
||||||
# 检查是否是HTTP/HTTPS URL
|
|
||||||
if file_path.startswith(("http://", "https://")):
|
|
||||||
return load_from_url(file_path)
|
|
||||||
else:
|
|
||||||
return load_from_local_file(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_url(url):
|
|
||||||
"""从HTTP URL加载数据"""
|
|
||||||
# 动态导入requests模块
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
except ImportError:
|
|
||||||
st.error(
|
|
||||||
"requests库未安装,无法从HTTP URL加载数据。请安装:pip install requests"
|
|
||||||
)
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 设置合理的超时
|
|
||||||
response = requests.get(url, timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
return convert_to_dataframe(data)
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
st.warning(f"从URL加载数据失败: {e}")
|
|
||||||
return pd.DataFrame()
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
st.warning("远程返回的数据不是有效的JSON格式")
|
|
||||||
return pd.DataFrame()
|
|
||||||
except Exception as e:
|
|
||||||
st.warning(f"从URL加载数据时发生错误: {e}")
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_local_file(file_path):
|
|
||||||
"""从本地文件加载数据"""
|
|
||||||
try:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
st.warning(f"文件不存在: {file_path}")
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
return convert_to_dataframe(data)
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return pd.DataFrame()
|
|
||||||
except Exception:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_dataframe(data):
|
|
||||||
"""将数据转换为DataFrame,包含数据验证和清理"""
|
|
||||||
# 检测是否是配置文件(检查是否有典型的配置键)
|
|
||||||
config_keys = [
|
|
||||||
"train_data_path",
|
|
||||||
"eval_data_path",
|
|
||||||
"output_dir",
|
|
||||||
"vocab_size",
|
|
||||||
"batch_size",
|
|
||||||
"num_epochs",
|
|
||||||
]
|
|
||||||
|
|
||||||
if isinstance(data, dict):
|
|
||||||
# 检查是否是配置文件
|
|
||||||
if any(key in data for key in config_keys):
|
|
||||||
st.error("❌ 检测到配置文件,请检查文件路径是否正确")
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 如果是单个训练状态字典,包装成列表
|
|
||||||
data = [data]
|
|
||||||
elif isinstance(data, list):
|
|
||||||
# 检查列表中的第一个元素是否是配置文件
|
|
||||||
if data and isinstance(data[0], dict):
|
|
||||||
if any(key in data[0] for key in config_keys):
|
|
||||||
st.error("❌ 检测到配置文件,请检查文件路径是否正确")
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 已经是列表,检查是否为空
|
|
||||||
if len(data) == 0:
|
|
||||||
return pd.DataFrame()
|
|
||||||
else:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 确保data是列表(经过上面的处理,它应该是列表)
|
|
||||||
if not isinstance(data, list):
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 清理数据:确保所有元素都是字典
|
|
||||||
cleaned_data = []
|
|
||||||
for i, item in enumerate(data):
|
|
||||||
if isinstance(item, dict):
|
|
||||||
# 检查是否是训练状态数据(包含训练指标)
|
|
||||||
if "step" in item or "train/loss" in item or "timestamp" in item:
|
|
||||||
cleaned_data.append(item)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(cleaned_data) == 0:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 使用清理后的数据创建DataFrame
|
|
||||||
try:
|
|
||||||
df = pd.DataFrame(cleaned_data, index=range(len(cleaned_data)))
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
df = pd.DataFrame.from_records(cleaned_data)
|
|
||||||
except Exception:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 确保时间戳为datetime类型
|
|
||||||
if "timestamp" in df.columns:
|
|
||||||
try:
|
|
||||||
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def create_metric_card(label, value, delta=None, help_text=None):
|
|
||||||
"""创建指标卡片"""
|
|
||||||
col1, col2 = st.columns([3, 1])
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
if delta is not None:
|
|
||||||
st.metric(label=label, value=value, delta=f"{delta:+.4f}")
|
|
||||||
else:
|
|
||||||
st.metric(label=label, value=value)
|
|
||||||
|
|
||||||
if help_text:
|
|
||||||
st.caption(help_text)
|
|
||||||
|
|
||||||
return col1, col2
|
|
||||||
|
|
||||||
|
|
||||||
def create_loss_chart(df):
|
|
||||||
"""创建损失图表"""
|
|
||||||
fig = go.Figure()
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
# 训练损失
|
|
||||||
if "train/loss" in df.columns:
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(
|
|
||||||
x=df["step"],
|
|
||||||
y=df["train/loss"],
|
|
||||||
mode="lines+markers",
|
|
||||||
name="训练损失",
|
|
||||||
line=dict(color="#1f77b4", width=2),
|
|
||||||
marker=dict(size=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估损失
|
|
||||||
if "eval/loss" in df.columns:
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(
|
|
||||||
x=df["step"],
|
|
||||||
y=df["eval/loss"],
|
|
||||||
mode="lines+markers",
|
|
||||||
name="评估损失",
|
|
||||||
line=dict(color="#ff7f0e", width=2, dash="dash"),
|
|
||||||
marker=dict(size=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title="损失曲线",
|
|
||||||
xaxis_title="训练步数",
|
|
||||||
yaxis_title="损失值",
|
|
||||||
hovermode="x unified",
|
|
||||||
template="plotly_white",
|
|
||||||
height=400,
|
|
||||||
margin=dict(l=40, r=40, t=60, b=40),
|
|
||||||
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加网格
|
|
||||||
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
|
|
||||||
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_accuracy_chart(df):
|
|
||||||
"""创建准确率图表"""
|
|
||||||
fig = go.Figure()
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
return fig
|
|
||||||
|
|
||||||
# 训练准确率
|
|
||||||
if "train/accuracy" in df.columns:
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(
|
|
||||||
x=df["step"],
|
|
||||||
y=df["train/accuracy"],
|
|
||||||
mode="lines+markers",
|
|
||||||
name="训练准确率",
|
|
||||||
line=dict(color="#2ca02c", width=2),
|
|
||||||
marker=dict(size=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估准确率
|
|
||||||
if "eval/accuracy" in df.columns:
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(
|
|
||||||
x=df["step"],
|
|
||||||
y=df["eval/accuracy"],
|
|
||||||
mode="lines+markers",
|
|
||||||
name="评估准确率",
|
|
||||||
line=dict(color="#d62728", width=2, dash="dash"),
|
|
||||||
marker=dict(size=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title="准确率曲线",
|
|
||||||
xaxis_title="训练步数",
|
|
||||||
yaxis_title="准确率",
|
|
||||||
hovermode="x unified",
|
|
||||||
template="plotly_white",
|
|
||||||
height=400,
|
|
||||||
margin=dict(l=40, r=40, t=60, b=40),
|
|
||||||
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加网格
|
|
||||||
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
|
|
||||||
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_learning_rate_chart(df):
|
|
||||||
"""创建学习率图表"""
|
|
||||||
if "train/learning_rate" not in df.columns:
|
|
||||||
return None
|
|
||||||
|
|
||||||
fig = go.Figure()
|
|
||||||
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(
|
|
||||||
x=df["step"],
|
|
||||||
y=df["train/learning_rate"],
|
|
||||||
mode="lines+markers",
|
|
||||||
name="学习率",
|
|
||||||
line=dict(color="#9467bd", width=2),
|
|
||||||
marker=dict(size=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title="学习率变化",
|
|
||||||
xaxis_title="训练步数",
|
|
||||||
yaxis_title="学习率",
|
|
||||||
hovermode="x unified",
|
|
||||||
template="plotly_white",
|
|
||||||
height=300,
|
|
||||||
margin=dict(l=40, r=40, t=60, b=40),
|
|
||||||
yaxis_type="log", # 对数坐标,适合学习率
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加网格
|
|
||||||
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
|
|
||||||
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_training_summary(df):
|
|
||||||
"""创建训练摘要"""
|
|
||||||
if df.empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
latest = df.iloc[-1]
|
|
||||||
summary = {}
|
|
||||||
|
|
||||||
# 基础信息
|
|
||||||
summary["当前步数"] = int(latest["step"]) if "step" in latest else 0
|
|
||||||
summary["当前轮次"] = int(latest["epoch"]) if "epoch" in latest else 0
|
|
||||||
|
|
||||||
# 训练指标
|
|
||||||
if "train/loss" in latest:
|
|
||||||
summary["训练损失"] = float(latest["train/loss"])
|
|
||||||
if "train/accuracy" in latest:
|
|
||||||
summary["训练准确率"] = float(latest["train/accuracy"])
|
|
||||||
|
|
||||||
# 评估指标
|
|
||||||
if "eval/loss" in latest:
|
|
||||||
summary["评估损失"] = float(latest["eval/loss"])
|
|
||||||
if "eval/accuracy" in latest:
|
|
||||||
summary["评估准确率"] = float(latest["eval/accuracy"])
|
|
||||||
|
|
||||||
# 学习率
|
|
||||||
if "train/learning_rate" in latest:
|
|
||||||
summary["当前学习率"] = float(latest["train/learning_rate"])
|
|
||||||
|
|
||||||
# 时间信息
|
|
||||||
if "timestamp" in latest:
|
|
||||||
summary["最后更新时间"] = pd.to_datetime(latest["timestamp"]).strftime(
|
|
||||||
"%Y-%m-%d %H:%M:%S"
|
|
||||||
)
|
|
||||||
|
|
||||||
return summary
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
# 标题
|
|
||||||
st.title("📈 AI模型训练实时监控看板")
|
|
||||||
st.markdown("基于JSON旁路记录法的移动端友好监控方案")
|
|
||||||
|
|
||||||
# 侧边栏配置
|
|
||||||
with st.sidebar:
|
|
||||||
st.header("⚙️ 监控设置")
|
|
||||||
|
|
||||||
# 文件路径选择
|
|
||||||
default_status_file = os.environ.get(
|
|
||||||
"TRAINING_STATUS_FILE", "./output/training_status.json"
|
|
||||||
)
|
|
||||||
status_file = st.text_input(
|
|
||||||
"状态文件路径或URL",
|
|
||||||
value=default_status_file,
|
|
||||||
help="可以是:\n1. 本地文件路径(如 ./output/training_status.json)\n2. HTTP/HTTPS URL(如 http://服务器IP:端口/training_status.json)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 刷新间隔
|
|
||||||
refresh_interval = st.slider(
|
|
||||||
"自动刷新间隔(秒)",
|
|
||||||
min_value=1,
|
|
||||||
max_value=30,
|
|
||||||
value=5,
|
|
||||||
help="数据自动刷新间隔时间",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 数据限制
|
|
||||||
max_data_points = st.slider(
|
|
||||||
"显示数据点数量",
|
|
||||||
min_value=10,
|
|
||||||
max_value=1000,
|
|
||||||
value=500,
|
|
||||||
help="图表中显示的最大数据点数量",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 手动刷新按钮
|
|
||||||
if st.button("🔄 手动刷新数据"):
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# 状态信息
|
|
||||||
st.subheader("📊 系统状态")
|
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
st.text(f"当前时间: {current_time}")
|
|
||||||
|
|
||||||
if os.path.exists(status_file):
|
|
||||||
file_size = os.path.getsize(status_file)
|
|
||||||
file_mtime = datetime.fromtimestamp(os.path.getmtime(status_file))
|
|
||||||
st.text(f"文件大小: {file_size:,} 字节")
|
|
||||||
st.text(f"最后修改: {file_mtime.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
st.success("✅ 状态文件正常")
|
|
||||||
else:
|
|
||||||
st.warning("⚠️ 状态文件不存在")
|
|
||||||
|
|
||||||
# 主内容区域
|
|
||||||
# 使用列布局适应移动端
|
|
||||||
col1, col2 = st.columns([1, 1])
|
|
||||||
|
|
||||||
# 加载数据
|
|
||||||
df = load_training_data(status_file)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
st.warning("暂无训练数据,请检查状态文件路径是否正确。")
|
|
||||||
st.info("开始训练后,数据将自动显示在这里。")
|
|
||||||
|
|
||||||
# 显示示例数据格式
|
|
||||||
with st.expander("📝 数据格式示例"):
|
|
||||||
st.code(
|
|
||||||
"""[
|
|
||||||
{
|
|
||||||
"step": 100,
|
|
||||||
"epoch": 1,
|
|
||||||
"timestamp": "2024-01-01T12:00:00",
|
|
||||||
"train/loss": 2.345,
|
|
||||||
"train/accuracy": 0.456,
|
|
||||||
"eval/loss": 2.123,
|
|
||||||
"eval/accuracy": 0.512,
|
|
||||||
"train/learning_rate": 0.0001
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]""",
|
|
||||||
language="json",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 数据处理
|
|
||||||
df_display = df.tail(max_data_points).copy()
|
|
||||||
|
|
||||||
# 计算变化量(用于指标卡片)
|
|
||||||
if len(df_display) >= 2:
|
|
||||||
prev = df_display.iloc[-2]
|
|
||||||
latest = df_display.iloc[-1]
|
|
||||||
|
|
||||||
loss_delta = None
|
|
||||||
acc_delta = None
|
|
||||||
lr_delta = None
|
|
||||||
|
|
||||||
if "train/loss" in latest and "train/loss" in prev:
|
|
||||||
loss_delta = float(latest["train/loss"]) - float(prev["train/loss"])
|
|
||||||
|
|
||||||
if "train/accuracy" in latest and "train/accuracy" in prev:
|
|
||||||
acc_delta = float(latest["train/accuracy"]) - float(prev["train/accuracy"])
|
|
||||||
|
|
||||||
if "train/learning_rate" in latest and "train/learning_rate" in prev:
|
|
||||||
lr_delta = float(latest["train/learning_rate"]) - float(
|
|
||||||
prev["train/learning_rate"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loss_delta = None
|
|
||||||
acc_delta = None
|
|
||||||
lr_delta = None
|
|
||||||
|
|
||||||
latest = df_display.iloc[-1]
|
|
||||||
|
|
||||||
# 关键指标卡片 - 第一行
|
|
||||||
st.subheader("📊 关键指标")
|
|
||||||
|
|
||||||
metric_cols = st.columns(4)
|
|
||||||
with metric_cols[0]:
|
|
||||||
if "step" in latest:
|
|
||||||
st.metric("当前步数", f"{int(latest['step']):,}")
|
|
||||||
|
|
||||||
with metric_cols[1]:
|
|
||||||
if "epoch" in latest:
|
|
||||||
st.metric("当前轮次", f"{int(latest['epoch'])}")
|
|
||||||
|
|
||||||
with metric_cols[2]:
|
|
||||||
if "train/loss" in latest:
|
|
||||||
st.metric(
|
|
||||||
"训练损失",
|
|
||||||
f"{float(latest['train/loss']):.4f}",
|
|
||||||
delta=f"{loss_delta:+.4f}" if loss_delta is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with metric_cols[3]:
|
|
||||||
if "train/accuracy" in latest:
|
|
||||||
st.metric(
|
|
||||||
"训练准确率",
|
|
||||||
f"{float(latest['train/accuracy']):.4f}",
|
|
||||||
delta=f"{acc_delta:+.4f}" if acc_delta is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估指标 - 第二行
|
|
||||||
eval_cols = st.columns(4)
|
|
||||||
with eval_cols[0]:
|
|
||||||
if "eval/loss" in latest:
|
|
||||||
eval_loss = float(latest["eval/loss"])
|
|
||||||
prev_eval_loss = (
|
|
||||||
float(df_display.iloc[-2]["eval/loss"])
|
|
||||||
if len(df_display) >= 2 and "eval/loss" in df_display.iloc[-2]
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
delta = eval_loss - prev_eval_loss if prev_eval_loss is not None else None
|
|
||||||
st.metric(
|
|
||||||
"评估损失",
|
|
||||||
f"{eval_loss:.4f}",
|
|
||||||
delta=f"{delta:+.4f}" if delta is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with eval_cols[1]:
|
|
||||||
if "eval/accuracy" in latest:
|
|
||||||
eval_acc = float(latest["eval/accuracy"])
|
|
||||||
prev_eval_acc = (
|
|
||||||
float(df_display.iloc[-2]["eval/accuracy"])
|
|
||||||
if len(df_display) >= 2 and "eval/accuracy" in df_display.iloc[-2]
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
delta = eval_acc - prev_eval_acc if prev_eval_acc is not None else None
|
|
||||||
st.metric(
|
|
||||||
"评估准确率",
|
|
||||||
f"{eval_acc:.4f}",
|
|
||||||
delta=f"{delta:+.4f}" if delta is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with eval_cols[2]:
|
|
||||||
if "train/learning_rate" in latest:
|
|
||||||
st.metric(
|
|
||||||
"学习率",
|
|
||||||
f"{float(latest['train/learning_rate']):.2e}",
|
|
||||||
delta=f"{lr_delta:+.2e}" if lr_delta is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with eval_cols[3]:
|
|
||||||
if "timestamp" in latest:
|
|
||||||
timestamp = pd.to_datetime(latest["timestamp"])
|
|
||||||
st.metric("最后更新", timestamp.strftime("%H:%M:%S"))
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# 图表区域
|
|
||||||
st.subheader("📈 训练曲线")
|
|
||||||
|
|
||||||
# 损失图表
|
|
||||||
loss_fig = create_loss_chart(df_display)
|
|
||||||
st.plotly_chart(loss_fig, width="stretch", config={"responsive": True})
|
|
||||||
|
|
||||||
# 准确率图表
|
|
||||||
acc_fig = create_accuracy_chart(df_display)
|
|
||||||
st.plotly_chart(acc_fig, width="stretch", config={"responsive": True})
|
|
||||||
|
|
||||||
# 学习率图表
|
|
||||||
lr_fig = create_learning_rate_chart(df_display)
|
|
||||||
if lr_fig:
|
|
||||||
st.plotly_chart(lr_fig, width="stretch", config={"responsive": True})
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# 数据详情
|
|
||||||
with st.expander("📋 数据详情", expanded=False):
|
|
||||||
st.dataframe(
|
|
||||||
df_display,
|
|
||||||
width="stretch",
|
|
||||||
hide_index=True,
|
|
||||||
column_config={
|
|
||||||
"step": st.column_config.NumberColumn("步数", format="%d"),
|
|
||||||
"epoch": st.column_config.NumberColumn("轮次", format="%d"),
|
|
||||||
"timestamp": st.column_config.DatetimeColumn("时间"),
|
|
||||||
"train/loss": st.column_config.NumberColumn("训练损失", format="%.4f"),
|
|
||||||
"train/accuracy": st.column_config.NumberColumn(
|
|
||||||
"训练准确率", format="%.4f"
|
|
||||||
),
|
|
||||||
"eval/loss": st.column_config.NumberColumn("评估损失", format="%.4f"),
|
|
||||||
"eval/accuracy": st.column_config.NumberColumn(
|
|
||||||
"评估准确率", format="%.4f"
|
|
||||||
),
|
|
||||||
"train/learning_rate": st.column_config.NumberColumn(
|
|
||||||
"学习率", format="%.2e"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 统计数据
|
|
||||||
st.subheader("📊 统计信息")
|
|
||||||
stats_cols = st.columns(3)
|
|
||||||
|
|
||||||
with stats_cols[0]:
|
|
||||||
st.metric("总数据点", f"{len(df):,}")
|
|
||||||
|
|
||||||
with stats_cols[1]:
|
|
||||||
if not df.empty and "timestamp" in df.columns:
|
|
||||||
start_time = df["timestamp"].min()
|
|
||||||
end_time = df["timestamp"].max()
|
|
||||||
duration = (end_time - start_time).total_seconds() / 3600 # 小时
|
|
||||||
st.metric("训练时长", f"{duration:.2f} 小时")
|
|
||||||
|
|
||||||
with stats_cols[2]:
|
|
||||||
if not df.empty and "step" in df.columns:
|
|
||||||
total_steps = df["step"].max() - df["step"].min()
|
|
||||||
st.metric("总步数", f"{total_steps:,}")
|
|
||||||
|
|
||||||
# 训练进度信息
|
|
||||||
if "step" in df.columns and "epoch" in df.columns:
|
|
||||||
current_step = df["step"].max()
|
|
||||||
current_epoch = df["epoch"].max()
|
|
||||||
|
|
||||||
progress_text = f"训练进度: 第 {current_epoch} 轮,第 {current_step:,} 步"
|
|
||||||
st.progress(min(current_step / (current_step + 1000), 1.0), text=progress_text)
|
|
||||||
|
|
||||||
# 底部状态栏
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
footer_cols = st.columns([3, 1])
|
|
||||||
with footer_cols[0]:
|
|
||||||
st.caption(
|
|
||||||
f"监控服务运行中 | 数据文件: {status_file} | 最后刷新: {datetime.now().strftime('%H:%M:%S')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with footer_cols[1]:
|
|
||||||
if st.button("🔄 立即刷新"):
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# 自动刷新
|
|
||||||
time.sleep(refresh_interval)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Loading…
Reference in New Issue