feat(model): 添加训练完成通知功能,通过ServerChan发送微信消息
This commit is contained in:
parent
019fa2d23d
commit
63efc49aa6
|
|
@ -15,6 +15,7 @@ dependencies = [
|
||||||
"numpy>=2.4.2",
|
"numpy>=2.4.2",
|
||||||
"pandas>=3.0.0",
|
"pandas>=3.0.0",
|
||||||
"pypinyin>=0.55.0",
|
"pypinyin>=0.55.0",
|
||||||
|
"requests>=2.32.5",
|
||||||
"rich>=14.3.1",
|
"rich>=14.3.1",
|
||||||
"transformers>=5.1.0",
|
"transformers>=5.1.0",
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from loguru import logger
|
||||||
from modelscope import AutoModel, AutoTokenizer
|
from modelscope import AutoModel, AutoTokenizer
|
||||||
from tqdm.autonotebook import tqdm
|
from tqdm.autonotebook import tqdm
|
||||||
|
|
||||||
from .monitor import TrainingMonitor
|
from .monitor import TrainingMonitor, send_serverchan_message
|
||||||
from suinput.dataset import PG
|
from suinput.dataset import PG
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -282,7 +282,7 @@ class MoEModel(nn.Module):
|
||||||
avg_loss: float, 平均损失
|
avg_loss: float, 平均损失
|
||||||
"""
|
"""
|
||||||
if criterion is None:
|
if criterion is None:
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.NLLLoss()
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
|
@ -300,7 +300,7 @@ class MoEModel(nn.Module):
|
||||||
# 前向传播
|
# 前向传播
|
||||||
probs = self(input_ids, attention_mask, pg)
|
probs = self(input_ids, attention_mask, pg)
|
||||||
log_probs = torch.log(probs + 1e-12)
|
log_probs = torch.log(probs + 1e-12)
|
||||||
loss = nn.NLLLoss()(log_probs, labels)
|
loss = criterion(log_probs, labels)
|
||||||
total_loss += loss.item() * labels.size(0)
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
# 计算准确率
|
# 计算准确率
|
||||||
|
|
@ -541,6 +541,16 @@ class MoEModel(nn.Module):
|
||||||
if processed_batches - 1 >= stop_batch:
|
if processed_batches - 1 >= stop_batch:
|
||||||
break
|
break
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
res_acc, res_loss = self.model_eval(eval_dataloader)
|
||||||
|
try:
|
||||||
|
to_wechat_response = send_serverchan_message(
|
||||||
|
title="训练完成",
|
||||||
|
content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}",
|
||||||
|
)
|
||||||
|
logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}")
|
||||||
|
logger.info(f"发送消息: {to_wechat_response}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e}")
|
||||||
|
|
||||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,40 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from bokeh.io import push_notebook, show, output_notebook
|
from bokeh.io import push_notebook, show, output_notebook
|
||||||
from bokeh.plotting import figure
|
from bokeh.plotting import figure
|
||||||
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
|
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
output_notebook() # 在 Jupyter 中必须调用一次
|
output_notebook() # 在 Jupyter 中必须调用一次
|
||||||
|
|
||||||
|
|
||||||
|
def send_serverchan_message(title, content=""):
|
||||||
|
"""
|
||||||
|
向ServerChan发送消息(SendKey从环境变量读取)
|
||||||
|
|
||||||
|
:param title: 消息标题
|
||||||
|
:param content: 消息内容(支持Markdown)
|
||||||
|
:return: 响应JSON
|
||||||
|
"""
|
||||||
|
# 从环境变量获取SendKey
|
||||||
|
send_key = os.environ.get("SERVERCHAN_SEND_KEY")
|
||||||
|
|
||||||
|
if not send_key:
|
||||||
|
raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY,请检查配置")
|
||||||
|
|
||||||
|
url = f"https://sctapi.ftqq.com/{send_key}.send"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"text": title,
|
||||||
|
"desp": content
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, data=data)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
class TrainingMonitor:
|
class TrainingMonitor:
|
||||||
"""
|
"""
|
||||||
实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。
|
实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue