feat(model): 添加训练完成通知功能,通过ServerChan发送微信消息

This commit is contained in:
songsenand 2026-02-24 00:10:25 +08:00
parent 019fa2d23d
commit 63efc49aa6
3 changed files with 43 additions and 3 deletions

View File

@ -15,6 +15,7 @@ dependencies = [
"numpy>=2.4.2", "numpy>=2.4.2",
"pandas>=3.0.0", "pandas>=3.0.0",
"pypinyin>=0.55.0", "pypinyin>=0.55.0",
"requests>=2.32.5",
"rich>=14.3.1", "rich>=14.3.1",
"transformers>=5.1.0", "transformers>=5.1.0",
"typer>=0.21.1", "typer>=0.21.1",

View File

@ -12,7 +12,7 @@ from loguru import logger
from modelscope import AutoModel, AutoTokenizer from modelscope import AutoModel, AutoTokenizer
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
from .monitor import TrainingMonitor from .monitor import TrainingMonitor, send_serverchan_message
from suinput.dataset import PG from suinput.dataset import PG
@ -282,7 +282,7 @@ class MoEModel(nn.Module):
avg_loss: float, 平均损失 avg_loss: float, 平均损失
""" """
if criterion is None: if criterion is None:
criterion = nn.CrossEntropyLoss() criterion = nn.NLLLoss()
self.eval() self.eval()
total_loss = 0.0 total_loss = 0.0
@ -300,7 +300,7 @@ class MoEModel(nn.Module):
# 前向传播 # 前向传播
probs = self(input_ids, attention_mask, pg) probs = self(input_ids, attention_mask, pg)
log_probs = torch.log(probs + 1e-12) log_probs = torch.log(probs + 1e-12)
loss = nn.NLLLoss()(log_probs, labels) loss = criterion(log_probs, labels)
total_loss += loss.item() * labels.size(0) total_loss += loss.item() * labels.size(0)
# 计算准确率 # 计算准确率
@ -541,6 +541,16 @@ class MoEModel(nn.Module):
if processed_batches - 1 >= stop_batch: if processed_batches - 1 >= stop_batch:
break break
global_step += 1 global_step += 1
res_acc, res_loss = self.model_eval(eval_dataloader)
try:
to_wechat_response = send_serverchan_message(
title="训练完成",
content=f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}",
)
logger.info(f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}")
logger.info(f"发送消息: {to_wechat_response}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
def load_from_state_dict(self, state_dict_path: Union[str, Path]): def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load( state_dict = torch.load(

View File

@ -1,11 +1,40 @@
import os
from bokeh.io import push_notebook, show, output_notebook from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, LinearAxis, Range1d from bokeh.models import ColumnDataSource, LinearAxis, Range1d
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import requests
output_notebook() # 在 Jupyter 中必须调用一次 output_notebook() # 在 Jupyter 中必须调用一次
def send_serverchan_message(title, content=""):
"""
向ServerChan发送消息SendKey从环境变量读取
:param title: 消息标题
:param content: 消息内容支持Markdown
:return: 响应JSON
"""
# 从环境变量获取SendKey
send_key = os.environ.get("SERVERCHAN_SEND_KEY")
if not send_key:
raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY请检查配置")
url = f"https://sctapi.ftqq.com/{send_key}.send"
data = {
"text": title,
"desp": content
}
response = requests.post(url, data=data)
return response.json()
class TrainingMonitor: class TrainingMonitor:
""" """
实时训练监控图支持任意多个指标自动管理左右 Y 实时训练监控图支持任意多个指标自动管理左右 Y