diff --git a/pyproject.toml b/pyproject.toml index 9a28047..15389ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numpy>=2.4.2", "pandas>=3.0.0", "pypinyin>=0.55.0", + "requests>=2.32.5", "rich>=14.3.1", "transformers>=5.1.0", "typer>=0.21.1", diff --git a/src/trainer/model.py b/src/trainer/model.py index 98633b9..cabfb6a 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -12,7 +12,7 @@ from loguru import logger from modelscope import AutoModel, AutoTokenizer from tqdm.autonotebook import tqdm -from .monitor import TrainingMonitor +from .monitor import TrainingMonitor, send_serverchan_message from suinput.dataset import PG @@ -282,7 +282,7 @@ class MoEModel(nn.Module): avg_loss: float, 平均损失 """ if criterion is None: - criterion = nn.CrossEntropyLoss() + criterion = nn.NLLLoss() self.eval() total_loss = 0.0 @@ -300,7 +300,7 @@ class MoEModel(nn.Module): # 前向传播 probs = self(input_ids, attention_mask, pg) log_probs = torch.log(probs + 1e-12) - loss = nn.NLLLoss()(log_probs, labels) + loss = criterion(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 @@ -541,6 +541,16 @@ class MoEModel(nn.Module): if processed_batches - 1 >= stop_batch: break global_step += 1 + res_acc, res_loss = self.model_eval(eval_dataloader) + try: + to_wechat_response = send_serverchan_message( + title="训练完成", + content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}", + ) + logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}") + logger.info(f"发送消息: {to_wechat_response}") + except Exception as e: + logger.error(f"发送消息失败: {e}") def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load( diff --git a/src/trainer/monitor.py b/src/trainer/monitor.py index d2f4de0..00859d2 100644 --- a/src/trainer/monitor.py +++ b/src/trainer/monitor.py @@ -1,11 +1,40 @@ +import os + from bokeh.io import push_notebook, show, output_notebook from bokeh.plotting import figure from bokeh.models import ColumnDataSource, LinearAxis, Range1d import pandas as pd import numpy as np +import requests output_notebook() # 在 Jupyter 中必须调用一次 + +def send_serverchan_message(title, content=""): + """ + 向ServerChan发送消息(SendKey从环境变量读取) + + :param title: 消息标题 + :param content: 消息内容(支持Markdown) + :return: 响应JSON + """ + # 从环境变量获取SendKey + send_key = os.environ.get("SERVERCHAN_SEND_KEY") + + if not send_key: + raise ValueError("❌ 未找到环境变量 SERVERCHAN_SEND_KEY,请检查配置") + + url = f"https://sctapi.ftqq.com/{send_key}.send" + + data = { + "text": title, + "desp": content + } + + response = requests.post(url, data=data) + return response.json() + + class TrainingMonitor: """ 实时训练监控图,支持任意多个指标,自动管理左右 Y 轴。