feat(model): 更新模型结构,使用 GELU 激活函数并优化专家网络参数

This commit is contained in:
songsenand 2026-02-25 16:56:09 +08:00
parent db90516fcf
commit 93dced50c7
1 changed files with 190 additions and 147 deletions

View File

@ -1,3 +1,4 @@
import math
import pickle
from importlib.resources import files
from pathlib import Path
@ -12,9 +13,10 @@ from loguru import logger
from modelscope import AutoModel, AutoTokenizer
from tqdm.autonotebook import tqdm
from .monitor import TrainingMonitor, send_serverchan_message
from suinput.dataset import PG
from .monitor import TrainingMonitor, send_serverchan_message
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
@ -22,61 +24,46 @@ def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.1):
def __init__(self, dim, dropout_prob=0.3):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.relu = nn.ReLU()
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
x = self.relu(self.linear1(x))
x = self.gelu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x)
x = x + residual
return self.relu(x)
return self.gelu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim, # 输入特征的维度大小
d_model=1024, # 模型内部的隐藏层维度默认为1024
num_resblocks=4, # 残差块的数量默认为4
output_multiplier=2, # 输出维度是输入维度的倍数默认为2倍
dropout_prob=0.1, # Dropout层的丢弃概率默认为0.1
input_dim,
d_model=768,
num_resblocks=3,
output_multiplier=1,
dropout_prob=0.3,
):
"""
初始化函数用于构建模型的各个层
参数说明
input_dim : 输入维度
d_model : 专家内部维度
output_multiplier : 输出维度 = input_dim * output_multiplier
dropout_prob : 残差块内 Dropout
"""
super().__init__() # 调用父类的初始化方法
self.input_dim = input_dim # 保存输入维度
self.d_model = d_model # 保存模型内部维度
self.output_dim = input_dim * output_multiplier # 计算并保存输出维度
# 输入映射input_dim -> d_model
super().__init__()
self.output_dim = input_dim * output_multiplier
self.linear_in = nn.Linear(input_dim, d_model)
# 残差堆叠
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
# 输出映射d_model -> output_dim
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(inplace=True),
nn.GELU(inplace=True),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim),
)
@ -91,7 +78,7 @@ class Expert(nn.Module):
class MoEModel(nn.Module):
def __init__(
self,
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
num_classes=10018,
output_multiplier=2,
d_model=768,
@ -102,14 +89,15 @@ class MoEModel(nn.Module):
super().__init__()
self.output_multiplier = output_multiplier
# 1. 加载预训练 BERT仅保留 embeddings
# 1. 加载预训练 Embedding (Lite: hidden_size=512)
logger.info(f"Loading backbone: {pretrained_model_name}")
bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置
self.hidden_size = self.bert_config.hidden_size # 512
self.device = None
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
# 2. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
@ -120,63 +108,49 @@ class MoEModel(nn.Module):
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
# self.shared_resblocks = nn.ModuleList(
# [ResidualBlock(self.hidden_size, 0.1) for _ in range(4)]
# )
self.pooler = nn.AdaptiveAvgPool1d(1)
# self.linear = nn.Linear(self.hidden_size, self.hidden_size)
# 3. 专家层n个领域专家 + 1个共享专家
# 3. 专家系统
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList()
for i in range(total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
dropout_prob = 0.1 if i < num_domain_experts else 0.2
expert = Expert(
dropout = 0.3 if i < num_domain_experts else 0.4
self.experts.append(
Expert(
input_dim=self.hidden_size,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier,
dropout_prob=dropout_prob,
dropout_prob=dropout,
)
)
self.experts.append(expert)
self.expert_bias = nn.Embedding(
total_experts, self.output_multiplier * self.hidden_size
total_experts, self.hidden_size * self.output_multiplier
)
# 4. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.output_multiplier * self.hidden_size),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Dropout(0.4),
nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
)
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg):
def forward(self, input_ids, attention_mask, pg, p_start):
"""
input_ids : [batch, seq_len]
attention_mask: [batch, seq_len] (1 为有效0 padding)
pg : group_id训练时为 [batch] LongTensor推理导出时为标量 Tensor
ONNX 兼容的 Forward 函数
Args:
input_ids: [B, L]
attention_mask: [B, L]
pg: [B] 拼音组 ID
p_start: [B] 拼音起始索引位置 (整数 Tensor)
"""
# ----- 1. Embeddings -----
embeddings = self.embedding(input_ids) # [B, S, H]
embeddings = self.embedding(input_ids)
# ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置
@ -185,13 +159,53 @@ class MoEModel(nn.Module):
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
# ----- 3. 池化量 -----
# for block in self.shared_resblocks:
# encoded = block(encoded)
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
# pooled = pooled.flatten(1) # [B, H*2]
# pooled = self.linear(pooled)
# ----- 3. ONNX 兼容的 Span Pooling (向量化实现) -----
"""
思路
我们不能用循环去切片我们要构造一个 Mask 矩阵
目标对于每个样本 i生成一个长度为 L 的向量其中 p_start[i] < index < p_end[i] 的位置为 1其余为 0
步骤
1. 生成位置索引轴indices = [0, 1, 2, ..., L-1] (Shape: [L])
2. 扩展维度以匹配 Batch
indices: [1, L]
p_start: [B, 1]
p_end: [B, 1]
3. 逻辑比较 (Broadcasting)
mask = (indices > p_start) & (indices < p_end)
结果 Shape: [B, L] (Boolean)
4. 应用 Mask
masked_encoded = encoded * mask.unsqueeze(-1)
5. 求和并归一化
sum_vec = masked_encoded.sum(dim=1)
count = mask.sum(dim=1).clamp(min=1) # 防止除零
pooled = sum_vec / count
"""
B, L, H = encoded.shape
device = encoded.device
# 生成位置轴 [0, 1, ..., L-1]
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
# 调整 p_start 形状为 [B, 1] 以便广播
p_start_exp = p_start.unsqueeze(1) # [B, 1]
span_mask = positions >= p_start_exp
# 转换为 Float 用于乘法
span_mask_float = span_mask.float() # [B, L]
# 应用 Mask
# encoded: [B, L, H] -> mask: [B, L, 1]
masked_encoded = encoded * span_mask_float.unsqueeze(-1)
# 求和
span_sum = masked_encoded.sum(dim=1) # [B, H]
# 计算有效长度 (防止除以 0)
span_count = span_mask_float.sum(dim=1, keepdim=True).clamp(min=1.0) # [B, 1]
# 平均池化
pooled = span_sum / span_count # [B, H]
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
@ -245,7 +259,7 @@ class MoEModel(nn.Module):
)
elif group_id == 11:
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
torch.tensor(11, device=pooled.device)
)
else: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
@ -270,7 +284,7 @@ class MoEModel(nn.Module):
return probs
return logits
def model_eval(self, eval_dataloader, criterion=None):
def model_eval(self, eval_dataloader, criterion):
"""
在验证集上评估模型返回准确率和平均损失
@ -281,9 +295,6 @@ class MoEModel(nn.Module):
accuracy: float, 准确率
avg_loss: float, 平均损失
"""
if criterion is None:
criterion = nn.NLLLoss()
self.eval()
total_loss = 0.0
correct = 0
@ -295,11 +306,11 @@ class MoEModel(nn.Module):
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
probs = self(input_ids, attention_mask, pg)
log_probs = torch.log(probs + 1e-12)
loss = criterion(log_probs, labels)
total_loss += loss.item() * labels.size(0)
@ -359,6 +370,7 @@ class MoEModel(nn.Module):
}
# 将拼音的第一个字符映射为 PG 中的索引并转换为张量
sample["pg"] = torch.tensor([PG[py[0]]])
sample["p_start"] = torch.tensor([len(text)])
return sample
def predict(self, text, py, tokenizer=None):
@ -366,7 +378,12 @@ class MoEModel(nn.Module):
基于输入的文本和拼音生成 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
参数
text : str
输入的文本
py : str
输入的拼音
tokenizer : Tokenizer, optional
用于分词的分词器默认为 None
debug : bool
是否打印预测错误的样本信息
@ -421,40 +438,18 @@ class MoEModel(nn.Module):
eval_frequency=500,
grad_accum_steps=1, # 梯度累积步数
clip_grad_norm=1.0, # 梯度裁剪的范数
mixed_precision=False, # 是否使用混合精度训练
loss_weight=None, # 损失权重,用于处理类别不平衡
lr=6e-5, # 初始学习率
lr_schedule=None, # 新增:可选的自定义学习率调度函数
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
):
"""
训练模型支持混合精度梯度累积学习率调度实时监控
参数
train_dataloader: DataLoader
训练数据加载器
eval_dataloader: DataLoader, optional
评估数据加载器
monitor: TrainingMonitor, optional
训练监控器
criterion: nn.Module, optional
损失函数
optimizer: optim.Optimizer, optional
优化器
num_epochs: int, optional
训练轮数
eval_frequency: int, optional
评估频率
grad_accum_steps: int, optional
梯度累积步数
clip_grad_norm: float, optional
梯度裁剪范数
mixed_precision: bool, optional
是否使用混合精度
lr: float, optional
初始学习率
lr_schedule : callable, optional
自定义学习率调度函数接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率
# TODO: 添加参数注释
"""
# 确保模型在正确的设备上GPU或CPU
if self.device is None:
@ -462,48 +457,61 @@ class MoEModel(nn.Module):
self.to(self.device)
# 切换到训练模式
super().train()
self.train()
# 默认优化器设置
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr)
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数设置
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(weight=loss_weight)
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing
)
else:
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)):
# ---------- 更新 batch 计数器 ----------
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
if lr_schedule is not None:
# 调用用户自定义的调度函数
lr_schedule(processed_batches, optimizer)
# LR Schedule
if processed_batches < warmup_steps:
current_lr = lr * (processed_batches / warmup_steps)
else:
progress = (processed_batches - warmup_steps) / (
total_steps - warmup_steps
)
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device) # [B]
labels = batch["char_id"].to(self.device)
# 混合精度前向
with amp.autocast(
# Forward
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg)
logits = self(input_ids, attention_mask, pg, p_start)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
@ -511,28 +519,34 @@ class MoEModel(nn.Module):
scaler.scale(loss).backward()
# 梯度累积
if (batch_idx) % grad_accum_steps == 0:
if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
has_nan = False
for p in self.parameters():
if p.grad is not None and torch.isnan(p.grad).any():
has_nan = True
break
if not has_nan:
scaler.step(optimizer)
scaler.update()
else:
logger.warning("NaN detected, skipping step.")
optimizer.zero_grad()
original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss
# 周期性评估(与原代码相同)
if (
eval_dataloader is not None
and global_step % eval_frequency == 0
):
avg_loss = batch_loss_sum / eval_frequency
batch_loss_sum += loss.item() * grad_accum_steps
# 周期性评估
if eval_dataloader and global_step % eval_frequency == 0:
self.eval()
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
if global_step == 0:
avg_loss = eval_loss
super().train()
if monitor is not None:
self.train()
if monitor:
monitor.add_step(
global_step,
{"loss": avg_loss, "acc": acc},
global_step, {"loss": avg_loss, "acc": acc}
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
@ -541,8 +555,9 @@ class MoEModel(nn.Module):
if processed_batches - 1 >= stop_batch:
break
global_step += 1
res_acc, res_loss = self.model_eval(eval_dataloader)
try:
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
to_wechat_response = send_serverchan_message(
title="训练完成",
content=f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}",
@ -582,3 +597,31 @@ class MoEModel(nn.Module):
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False
# --- ONNX 导出辅助函数 ---
def export_onnx(self, output_path, dummy_input):
"""
dummy_input 应该是一个字典或元组包含:
(input_ids, attention_mask, pg, p_start)
"""
self.eval()
input_names = ["input_ids", "attention_mask", "pg", "p_start"]
output_names = ["logits"]
torch.onnx.export(
self,
dummy_input,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"pg": {0: "batch_size"},
"p_start": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=14, # 推荐使用 14+ 以支持更好的算子
do_constant_folding=True,
)
logger.info(f"ONNX model exported to {output_path}")