From 3bb44f1d73952e900ec1654f0f57df3bcceb2600 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 22 Feb 2026 11:02:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=E5=85=B1?= =?UTF-8?q?=E4=BA=AB=E6=AE=8B=E5=B7=AE=E5=9D=97=E5=B9=B6=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=B1=A0=E5=8C=96=E5=B1=82=E8=BE=93=E5=87=BA=E7=BB=B4=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index af3a24a..1fe3ac6 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -120,9 +120,13 @@ class MoEModel(nn.Module): norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) - self.pooler = nn.AdaptiveAvgPool1d(2) - self.linear = nn.Linear(self.hidden_size * 2, self.hidden_size) + self.shared_resblocks = nn.ModuleList( + [ResidualBlock(self.hidden_size, 0.1) for _ in range(6)] + ) + self.pooler = nn.AdaptiveAvgPool1d(1) + + # self.linear = nn.Linear(self.hidden_size, self.hidden_size) # 3. 专家层:8个领域专家 + 1个共享专家 @@ -191,9 +195,12 @@ class MoEModel(nn.Module): ) # [B, S, H] # ----- 3. 池化量 ----- - pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] - pooled = pooled.flatten(1) # [B, H*2] - pooled = self.linear(pooled) + for block in self.shared_resblocks: + encoded = block(encoded) + pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + # pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] + # pooled = pooled.flatten(1) # [B, H*2] + # pooled = self.linear(pooled) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing():