feat(model): 添加训练完成通知功能,通过ServerChan发送微信消息
This commit is contained in:
parent
019fa2d23d
commit
63efc49aa6
|
|
@ -15,6 +15,7 @@ dependencies = [
|
|||
"numpy>=2.4.2",
|
||||
"pandas>=3.0.0",
|
||||
"pypinyin>=0.55.0",
|
||||
"requests>=2.32.5",
|
||||
"rich>=14.3.1",
|
||||
"transformers>=5.1.0",
|
||||
"typer>=0.21.1",
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from loguru import logger
|
|||
from modelscope import AutoModel, AutoTokenizer
|
||||
from tqdm.autonotebook import tqdm
|
||||
|
||||
from .monitor import TrainingMonitor
|
||||
from .monitor import TrainingMonitor, send_serverchan_message
|
||||
from suinput.dataset import PG
|
||||
|
||||
|
||||
|
|
@ -282,7 +282,7 @@ class MoEModel(nn.Module):
|
|||
avg_loss: float, 平均损失
|
||||
"""
|
||||
if criterion is None:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = nn.NLLLoss()
|
||||
|
||||
self.eval()
|
||||
total_loss = 0.0
|
||||
|
|
@ -300,7 +300,7 @@ class MoEModel(nn.Module):
|
|||
# 前向传播
|
||||
probs = self(input_ids, attention_mask, pg)
|
||||
log_probs = torch.log(probs + 1e-12)
|
||||
loss = nn.NLLLoss()(log_probs, labels)
|
||||
loss = criterion(log_probs, labels)
|
||||
total_loss += loss.item() * labels.size(0)
|
||||
|
||||
# 计算准确率
|
||||
|
|
@ -541,6 +541,16 @@ class MoEModel(nn.Module):
|
|||
if processed_batches - 1 >= stop_batch:
|
||||
break
|
||||
global_step += 1
|
||||
res_acc, res_loss = self.model_eval(eval_dataloader)
|
||||
try:
|
||||
to_wechat_response = send_serverchan_message(
|
||||
title="训练完成",
|
||||
content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}",
|
||||
)
|
||||
logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}")
|
||||
logger.info(f"发送消息: {to_wechat_response}")
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||
state_dict = torch.load(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,40 @@
|
|||
import os
|
||||
|
||||
from bokeh.io import push_notebook, show, output_notebook
|
||||
from bokeh.plotting import figure
|
||||
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
output_notebook() # 在 Jupyter 中必须调用一次
|
||||
|
||||
|
||||
def send_serverchan_message(title, content=""):
|
||||
"""
|
||||
向ServerChan发送消息(SendKey从环境变量读取)
|
||||
|
||||
:param title: 消息标题
|
||||
:param content: 消息内容(支持Markdown)
|
||||
:return: 响应JSON
|
||||
"""
|
||||
# 从环境变量获取SendKey
|
||||
send_key = os.environ.get("SERVERCHAN_SEND_KEY")
|
||||
|
||||
if not send_key:
|
||||
raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY,请检查配置")
|
||||
|
||||
url = f"https://sctapi.ftqq.com/{send_key}.send"
|
||||
|
||||
data = {
|
||||
"text": title,
|
||||
"desp": content
|
||||
}
|
||||
|
||||
response = requests.post(url, data=data)
|
||||
return response.json()
|
||||
|
||||
|
||||
class TrainingMonitor:
|
||||
"""
|
||||
实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。
|
||||
|
|
|
|||
Loading…
Reference in New Issue