feat(model): 更新模型结构,使用 GELU 激活函数并优化专家网络参数

This commit is contained in:
songsenand 2026-02-25 16:56:09 +08:00
parent db90516fcf
commit 93dced50c7
1 changed files with 190 additions and 147 deletions

View File

@ -1,3 +1,4 @@
import math
import pickle import pickle
from importlib.resources import files from importlib.resources import files
from pathlib import Path from pathlib import Path
@ -12,9 +13,10 @@ from loguru import logger
from modelscope import AutoModel, AutoTokenizer from modelscope import AutoModel, AutoTokenizer
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
from .monitor import TrainingMonitor, send_serverchan_message
from suinput.dataset import PG from suinput.dataset import PG
from .monitor import TrainingMonitor, send_serverchan_message
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
@ -22,61 +24,46 @@ def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset
# ---------------------------- 残差块 ---------------------------- # ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.1): def __init__(self, dim, dropout_prob=0.3):
super().__init__() super().__init__()
self.linear1 = nn.Linear(dim, dim) self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim) self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim) self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim) self.ln2 = nn.LayerNorm(dim)
self.relu = nn.ReLU() self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)
def forward(self, x): def forward(self, x):
residual = x residual = x
x = self.relu(self.linear1(x)) x = self.gelu(self.linear1(x))
x = self.ln1(x) x = self.ln1(x)
x = self.linear2(x) x = self.linear2(x)
x = self.ln2(x) x = self.ln2(x)
x = self.dropout(x) x = self.dropout(x)
x = x + residual x = x + residual
return self.relu(x) return self.gelu(x)
# ---------------------------- 专家网络 ---------------------------- # ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module): class Expert(nn.Module):
def __init__( def __init__(
self, self,
input_dim, # 输入特征的维度大小 input_dim,
d_model=1024, # 模型内部的隐藏层维度默认为1024 d_model=768,
num_resblocks=4, # 残差块的数量默认为4 num_resblocks=3,
output_multiplier=2, # 输出维度是输入维度的倍数默认为2倍 output_multiplier=1,
dropout_prob=0.1, # Dropout层的丢弃概率默认为0.1 dropout_prob=0.3,
): ):
""" super().__init__()
初始化函数用于构建模型的各个层 self.output_dim = input_dim * output_multiplier
参数说明
input_dim : 输入维度
d_model : 专家内部维度
output_multiplier : 输出维度 = input_dim * output_multiplier
dropout_prob : 残差块内 Dropout
"""
super().__init__() # 调用父类的初始化方法
self.input_dim = input_dim # 保存输入维度
self.d_model = d_model # 保存模型内部维度
self.output_dim = input_dim * output_multiplier # 计算并保存输出维度
# 输入映射input_dim -> d_model
self.linear_in = nn.Linear(input_dim, d_model) self.linear_in = nn.Linear(input_dim, d_model)
# 残差堆叠
self.res_blocks = nn.ModuleList( self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
) )
# 输出映射d_model -> output_dim
self.output = nn.Sequential( self.output = nn.Sequential(
nn.Linear(d_model, d_model), nn.Linear(d_model, d_model),
nn.ReLU(inplace=True), nn.GELU(inplace=True),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim), nn.Linear(d_model, self.output_dim),
) )
@ -91,7 +78,7 @@ class Expert(nn.Module):
class MoEModel(nn.Module): class MoEModel(nn.Module):
def __init__( def __init__(
self, self,
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std", pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
num_classes=10018, num_classes=10018,
output_multiplier=2, output_multiplier=2,
d_model=768, d_model=768,
@ -102,14 +89,15 @@ class MoEModel(nn.Module):
super().__init__() super().__init__()
self.output_multiplier = output_multiplier self.output_multiplier = output_multiplier
# 1. 加载预训练 BERT仅保留 embeddings # 1. 加载预训练 Embedding (Lite: hidden_size=512)
logger.info(f"Loading backbone: {pretrained_model_name}")
bert = AutoModel.from_pretrained(pretrained_model_name) bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings self.embedding = bert.embeddings
self.bert_config = bert.config self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.hidden_size = self.bert_config.hidden_size # 512
self.device = None # 将在 to() 调用时设置 self.device = None
# 2. 4 层标准 Transformer Encoder(从 config 读取参数) # 2. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer( encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size, d_model=self.hidden_size,
nhead=8, nhead=8,
@ -120,78 +108,104 @@ class MoEModel(nn.Module):
) )
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
# self.shared_resblocks = nn.ModuleList( # 3. 专家系统
# [ResidualBlock(self.hidden_size, 0.1) for _ in range(4)]
# )
self.pooler = nn.AdaptiveAvgPool1d(1)
# self.linear = nn.Linear(self.hidden_size, self.hidden_size)
# 3. 专家层n个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList() self.experts = nn.ModuleList()
for i in range(total_experts): for i in range(total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则) dropout = 0.3 if i < num_domain_experts else 0.4
dropout_prob = 0.1 if i < num_domain_experts else 0.2 self.experts.append(
expert = Expert( Expert(
input_dim=self.hidden_size, input_dim=self.hidden_size,
d_model=d_model, d_model=d_model,
num_resblocks=num_resblocks, num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, output_multiplier=self.output_multiplier,
dropout_prob=dropout_prob, dropout_prob=dropout,
)
) )
self.experts.append(expert)
self.expert_bias = nn.Embedding( self.expert_bias = nn.Embedding(
total_experts, self.output_multiplier * self.hidden_size total_experts, self.hidden_size * self.output_multiplier
) )
# 4. 分类头 # 4. 分类头
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.LayerNorm(self.output_multiplier * self.hidden_size), nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Linear( nn.Dropout(0.4),
self.output_multiplier * self.hidden_size, nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
) )
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
def to(self, device): def to(self, device):
"""重写 to 方法,记录设备""" """重写 to 方法,记录设备"""
self.device = device self.device = device
return super().to(device) return super().to(device)
def forward(self, input_ids, attention_mask, pg): def forward(self, input_ids, attention_mask, pg, p_start):
""" """
input_ids : [batch, seq_len] ONNX 兼容的 Forward 函数
attention_mask: [batch, seq_len] (1 为有效0 padding)
pg : group_id训练时为 [batch] LongTensor推理导出时为标量 Tensor Args:
input_ids: [B, L]
attention_mask: [B, L]
pg: [B] 拼音组 ID
p_start: [B] 拼音起始索引位置 (整数 Tensor)
""" """
# ----- 1. Embeddings ----- # ----- 1. Embeddings -----
embeddings = self.embedding(input_ids) # [B, S, H] embeddings = self.embedding(input_ids)
# ----- 2. Transformer Encoder ----- # ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置 # padding mask: True 表示忽略该位置
padding_mask = attention_mask == 0 padding_mask = attention_mask == 0
encoded = self.encoder( encoded = self.encoder(
embeddings , src_key_padding_mask=padding_mask embeddings, src_key_padding_mask=padding_mask
) # [B, S, H] ) # [B, S, H]
# ----- 3. 池化量 ----- # ----- 3. ONNX 兼容的 Span Pooling (向量化实现) -----
# for block in self.shared_resblocks: """
# encoded = block(encoded) 思路
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) 我们不能用循环去切片我们要构造一个 Mask 矩阵
# pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] 目标对于每个样本 i生成一个长度为 L 的向量其中 p_start[i] < index < p_end[i] 的位置为 1其余为 0
# pooled = pooled.flatten(1) # [B, H*2]
# pooled = self.linear(pooled) 步骤
1. 生成位置索引轴indices = [0, 1, 2, ..., L-1] (Shape: [L])
2. 扩展维度以匹配 Batch
indices: [1, L]
p_start: [B, 1]
p_end: [B, 1]
3. 逻辑比较 (Broadcasting)
mask = (indices > p_start) & (indices < p_end)
结果 Shape: [B, L] (Boolean)
4. 应用 Mask
masked_encoded = encoded * mask.unsqueeze(-1)
5. 求和并归一化
sum_vec = masked_encoded.sum(dim=1)
count = mask.sum(dim=1).clamp(min=1) # 防止除零
pooled = sum_vec / count
"""
B, L, H = encoded.shape
device = encoded.device
# 生成位置轴 [0, 1, ..., L-1]
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
# 调整 p_start 形状为 [B, 1] 以便广播
p_start_exp = p_start.unsqueeze(1) # [B, 1]
span_mask = positions >= p_start_exp
# 转换为 Float 用于乘法
span_mask_float = span_mask.float() # [B, L]
# 应用 Mask
# encoded: [B, L, H] -> mask: [B, L, 1]
masked_encoded = encoded * span_mask_float.unsqueeze(-1)
# 求和
span_sum = masked_encoded.sum(dim=1) # [B, H]
# 计算有效长度 (防止除以 0)
span_count = span_mask_float.sum(dim=1, keepdim=True).clamp(min=1.0) # [B, 1]
# 平均池化
pooled = span_sum / span_count # [B, H]
# ----- 4. 专家路由(硬路由)----- # ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing(): if torch.jit.is_tracing():
@ -245,7 +259,7 @@ class MoEModel(nn.Module):
) )
elif group_id == 11: elif group_id == 11:
expert_out = self.experts[11](pooled) + self.expert_bias( expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device) torch.tensor(11, device=pooled.device)
) )
else: # group_id == 12 else: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias( expert_out = self.experts[12](pooled) + self.expert_bias(
@ -270,7 +284,7 @@ class MoEModel(nn.Module):
return probs return probs
return logits return logits
def model_eval(self, eval_dataloader, criterion=None): def model_eval(self, eval_dataloader, criterion):
""" """
在验证集上评估模型返回准确率和平均损失 在验证集上评估模型返回准确率和平均损失
@ -281,9 +295,6 @@ class MoEModel(nn.Module):
accuracy: float, 准确率 accuracy: float, 准确率
avg_loss: float, 平均损失 avg_loss: float, 平均损失
""" """
if criterion is None:
criterion = nn.NLLLoss()
self.eval() self.eval()
total_loss = 0.0 total_loss = 0.0
correct = 0 correct = 0
@ -295,11 +306,11 @@ class MoEModel(nn.Module):
input_ids = batch["hint"]["input_ids"].to(self.device) input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device) pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device)
labels = batch["char_id"].to(self.device) labels = batch["char_id"].to(self.device)
# 前向传播 # 前向传播
probs = self(input_ids, attention_mask, pg) probs = self(input_ids, attention_mask, pg)
log_probs = torch.log(probs + 1e-12)
loss = criterion(log_probs, labels) loss = criterion(log_probs, labels)
total_loss += loss.item() * labels.size(0) total_loss += loss.item() * labels.size(0)
@ -359,6 +370,7 @@ class MoEModel(nn.Module):
} }
# 将拼音的第一个字符映射为 PG 中的索引并转换为张量 # 将拼音的第一个字符映射为 PG 中的索引并转换为张量
sample["pg"] = torch.tensor([PG[py[0]]]) sample["pg"] = torch.tensor([PG[py[0]]])
sample["p_start"] = torch.tensor([len(text)])
return sample return sample
def predict(self, text, py, tokenizer=None): def predict(self, text, py, tokenizer=None):
@ -366,7 +378,12 @@ class MoEModel(nn.Module):
基于输入的文本和拼音生成 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息 基于输入的文本和拼音生成 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
参数 参数
text : str
输入的文本
py : str
输入的拼音
tokenizer : Tokenizer, optional
用于分词的分词器默认为 None
debug : bool debug : bool
是否打印预测错误的样本信息 是否打印预测错误的样本信息
@ -421,40 +438,18 @@ class MoEModel(nn.Module):
eval_frequency=500, eval_frequency=500,
grad_accum_steps=1, # 梯度累积步数 grad_accum_steps=1, # 梯度累积步数
clip_grad_norm=1.0, # 梯度裁剪的范数 clip_grad_norm=1.0, # 梯度裁剪的范数
mixed_precision=False, # 是否使用混合精度训练 loss_weight=None,
loss_weight=None, # 损失权重,用于处理类别不平衡 mixed_precision=True,
lr=6e-5, # 初始学习率 weight_decay=0.1,
lr_schedule=None, # 新增:可选的自定义学习率调度函数 warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
): ):
""" """
训练模型支持混合精度梯度累积学习率调度实时监控 训练模型支持混合精度梯度累积学习率调度实时监控
参数 参数
train_dataloader: DataLoader # TODO: 添加参数注释
训练数据加载器
eval_dataloader: DataLoader, optional
评估数据加载器
monitor: TrainingMonitor, optional
训练监控器
criterion: nn.Module, optional
损失函数
optimizer: optim.Optimizer, optional
优化器
num_epochs: int, optional
训练轮数
eval_frequency: int, optional
评估频率
grad_accum_steps: int, optional
梯度累积步数
clip_grad_norm: float, optional
梯度裁剪范数
mixed_precision: bool, optional
是否使用混合精度
lr: float, optional
初始学习率
lr_schedule : callable, optional
自定义学习率调度函数接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率
""" """
# 确保模型在正确的设备上GPU或CPU # 确保模型在正确的设备上GPU或CPU
if self.device is None: if self.device is None:
@ -462,48 +457,61 @@ class MoEModel(nn.Module):
self.to(self.device) self.to(self.device)
# 切换到训练模式 # 切换到训练模式
super().train() self.train()
# 默认优化器设置 # 默认优化器设置
if optimizer is None: if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr) optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数设置 # 损失函数设置
if criterion is None: if criterion is None:
if loss_weight is not None: if loss_weight is not None:
criterion = nn.CrossEntropyLoss(weight=loss_weight) criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing
)
else: else:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 混合精度缩放器 # 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision) scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0 total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0 # 新增:实际处理的 batch 数量计数器 processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0 batch_loss_sum = 0.0
optimizer.zero_grad() optimizer.zero_grad()
for epoch in range(num_epochs): for epoch in range(num_epochs):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)): for batch_idx, batch in enumerate(
# ---------- 更新 batch 计数器 ---------- tqdm(train_dataloader, total=int(stop_batch))
):
processed_batches += 1 processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- # LR Schedule
if lr_schedule is not None: if processed_batches < warmup_steps:
# 调用用户自定义的调度函数 current_lr = lr * (processed_batches / warmup_steps)
lr_schedule(processed_batches, optimizer) else:
progress = (processed_batches - warmup_steps) / (
total_steps - warmup_steps
)
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# ---------- 移动数据 ---------- # ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device) input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device) pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device) # [B]
labels = batch["char_id"].to(self.device) labels = batch["char_id"].to(self.device)
# 混合精度前向 # 混合精度前向
with amp.autocast( # Forward
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision device_type=self.device.type, enabled=mixed_precision
): ):
logits = self(input_ids, attention_mask, pg) logits = self(input_ids, attention_mask, pg, p_start)
loss = criterion(logits, labels) loss = criterion(logits, labels)
loss = loss / grad_accum_steps loss = loss / grad_accum_steps
@ -511,38 +519,45 @@ class MoEModel(nn.Module):
scaler.scale(loss).backward() scaler.scale(loss).backward()
# 梯度累积 # 梯度累积
if (batch_idx) % grad_accum_steps == 0: if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update() has_nan = False
for p in self.parameters():
if p.grad is not None and torch.isnan(p.grad).any():
has_nan = True
break
if not has_nan:
scaler.step(optimizer)
scaler.update()
else:
logger.warning("NaN detected, skipping step.")
optimizer.zero_grad() optimizer.zero_grad()
original_loss = loss.item() * grad_accum_steps batch_loss_sum += loss.item() * grad_accum_steps
batch_loss_sum += original_loss
# 周期性评估(与原代码相同) # 周期性评估
if ( if eval_dataloader and global_step % eval_frequency == 0:
eval_dataloader is not None self.eval()
and global_step % eval_frequency == 0
):
avg_loss = batch_loss_sum / eval_frequency
acc, eval_loss = self.model_eval(eval_dataloader, criterion) acc, eval_loss = self.model_eval(eval_dataloader, criterion)
if global_step == 0: if global_step == 0:
avg_loss = eval_loss avg_loss = eval_loss
super().train() self.train()
if monitor is not None: if monitor:
monitor.add_step( monitor.add_step(
global_step, global_step, {"loss": avg_loss, "acc": acc}
{"loss": avg_loss, "acc": acc},
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
) )
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
)
batch_loss_sum = 0.0 batch_loss_sum = 0.0
if processed_batches - 1 >= stop_batch: if processed_batches - 1 >= stop_batch:
break break
global_step += 1 global_step += 1
res_acc, res_loss = self.model_eval(eval_dataloader)
try: try:
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
to_wechat_response = send_serverchan_message( to_wechat_response = send_serverchan_message(
title="训练完成", title="训练完成",
content=f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}", content=f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}",
@ -582,3 +597,31 @@ class MoEModel(nn.Module):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if name in freeze_layers: if name in freeze_layers:
param.requires_grad = False param.requires_grad = False
# --- ONNX 导出辅助函数 ---
def export_onnx(self, output_path, dummy_input):
"""
dummy_input 应该是一个字典或元组包含:
(input_ids, attention_mask, pg, p_start)
"""
self.eval()
input_names = ["input_ids", "attention_mask", "pg", "p_start"]
output_names = ["logits"]
torch.onnx.export(
self,
dummy_input,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"pg": {0: "batch_size"},
"p_start": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=14, # 推荐使用 14+ 以支持更好的算子
do_constant_folding=True,
)
logger.info(f"ONNX model exported to {output_path}")