172 lines
6.8 KiB
Python
172 lines
6.8 KiB
Python
from bokeh.io import push_notebook, show, output_notebook
|
||
from bokeh.plotting import figure
|
||
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
|
||
import pandas as pd
|
||
import numpy as np
|
||
|
||
output_notebook() # 在 Jupyter 中必须调用一次
|
||
|
||
class TrainingMonitor:
|
||
"""
|
||
实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。
|
||
|
||
参数
|
||
----------
|
||
metrics : list of dict, 可选
|
||
每个指标是一个 dict,必须包含 'name'(数据列名)和 'label'(图例)。
|
||
可选字段:'color' (颜色), 'y_axis' ('left' 或 'right', 默认 'left'),
|
||
'y_range' (手动指定 Y 轴范围,如 [0,1])。
|
||
若为 None,则使用默认的 [loss, acc]。
|
||
title : str, 默认 "训练曲线"
|
||
width : int, 默认 1080
|
||
height : int, 默认 384
|
||
line_width : int, 默认 2
|
||
"""
|
||
def __init__(self, metrics=None, title="训练曲线", width=1080, height=384, line_width=2):
|
||
# 默认指标:loss 左轴,acc 右轴
|
||
if metrics is None:
|
||
metrics = [
|
||
{'name': 'loss', 'label': 'loss', 'color': '#ed5a65', 'y_axis': 'left'},
|
||
{'name': 'acc', 'label': 'accuracy', 'color': '#2b1216', 'y_axis': 'right', 'y_range': [0, 1]}
|
||
]
|
||
self.metrics = metrics
|
||
self.metric_names = [m['name'] for m in metrics]
|
||
|
||
# 初始化数据源,含 step 列 + 各指标列
|
||
self.source = ColumnDataSource(data={'step': []})
|
||
for m in metrics:
|
||
self.source.data[m['name']] = []
|
||
|
||
# 创建图形
|
||
self.p = figure(title=title, width=width, height=height,
|
||
x_axis_label='step', y_axis_label='left axis')
|
||
self.p.extra_y_ranges = {} # 存放右轴
|
||
|
||
# 为每个指标添加线条
|
||
for m in metrics:
|
||
color = m.get('color', None)
|
||
y_axis = m.get('y_axis', 'left')
|
||
legend = m.get('label', m['name'])
|
||
|
||
if y_axis == 'right':
|
||
# 创建右轴(若尚未创建)
|
||
if y_axis not in self.p.extra_y_ranges:
|
||
y_range_name = f'{y_axis}_{m["name"]}' # 唯一命名
|
||
# 使用手动范围或自动计算
|
||
y_range = m.get('y_range', None)
|
||
if y_range is None:
|
||
y_range = Range1d(start=0, end=1) # 占位,稍后自动调整
|
||
else:
|
||
y_range = Range1d(start=y_range[0], end=y_range[1])
|
||
self.p.extra_y_ranges[y_range_name] = y_range
|
||
self.p.add_layout(LinearAxis(y_range_name=y_range_name), 'right')
|
||
else:
|
||
# 复用已创建的右轴,简单起见:每个右轴指标使用独立 y_range_name
|
||
y_range_name = f'right_{m["name"]}'
|
||
self.p.extra_y_ranges[y_range_name] = Range1d(start=0, end=1)
|
||
|
||
self.p.line(x='step', y=m['name'], source=self.source,
|
||
color=color, legend_label=legend,
|
||
y_range_name=y_range_name, line_width=line_width)
|
||
else:
|
||
self.p.line(x='step', y=m['name'], source=self.source,
|
||
color=color, legend_label=legend,
|
||
line_width=line_width)
|
||
|
||
self.p.legend.location = "top_left"
|
||
self.p.legend.click_policy = "hide" # 可点击图例隐藏曲线
|
||
self.handle = show(self.p, notebook_handle=True)
|
||
|
||
def add_step(self, step, values):
|
||
"""
|
||
追加一个 step 的数据。
|
||
|
||
参数
|
||
----------
|
||
step : int or float
|
||
values : dict
|
||
键为指标名,值为该 step 的指标值。
|
||
例如:{'loss': 0.23, 'acc': 0.85}
|
||
"""
|
||
new_data = {'step': [step]}
|
||
for name in self.metric_names:
|
||
new_data[name] = [values.get(name, np.nan)]
|
||
self.source.stream(new_data, rollover=10000) # 保留最多 10000 条,防内存爆炸
|
||
|
||
# 自动调整 Y 轴范围(右轴)
|
||
self._adjust_y_ranges()
|
||
|
||
def add_batch(self, steps, values_matrix):
|
||
"""
|
||
批量添加多个 step 的数据。
|
||
|
||
参数
|
||
----------
|
||
steps : list
|
||
values_matrix : list of dict
|
||
每个元素是与 add_step 格式相同的 dict。
|
||
"""
|
||
new_data = {'step': steps}
|
||
# 初始化各列为空列表
|
||
for name in self.metric_names:
|
||
new_data[name] = []
|
||
# 填充数据
|
||
for vals in values_matrix:
|
||
for name in self.metric_names:
|
||
new_data[name].append(vals.get(name, np.nan))
|
||
self.source.stream(new_data, rollover=10000)
|
||
self._adjust_y_ranges()
|
||
|
||
def _adjust_y_ranges(self):
|
||
"""自动更新右轴的范围(基于当前数据)"""
|
||
df = pd.DataFrame(self.source.data)
|
||
for m in self.metrics:
|
||
if m.get('y_axis') == 'right':
|
||
col = m['name']
|
||
if col in df.columns and not df[col].empty:
|
||
valid = df[col].dropna()
|
||
if len(valid) > 0:
|
||
min_val = valid.min()
|
||
max_val = valid.max()
|
||
pad = (max_val - min_val) * 0.05 # 留5%边距
|
||
if pad == 0:
|
||
pad = 0.1
|
||
y_range_name = f'right_{col}'
|
||
if y_range_name in self.p.extra_y_ranges:
|
||
self.p.extra_y_ranges[y_range_name].start = min_val - pad
|
||
self.p.extra_y_ranges[y_range_name].end = max_val + pad
|
||
push_notebook(handle=self.handle)
|
||
|
||
def clear(self):
|
||
"""清空所有数据"""
|
||
self.source.data = {'step': []}
|
||
for name in self.metric_names:
|
||
self.source.data[name] = []
|
||
push_notebook(handle=self.handle)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 初始化监控器(支持自定义指标)
|
||
monitor = TrainingMonitor(
|
||
metrics=[
|
||
{'name': 'loss', 'label': 'train_loss', 'color': 'red', 'y_axis': 'left'},
|
||
{'name': 'acc', 'label': 'train_acc', 'color': 'blue', 'y_axis': 'right', 'y_range': [0, 1]},
|
||
{'name': 'val_loss', 'label': 'val_loss', 'color': 'orange', 'y_axis': 'left'},
|
||
{'name': 'val_acc', 'label': 'val_acc', 'color': 'green', 'y_axis': 'right'},
|
||
],
|
||
title="BERT 训练曲线"
|
||
)
|
||
|
||
# 模拟训练
|
||
for step in range(1, 101):
|
||
train_loss = 1.0 / step
|
||
train_acc = 0.5 + 0.4 * (1 - 1/step)
|
||
val_loss = 1.2 / step
|
||
val_acc = 0.48 + 0.4 * (1 - 1/step)
|
||
|
||
monitor.add_step(step, {
|
||
'loss': train_loss,
|
||
'acc': train_acc,
|
||
'val_loss': val_loss,
|
||
'val_acc': val_acc
|
||
}) |