From 63efc49aa63e81a379c56b62eca870eed79ee2d1 Mon Sep 17 00:00:00 2001 From: songsenand Date: Tue, 24 Feb 2026 00:10:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=AE=8C=E6=88=90=E9=80=9A=E7=9F=A5=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E9=80=9A=E8=BF=87ServerChan=E5=8F=91=E9=80=81?= =?UTF-8?q?=E5=BE=AE=E4=BF=A1=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 + src/trainer/model.py | 16 +++++++++++++--- src/trainer/monitor.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a28047..15389ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numpy>=2.4.2", "pandas>=3.0.0", "pypinyin>=0.55.0", + "requests>=2.32.5", "rich>=14.3.1", "transformers>=5.1.0", "typer>=0.21.1", diff --git a/src/trainer/model.py b/src/trainer/model.py index 98633b9..cabfb6a 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -12,7 +12,7 @@ from loguru import logger from modelscope import AutoModel, AutoTokenizer from tqdm.autonotebook import tqdm -from .monitor import TrainingMonitor +from .monitor import TrainingMonitor, send_serverchan_message from suinput.dataset import PG @@ -282,7 +282,7 @@ class MoEModel(nn.Module): avg_loss: float, 平均损失 """ if criterion is None: - criterion = nn.CrossEntropyLoss() + criterion = nn.NLLLoss() self.eval() total_loss = 0.0 @@ -300,7 +300,7 @@ class MoEModel(nn.Module): # 前向传播 probs = self(input_ids, attention_mask, pg) log_probs = torch.log(probs + 1e-12) - loss = nn.NLLLoss()(log_probs, labels) + loss = criterion(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 @@ -541,6 +541,16 @@ class MoEModel(nn.Module): if processed_batches - 1 >= stop_batch: break global_step += 1 + res_acc, res_loss = self.model_eval(eval_dataloader) + try: + to_wechat_response = send_serverchan_message( + title="训练完成", + content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}", + ) + logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}") + logger.info(f"发送消息: {to_wechat_response}") + except Exception as e: + logger.error(f"发送消息失败: {e}") def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load( diff --git a/src/trainer/monitor.py b/src/trainer/monitor.py index d2f4de0..00859d2 100644 --- a/src/trainer/monitor.py +++ b/src/trainer/monitor.py @@ -1,11 +1,40 @@ +import os + from bokeh.io import push_notebook, show, output_notebook from bokeh.plotting import figure from bokeh.models import ColumnDataSource, LinearAxis, Range1d import pandas as pd import numpy as np +import requests output_notebook() # 在 Jupyter 中必须调用一次 + +def send_serverchan_message(title, content=""): + """ + 向ServerChan发送消息(SendKey从环境变量读取) + + :param title: 消息标题 + :param content: 消息内容(支持Markdown) + :return: 响应JSON + """ + # 从环境变量获取SendKey + send_key = os.environ.get("SERVERCHAN_SEND_KEY") + + if not send_key: + raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY,请检查配置") + + url = f"https://sctapi.ftqq.com/{send_key}.send" + + data = { + "text": title, + "desp": content + } + + response = requests.post(url, data=data) + return response.json() + + class TrainingMonitor: """ 实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。