feat(model): 添加训练完成通知功能,通过ServerChan发送微信消息

This commit is contained in:
songsenand 2026-02-24 00:10:25 +08:00
parent 019fa2d23d
commit 63efc49aa6
3 changed files with 43 additions and 3 deletions

View File

@ -15,6 +15,7 @@ dependencies = [
"numpy>=2.4.2",
"pandas>=3.0.0",
"pypinyin>=0.55.0",
"requests>=2.32.5",
"rich>=14.3.1",
"transformers>=5.1.0",
"typer>=0.21.1",

View File

@ -12,7 +12,7 @@ from loguru import logger
from modelscope import AutoModel, AutoTokenizer
from tqdm.autonotebook import tqdm
from .monitor import TrainingMonitor
from .monitor import TrainingMonitor, send_serverchan_message
from suinput.dataset import PG
@ -282,7 +282,7 @@ class MoEModel(nn.Module):
avg_loss: float, 平均损失
"""
if criterion is None:
criterion = nn.CrossEntropyLoss()
criterion = nn.NLLLoss()
self.eval()
total_loss = 0.0
@ -300,7 +300,7 @@ class MoEModel(nn.Module):
# 前向传播
probs = self(input_ids, attention_mask, pg)
log_probs = torch.log(probs + 1e-12)
loss = nn.NLLLoss()(log_probs, labels)
loss = criterion(log_probs, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
@ -541,6 +541,16 @@ class MoEModel(nn.Module):
if processed_batches - 1 >= stop_batch:
break
global_step += 1
res_acc, res_loss = self.model_eval(eval_dataloader)
try:
to_wechat_response = send_serverchan_message(
title="训练完成",
content=f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}",
)
logger.info(f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}")
logger.info(f"发送消息: {to_wechat_response}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(

View File

@ -1,11 +1,40 @@
import os
from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
import pandas as pd
import numpy as np
import requests
output_notebook() # 在 Jupyter 中必须调用一次
def send_serverchan_message(title, content=""):
"""
向ServerChan发送消息SendKey从环境变量读取
:param title: 消息标题
:param content: 消息内容支持Markdown
:return: 响应JSON
"""
# 从环境变量获取SendKey
send_key = os.environ.get("SERVERCHAN_SEND_KEY")
if not send_key:
raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY请检查配置")
url = f"https://sctapi.ftqq.com/{send_key}.send"
data = {
"text": title,
"desp": content
}
response = requests.post(url, data=data)
return response.json()
class TrainingMonitor:
"""
实时训练监控图支持任意多个指标自动管理左右 Y