feat(model): 添加 tensorboard 依赖并重构训练监控逻辑
This commit is contained in:
parent
43c8349d51
commit
4031a668da
|
|
@ -17,6 +17,7 @@ dependencies = [
|
||||||
"pypinyin>=0.55.0",
|
"pypinyin>=0.55.0",
|
||||||
"requests>=2.32.5",
|
"requests>=2.32.5",
|
||||||
"rich>=14.3.1",
|
"rich>=14.3.1",
|
||||||
|
"tensorboard>=2.20.0",
|
||||||
"transformers>=5.1.0",
|
"transformers>=5.1.0",
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from tqdm.autonotebook import tqdm
|
||||||
|
|
||||||
from suinput.dataset import PG
|
from suinput.dataset import PG
|
||||||
|
|
||||||
from .monitor import TrainingMonitor, send_serverchan_message
|
from .monitor import TrainingMonitor
|
||||||
|
|
||||||
|
|
||||||
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||||||
|
|
@ -157,7 +157,10 @@ class MoEModel(nn.Module):
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
||||||
nn.Dropout(0.4),
|
nn.Dropout(0.4),
|
||||||
nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
|
nn.Linear(
|
||||||
|
self.hidden_size * self.output_multiplier,
|
||||||
|
self.hidden_size * self.output_multiplier * 2,
|
||||||
|
),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
|
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
|
||||||
)
|
)
|
||||||
|
|
@ -462,77 +465,89 @@ class MoEModel(nn.Module):
|
||||||
global_step = 0 # 初始化
|
global_step = 0 # 初始化
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
try:
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for batch_idx, batch in enumerate(
|
for batch_idx, batch in enumerate(
|
||||||
tqdm(train_dataloader, total=int(stop_batch))
|
tqdm(train_dataloader, total=int(stop_batch))
|
||||||
):
|
|
||||||
# LR Schedule
|
|
||||||
if processed_batches < warmup_steps:
|
|
||||||
current_lr = lr * (processed_batches/ warmup_steps)
|
|
||||||
else:
|
|
||||||
progress = (processed_batches - warmup_steps) / (
|
|
||||||
total_steps - warmup_steps
|
|
||||||
)
|
|
||||||
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group["lr"] = current_lr
|
|
||||||
|
|
||||||
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
|
||||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
|
||||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
|
||||||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
|
|
||||||
pg = batch["pg"].to(self.device)
|
|
||||||
labels = batch["char_id"].to(self.device)
|
|
||||||
|
|
||||||
with torch.amp.autocast(
|
|
||||||
device_type=self.device.type, enabled=mixed_precision
|
|
||||||
):
|
):
|
||||||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
# LR Schedule
|
||||||
loss = criterion(logits, labels)
|
if processed_batches < warmup_steps:
|
||||||
loss = loss / grad_accum_steps
|
current_lr = lr * (processed_batches / warmup_steps)
|
||||||
|
else:
|
||||||
|
progress = (processed_batches - warmup_steps) / (
|
||||||
|
total_steps - warmup_steps
|
||||||
|
)
|
||||||
|
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = current_lr
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
||||||
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
|
token_type_ids = batch["hint"]["token_type_ids"].to(
|
||||||
|
self.device
|
||||||
|
) # 新增
|
||||||
|
pg = batch["pg"].to(self.device)
|
||||||
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
if (processed_batches) % grad_accum_steps == 0:
|
with torch.amp.autocast(
|
||||||
scaler.unscale_(optimizer)
|
device_type=self.device.type, enabled=mixed_precision
|
||||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
):
|
||||||
|
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
loss = loss / grad_accum_steps
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.scale(loss).backward()
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
if (processed_batches) % grad_accum_steps == 0:
|
||||||
batch_loss_sum += loss.item() * grad_accum_steps
|
scaler.unscale_(optimizer)
|
||||||
if global_step % eval_frequency == 0:
|
torch.nn.utils.clip_grad_norm_(
|
||||||
if eval_dataloader:
|
self.parameters(), clip_grad_norm
|
||||||
self.eval()
|
)
|
||||||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
|
||||||
self.train()
|
scaler.step(optimizer)
|
||||||
if monitor:
|
scaler.update()
|
||||||
# 使用 eval_loss 作为监控指标
|
optimizer.zero_grad()
|
||||||
monitor.add_step(
|
batch_loss_sum += loss.item() * grad_accum_steps
|
||||||
global_step, {"loss": batch_loss_sum / (eval_frequency if global_step > 0 else 1), "acc": acc}
|
if global_step % eval_frequency == 0:
|
||||||
|
if eval_dataloader:
|
||||||
|
self.eval()
|
||||||
|
acc, eval_loss = self.model_eval(
|
||||||
|
eval_dataloader, criterion
|
||||||
)
|
)
|
||||||
logger.info(
|
self.train()
|
||||||
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
if monitor:
|
||||||
)
|
# 使用 eval_loss 作为监控指标
|
||||||
else:
|
monitor.add_step(
|
||||||
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
|
global_step,
|
||||||
batch_loss_sum = 0.0
|
{
|
||||||
if processed_batches >= stop_batch:
|
"train_loss": batch_loss_sum
|
||||||
break
|
/ (
|
||||||
processed_batches += 1
|
eval_frequency if global_step > 0 else 1
|
||||||
global_step += 1
|
),
|
||||||
|
"acc": acc,
|
||||||
|
"loss": eval_loss,
|
||||||
|
"lr": current_lr
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||||
|
)
|
||||||
|
batch_loss_sum = 0.0
|
||||||
|
if processed_batches >= stop_batch:
|
||||||
|
break
|
||||||
|
processed_batches += 1
|
||||||
|
global_step += 1
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Training interrupted by user")
|
||||||
|
|
||||||
# 训练结束发送通知
|
# 训练结束发送通知
|
||||||
try:
|
if monitor:
|
||||||
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
|
monitor.finish()
|
||||||
send_serverchan_message(
|
|
||||||
title="训练完成",
|
|
||||||
content=f"acc: {res_acc:.4f}, loss: {res_loss:.4f}",
|
|
||||||
)
|
|
||||||
logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送消息失败: {e}")
|
|
||||||
|
|
||||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
|
|
|
||||||
|
|
@ -1,201 +1,343 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from bokeh.io import push_notebook, show, output_notebook
|
|
||||||
from bokeh.plotting import figure
|
|
||||||
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
|
import pendulum
|
||||||
|
from loguru import logger
|
||||||
|
from bokeh.io import push_notebook, show, output_notebook, save
|
||||||
|
from bokeh.resources import INLINE
|
||||||
|
from bokeh.plotting import figure
|
||||||
|
from bokeh.models import ColumnDataSource, Range1d, LinearAxis
|
||||||
|
|
||||||
output_notebook() # 在 Jupyter 中必须调用一次
|
output_notebook() # 在 Jupyter 中必须调用一次
|
||||||
|
|
||||||
|
|
||||||
def send_serverchan_message(title, content=""):
|
|
||||||
"""
|
|
||||||
向ServerChan发送消息(SendKey从环境变量读取)
|
|
||||||
|
|
||||||
:param title: 消息标题
|
|
||||||
:param content: 消息内容(支持Markdown)
|
|
||||||
:return: 响应JSON
|
|
||||||
"""
|
|
||||||
# 从环境变量获取SendKey
|
|
||||||
send_key = os.environ.get("SERVERCHAN_SEND_KEY")
|
|
||||||
|
|
||||||
if not send_key:
|
|
||||||
raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY,请检查配置")
|
|
||||||
|
|
||||||
url = f"https://sctapi.ftqq.com/{send_key}.send"
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"text": title,
|
|
||||||
"desp": content
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, data=data)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingMonitor:
|
class TrainingMonitor:
|
||||||
"""
|
"""
|
||||||
实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。
|
实时训练监控图,支持 Bokeh 绘图和可选的 TensorBoard 记录。
|
||||||
|
|
||||||
参数
|
参数
|
||||||
----------
|
----------
|
||||||
metrics : list of dict, 可选
|
metrics : list of dict, 可选
|
||||||
每个指标是一个 dict,必须包含 'name'(数据列名)和 'label'(图例)。
|
每个指标是一个 dict,必须包含 'name'(数据列名)和 'label'(图例)。
|
||||||
可选字段:'color' (颜色), 'y_axis' ('left' 或 'right', 默认 'left'),
|
可选字段:
|
||||||
'y_range' (手动指定 Y 轴范围,如 [0,1])。
|
'color' (颜色)
|
||||||
若为 None,则使用默认的 [loss, acc]。
|
'y_axis' ('left' 或 'right', 默认 'left')
|
||||||
title : str, 默认 "训练曲线"
|
'y_range' (手动指定 Y 轴范围,如 [0,1])
|
||||||
width : int, 默认 1080
|
'plot' (bool, 是否在 Bokeh 中绘制,默认 False)
|
||||||
height : int, 默认 384
|
若为 None,则使用默认的两个指标(loss 左轴,acc 右轴),且 plot=True。
|
||||||
line_width : int, 默认 2
|
title : str, 默认 "训练曲线" (仅 Bokeh 使用)
|
||||||
|
width : int, 默认 1080 (仅 Bokeh 使用)
|
||||||
|
height : int, 默认 384 (仅 Bokeh 使用)
|
||||||
|
line_width : int, 默认 2 (仅 Bokeh 使用)
|
||||||
|
tensorboard : bool, 默认 False,是否启用 TensorBoard 记录
|
||||||
|
tensorboard_log_dir : str, 可选,TensorBoard 日志目录(当 tensorboard=True 时使用)
|
||||||
|
send_key_serverchan : str, 可选,ServerChan SendKey。若未提供则从环境变量 SERVERCHAN_SEND_KEY 读取。
|
||||||
|
send_key_autodl : str, 可选,AutoDL SendKey。若未提供则从环境变量 AUTODL_SCKEY 读取。
|
||||||
"""
|
"""
|
||||||
def __init__(self, metrics=None, title="训练曲线", width=1080, height=384, line_width=2):
|
|
||||||
# 默认指标:loss 左轴,acc 右轴
|
def __init__(
|
||||||
|
self,
|
||||||
|
metrics=None,
|
||||||
|
title="训练曲线",
|
||||||
|
width=1080,
|
||||||
|
height=384,
|
||||||
|
line_width=2,
|
||||||
|
tensorboard=False,
|
||||||
|
tensorboard_log_dir="/root/tf-logs",
|
||||||
|
send_key_serverchan=None,
|
||||||
|
send_key_autodl=None,
|
||||||
|
):
|
||||||
|
self.tensorboard = tensorboard
|
||||||
|
self.send_key_serverchan = send_key_serverchan or os.environ.get(
|
||||||
|
"SERVERCHAN_SEND_KEY"
|
||||||
|
)
|
||||||
|
self.send_key_autodl = send_key_autodl or os.environ.get("AUTODL_SCKEY")
|
||||||
|
|
||||||
|
# ---------- 处理指标定义 ----------
|
||||||
if metrics is None:
|
if metrics is None:
|
||||||
metrics = [
|
metrics = [
|
||||||
{'name': 'loss', 'label': 'loss', 'color': '#ed5a65', 'y_axis': 'left'},
|
{
|
||||||
{'name': 'acc', 'label': 'accuracy', 'color': '#2b1216', 'y_axis': 'right', 'y_range': [0, 1]}
|
"name": "loss",
|
||||||
|
"label": "loss",
|
||||||
|
"color": "#ed5a65",
|
||||||
|
"y_axis": "left",
|
||||||
|
"plot": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "acc",
|
||||||
|
"label": "accuracy",
|
||||||
|
"color": "#2b1216",
|
||||||
|
"y_axis": "right",
|
||||||
|
"y_range": [0, 1],
|
||||||
|
"plot": True,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
for m in metrics:
|
||||||
|
m.setdefault("plot", False)
|
||||||
|
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.metric_names = [m['name'] for m in metrics]
|
self.metric_names = [m["name"] for m in metrics]
|
||||||
|
self.plot_metrics = [m for m in metrics if m["plot"]]
|
||||||
|
self.plot_names = [m["name"] for m in self.plot_metrics]
|
||||||
|
|
||||||
# 初始化数据源,含 step 列 + 各指标列
|
# ---------- Bokeh 初始化 ----------
|
||||||
self.source = ColumnDataSource(data={'step': []})
|
self.source = ColumnDataSource(data={"step": []})
|
||||||
for m in metrics:
|
for m in self.plot_metrics:
|
||||||
self.source.data[m['name']] = []
|
self.source.data[m["name"]] = []
|
||||||
|
|
||||||
# 创建图形
|
self.p = figure(
|
||||||
self.p = figure(title=title, width=width, height=height,
|
title=title,
|
||||||
x_axis_label='step', y_axis_label='left axis')
|
width=width,
|
||||||
self.p.extra_y_ranges = {} # 存放右轴
|
height=height,
|
||||||
|
x_axis_label="step",
|
||||||
|
y_axis_label="left axis",
|
||||||
|
)
|
||||||
|
self.p.extra_y_ranges = {}
|
||||||
|
|
||||||
# 为每个指标添加线条
|
for m in self.plot_metrics:
|
||||||
for m in metrics:
|
color = m.get("color", None)
|
||||||
color = m.get('color', None)
|
y_axis = m.get("y_axis", "left")
|
||||||
y_axis = m.get('y_axis', 'left')
|
legend = m.get("label", m["name"])
|
||||||
legend = m.get('label', m['name'])
|
|
||||||
|
|
||||||
if y_axis == 'right':
|
if y_axis == "right":
|
||||||
# 创建右轴(若尚未创建)
|
y_range_name = f"right_{m['name']}"
|
||||||
if y_axis not in self.p.extra_y_ranges:
|
y_range = m.get("y_range", None)
|
||||||
y_range_name = f'{y_axis}_{m["name"]}' # 唯一命名
|
if y_range is None:
|
||||||
# 使用手动范围或自动计算
|
y_range = Range1d(start=0, end=1)
|
||||||
y_range = m.get('y_range', None)
|
|
||||||
if y_range is None:
|
|
||||||
y_range = Range1d(start=0, end=1) # 占位,稍后自动调整
|
|
||||||
else:
|
|
||||||
y_range = Range1d(start=y_range[0], end=y_range[1])
|
|
||||||
self.p.extra_y_ranges[y_range_name] = y_range
|
|
||||||
self.p.add_layout(LinearAxis(y_range_name=y_range_name), 'right')
|
|
||||||
else:
|
else:
|
||||||
# 复用已创建的右轴,简单起见:每个右轴指标使用独立 y_range_name
|
y_range = Range1d(start=y_range[0], end=y_range[1])
|
||||||
y_range_name = f'right_{m["name"]}'
|
self.p.extra_y_ranges[y_range_name] = y_range
|
||||||
self.p.extra_y_ranges[y_range_name] = Range1d(start=0, end=1)
|
self.p.add_layout(LinearAxis(y_range_name=y_range_name), "right")
|
||||||
|
self.p.line(
|
||||||
self.p.line(x='step', y=m['name'], source=self.source,
|
x="step",
|
||||||
color=color, legend_label=legend,
|
y=m["name"],
|
||||||
y_range_name=y_range_name, line_width=line_width)
|
source=self.source,
|
||||||
|
color=color,
|
||||||
|
legend_label=legend,
|
||||||
|
y_range_name=y_range_name,
|
||||||
|
line_width=line_width,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.p.line(x='step', y=m['name'], source=self.source,
|
self.p.line(
|
||||||
color=color, legend_label=legend,
|
x="step",
|
||||||
line_width=line_width)
|
y=m["name"],
|
||||||
|
source=self.source,
|
||||||
|
color=color,
|
||||||
|
legend_label=legend,
|
||||||
|
line_width=line_width,
|
||||||
|
)
|
||||||
|
|
||||||
self.p.legend.location = "top_left"
|
self.p.legend.location = "top_left"
|
||||||
self.p.legend.click_policy = "hide" # 可点击图例隐藏曲线
|
self.p.legend.click_policy = "hide"
|
||||||
self.handle = show(self.p, notebook_handle=True)
|
self.handle = show(self.p, notebook_handle=True)
|
||||||
|
|
||||||
|
# ---------- TensorBoard 初始化 ----------
|
||||||
|
if self.tensorboard:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("使用 TensorBoard 需要安装 PyTorch (torch)")
|
||||||
|
"""
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
self.writer = SummaryWriter(log_dir=tensorboard_log_dir)
|
||||||
|
print(f"TensorBoard 日志保存在: {self.writer.log_dir}")
|
||||||
|
print(f"查看命令: tensorboard --logdir={self.writer.log_dir}")
|
||||||
|
|
||||||
|
def send_serverchan_message(self, title, content=""):
|
||||||
|
"""
|
||||||
|
向 ServerChan 发送消息(使用实例的 send_key)
|
||||||
|
|
||||||
|
:param title: 消息标题
|
||||||
|
:param content: 消息内容(支持 Markdown)
|
||||||
|
:return: 响应 JSON,若 send_key_serverchan 无效则返回 None
|
||||||
|
"""
|
||||||
|
if not self.send_key_serverchan:
|
||||||
|
logger.warning("警告: 未提供 ServerChan SendKey,消息未发送。")
|
||||||
|
return None
|
||||||
|
url = f"https://sctapi.ftqq.com/{self.send_key_serverchan}.send"
|
||||||
|
data = {"text": title, "desp": content}
|
||||||
|
try:
|
||||||
|
response = requests.post(url, data=data)
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送 ServerChan 消息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_autodl_message(self, title, name, content=""):
|
||||||
|
"""
|
||||||
|
向 AutoDL 发送消息(使用实例的 send_key_autodl)
|
||||||
|
|
||||||
|
:param title: 消息标题
|
||||||
|
:param name: 消息名称
|
||||||
|
:param content: 消息内容(支持 Markdown)
|
||||||
|
:return: 响应 JSON,若 send_key_autodl 无效则返回 None
|
||||||
|
"""
|
||||||
|
if not self.send_key_autodl:
|
||||||
|
logger.warning("警告: 未提供 AutoDL SendKey,消息未发送。")
|
||||||
|
return None
|
||||||
|
url = "https://www.autodl.com/api/v1/wechat/message/send"
|
||||||
|
header = {"Authorization": self.send_key_autodl}
|
||||||
|
try:
|
||||||
|
resp = requests.post(
|
||||||
|
url,
|
||||||
|
json={
|
||||||
|
"title": title,
|
||||||
|
"name": name,
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
headers=header,
|
||||||
|
)
|
||||||
|
return resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送 AutoDL 消息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def add_step(self, step, values):
|
def add_step(self, step, values):
|
||||||
"""
|
"""
|
||||||
追加一个 step 的数据。
|
追加一个 step 的数据。
|
||||||
|
|
||||||
参数
|
参数
|
||||||
----------
|
----------
|
||||||
step : int or float
|
step : int or float
|
||||||
values : dict
|
values : dict
|
||||||
键为指标名,值为该 step 的指标值。
|
键为指标名,值为该 step 的指标值。
|
||||||
例如:{'loss': 0.23, 'acc': 0.85}
|
例如:{'loss': 0.23, 'acc': 0.85, 'val_loss': 0.34}
|
||||||
"""
|
"""
|
||||||
new_data = {'step': [step]}
|
if self.tensorboard:
|
||||||
for name in self.metric_names:
|
for name, value in values.items():
|
||||||
new_data[name] = [values.get(name, np.nan)]
|
self.writer.add_scalar(name, value, step)
|
||||||
self.source.stream(new_data, rollover=10000) # 保留最多 10000 条,防内存爆炸
|
|
||||||
|
|
||||||
# 自动调整 Y 轴范围(右轴)
|
if self.plot_names:
|
||||||
self._adjust_y_ranges()
|
new_data = {"step": [step]}
|
||||||
|
for name in self.plot_names:
|
||||||
def add_batch(self, steps, values_matrix):
|
new_data[name] = [values.get(name, np.nan)]
|
||||||
"""
|
self.source.stream(new_data, rollover=10000)
|
||||||
批量添加多个 step 的数据。
|
self._adjust_y_ranges()
|
||||||
|
|
||||||
参数
|
|
||||||
----------
|
|
||||||
steps : list
|
|
||||||
values_matrix : list of dict
|
|
||||||
每个元素是与 add_step 格式相同的 dict。
|
|
||||||
"""
|
|
||||||
new_data = {'step': steps}
|
|
||||||
# 初始化各列为空列表
|
|
||||||
for name in self.metric_names:
|
|
||||||
new_data[name] = []
|
|
||||||
# 填充数据
|
|
||||||
for vals in values_matrix:
|
|
||||||
for name in self.metric_names:
|
|
||||||
new_data[name].append(vals.get(name, np.nan))
|
|
||||||
self.source.stream(new_data, rollover=10000)
|
|
||||||
self._adjust_y_ranges()
|
|
||||||
|
|
||||||
def _adjust_y_ranges(self):
|
def _adjust_y_ranges(self):
|
||||||
"""自动更新右轴的范围(基于当前数据)"""
|
if not self.plot_metrics:
|
||||||
|
return
|
||||||
df = pd.DataFrame(self.source.data)
|
df = pd.DataFrame(self.source.data)
|
||||||
for m in self.metrics:
|
for m in self.plot_metrics:
|
||||||
if m.get('y_axis') == 'right':
|
if m.get("y_axis") == "right":
|
||||||
col = m['name']
|
col = m["name"]
|
||||||
if col in df.columns and not df[col].empty:
|
if col in df.columns and not df[col].empty:
|
||||||
valid = df[col].dropna()
|
valid = df[col].dropna()
|
||||||
if len(valid) > 0:
|
if len(valid) > 0:
|
||||||
min_val = valid.min()
|
min_val = valid.min()
|
||||||
max_val = valid.max()
|
max_val = valid.max()
|
||||||
pad = (max_val - min_val) * 0.05 # 留5%边距
|
pad = (max_val - min_val) * 0.05
|
||||||
if pad == 0:
|
if pad == 0:
|
||||||
pad = 0.1
|
pad = 0.1
|
||||||
y_range_name = f'right_{col}'
|
y_range_name = f"right_{col}"
|
||||||
if y_range_name in self.p.extra_y_ranges:
|
if y_range_name in self.p.extra_y_ranges:
|
||||||
self.p.extra_y_ranges[y_range_name].start = min_val - pad
|
self.p.extra_y_ranges[y_range_name].start = min_val - pad
|
||||||
self.p.extra_y_ranges[y_range_name].end = max_val + pad
|
self.p.extra_y_ranges[y_range_name].end = max_val + pad
|
||||||
push_notebook(handle=self.handle)
|
push_notebook(handle=self.handle)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""清空所有数据"""
|
"""清空 Bokeh 图表数据(TensorBoard 记录不受影响)"""
|
||||||
self.source.data = {'step': []}
|
self.source.data = {"step": []}
|
||||||
for name in self.metric_names:
|
for name in self.plot_names:
|
||||||
self.source.data[name] = []
|
self.source.data[name] = []
|
||||||
push_notebook(handle=self.handle)
|
push_notebook(handle=self.handle)
|
||||||
|
|
||||||
|
def finish(self, filename="", message=None, send_message=True):
|
||||||
|
"""
|
||||||
|
训练结束时的操作:保存图表、发送 ServerChan 通知。
|
||||||
|
|
||||||
|
参数
|
||||||
|
----------
|
||||||
|
filename : str
|
||||||
|
保存 Bokeh 图表的 HTML 文件路径。
|
||||||
|
message : str, 可选
|
||||||
|
自定义消息内容(Markdown 格式)。若为 None,则自动生成包含最终指标值的摘要。
|
||||||
|
send_message : bool, 默认 True
|
||||||
|
是否发送 ServerChan 消息。
|
||||||
|
"""
|
||||||
|
# 1. 保存图表为 HTML
|
||||||
|
if filename == "":
|
||||||
|
filename = f"training_monitor_{pendulum.now().format('YYYYMMDDHH')}.html"
|
||||||
|
try:
|
||||||
|
save(
|
||||||
|
self.p,
|
||||||
|
filename=filename,
|
||||||
|
title=f"训练监控 - {pendulum.now().format('YYYYMMDDHHmm')}",
|
||||||
|
resources=INLINE,
|
||||||
|
|
||||||
|
)
|
||||||
|
print(f"图表已保存至: {os.path.abspath(filename)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存图表失败: {e}")
|
||||||
|
|
||||||
|
# 2. 发送 ServerChan 消息
|
||||||
|
if send_message:
|
||||||
|
if message is None:
|
||||||
|
# 自动生成摘要:获取最后一个 step 的各指标值
|
||||||
|
data = self.source.data
|
||||||
|
if data["step"]:
|
||||||
|
last_step = data["step"][-1]
|
||||||
|
lines = [f"### 训练完成 (step {last_step})"]
|
||||||
|
for name in self.metric_names:
|
||||||
|
if name in data and len(data[name]) > 0:
|
||||||
|
val = data[name][-1]
|
||||||
|
lines.append(f"- **{name}**: {val:.4f}")
|
||||||
|
message = "\n".join(lines)
|
||||||
|
else:
|
||||||
|
message = "训练完成,但无数据记录。"
|
||||||
|
if not self.send_autodl_message(
|
||||||
|
"训练结束通知", name="训练", content=message
|
||||||
|
):
|
||||||
|
self.send_serverchan_message("训练结束通知", message)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 初始化监控器(支持自定义指标)
|
# 示例:使用 Bokeh + TensorBoard + ServerChan
|
||||||
monitor = TrainingMonitor(
|
monitor = TrainingMonitor(
|
||||||
metrics=[
|
metrics=[
|
||||||
{'name': 'loss', 'label': 'train_loss', 'color': 'red', 'y_axis': 'left'},
|
{"name": "loss", "label": "train_loss", "color": "red", "plot": True},
|
||||||
{'name': 'acc', 'label': 'train_acc', 'color': 'blue', 'y_axis': 'right', 'y_range': [0, 1]},
|
{
|
||||||
{'name': 'val_loss', 'label': 'val_loss', 'color': 'orange', 'y_axis': 'left'},
|
"name": "acc",
|
||||||
{'name': 'val_acc', 'label': 'val_acc', 'color': 'green', 'y_axis': 'right'},
|
"label": "train_acc",
|
||||||
|
"color": "blue",
|
||||||
|
"y_axis": "right",
|
||||||
|
"y_range": [0, 1],
|
||||||
|
"plot": True,
|
||||||
|
},
|
||||||
|
{"name": "val_loss", "label": "val_loss", "color": "orange", "plot": True},
|
||||||
|
{
|
||||||
|
"name": "val_acc",
|
||||||
|
"label": "val_acc",
|
||||||
|
"color": "green",
|
||||||
|
"y_axis": "right",
|
||||||
|
"plot": False,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
title="BERT 训练曲线"
|
title="BERT 训练曲线",
|
||||||
|
tensorboard=True,
|
||||||
|
tensorboard_log_dir="runs/example",
|
||||||
|
send_key_autodl=None, # 若未提供,会尝试从环境变量读取
|
||||||
|
send_key_serverchan=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 模拟训练
|
# 模拟训练
|
||||||
for step in range(1, 101):
|
for step in range(1, 101):
|
||||||
train_loss = 1.0 / step
|
train_loss = 1.0 / step
|
||||||
train_acc = 0.5 + 0.4 * (1 - 1/step)
|
train_acc = 0.5 + 0.4 * (1 - 1 / step)
|
||||||
val_loss = 1.2 / step
|
val_loss = 1.2 / step
|
||||||
val_acc = 0.48 + 0.4 * (1 - 1/step)
|
val_acc = 0.48 + 0.4 * (1 - 1 / step)
|
||||||
|
monitor.add_step(
|
||||||
monitor.add_step(step, {
|
step,
|
||||||
'loss': train_loss,
|
{
|
||||||
'acc': train_acc,
|
"loss": train_loss,
|
||||||
'val_loss': val_loss,
|
"acc": train_acc,
|
||||||
'val_acc': val_acc
|
"val_loss": val_loss,
|
||||||
})
|
"val_acc": val_acc,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练结束:保存图表并发送通知
|
||||||
|
monitor.finish("bert_training.html")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue