diff --git a/src/trainer/model.py b/src/trainer/model.py index eb04e65..af3a24a 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -120,7 +120,10 @@ class MoEModel(nn.Module): norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) - self.pooler = nn.AdaptiveAvgPool1d(1) + self.pooler = nn.AdaptiveAvgPool1d(2) + + self.linear = nn.Linear(self.hidden_size * 2, self.hidden_size) + # 3. 专家层:8个领域专家 + 1个共享专家 total_experts = num_domain_experts + num_shared_experts @@ -138,6 +141,7 @@ class MoEModel(nn.Module): ) self.experts.append(expert) + self.expert_bias = nn.Embedding( total_experts, self.output_multiplier * self.hidden_size ) @@ -187,7 +191,9 @@ class MoEModel(nn.Module): ) # [B, S, H] # ----- 3. 池化量 ----- - pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] + pooled = pooled.flatten(1) # [B, H*2] + pooled = self.linear(pooled) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing():