From fd913748ca116d9bb6a7054453f78872dfb41964 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 15 Feb 2026 00:08:44 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=AE=8B=E5=B7=AE=E5=9D=97?= =?UTF-8?q?=E5=92=8C=E5=88=86=E7=B1=BB=E5=A4=B4=E7=9A=84=20dropout=20?= =?UTF-8?q?=E6=A6=82=E7=8E=87=EF=BC=8C=E5=B9=B6=E6=96=B0=E5=A2=9E=E6=AE=8B?= =?UTF-8?q?=E5=B7=AE=E6=A8=A1=E5=9D=97=E5=88=B0=20MoE=20=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model_with_neck.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/trainer/model_with_neck.py b/src/trainer/model_with_neck.py index f08819f..ceec383 100644 --- a/src/trainer/model_with_neck.py +++ b/src/trainer/model_with_neck.py @@ -58,7 +58,7 @@ EXPORT_HIDE_DIM = { # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): - def __init__(self, dim, dropout_prob=0.1): + def __init__(self, dim, dropout_prob=0.0): super().__init__() self.linear1 = nn.Linear(dim, dim) self.ln1 = nn.LayerNorm(dim) @@ -73,7 +73,7 @@ class ResidualBlock(nn.Module): x = self.ln1(x) x = self.linear2(x) x = self.ln2(x) - x = self.dropout(x) # 残差前加 Dropout(符合原描述) + x = self.dropout(x) x = x + residual return self.relu(x) @@ -86,7 +86,7 @@ class Expert(nn.Module): d_model=1024, num_resblocks=4, output_multiplier=2, - dropout_prob=0.1, + dropout_prob=0.0, ): """ input_dim : BERT 输出的 hidden_size(如 312/768) @@ -156,6 +156,8 @@ class MoEModel(nn.Module): self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.pooler = nn.AdaptiveAvgPool1d(1) + self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)]) + self.total_experts = 20 self.experts = nn.ModuleList() @@ -175,6 +177,7 @@ class MoEModel(nn.Module): # 4. 分类头 self.classifier = nn.Sequential( + nn.Dropout(0.2), nn.LayerNorm(self.output_multiplier * self.hidden_size), nn.Linear( self.output_multiplier * self.hidden_size, @@ -186,10 +189,8 @@ class MoEModel(nn.Module): self.output_multiplier * self.hidden_size * 2, ), nn.ReLU(inplace=True), - nn.Dropout(0.2), nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes), ) - # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) def to(self, device): """重写 to 方法,记录设备""" @@ -212,6 +213,9 @@ class MoEModel(nn.Module): embeddings, src_key_padding_mask=padding_mask ) # [B, S, H] + for block in self.res_blocks: + encoded = block(encoded) + # ----- 3. 池化量 ----- pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)