From 96706abb938c852bf85d7931fbb6ca13d4046cb2 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 22 Feb 2026 10:20:17 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E4=BF=AE=E6=94=B9=E6=B1=A0?= =?UTF-8?q?=E5=8C=96=E5=B1=82=E8=BE=93=E5=87=BA=E7=BB=B4=E5=BA=A6=E5=B9=B6?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=BA=BF=E6=80=A7=E5=8F=98=E6=8D=A2=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index eb04e65..af3a24a 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -120,7 +120,10 @@ class MoEModel(nn.Module): norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) - self.pooler = nn.AdaptiveAvgPool1d(1) + self.pooler = nn.AdaptiveAvgPool1d(2) + + self.linear = nn.Linear(self.hidden_size * 2, self.hidden_size) + # 3. 专家层:8个领域专家 + 1个共享专家 total_experts = num_domain_experts + num_shared_experts @@ -138,6 +141,7 @@ class MoEModel(nn.Module): ) self.experts.append(expert) + self.expert_bias = nn.Embedding( total_experts, self.output_multiplier * self.hidden_size ) @@ -187,7 +191,9 @@ class MoEModel(nn.Module): ) # [B, S, H] # ----- 3. 池化量 ----- - pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] + pooled = pooled.flatten(1) # [B, H*2] + pooled = self.linear(pooled) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing():