diff --git a/src/trainer/model.py b/src/trainer/model.py index af3a24a..1fe3ac6 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -120,9 +120,13 @@ class MoEModel(nn.Module): norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) - self.pooler = nn.AdaptiveAvgPool1d(2) - self.linear = nn.Linear(self.hidden_size * 2, self.hidden_size) + self.shared_resblocks = nn.ModuleList( + [ResidualBlock(self.hidden_size, 0.1) for _ in range(6)] + ) + self.pooler = nn.AdaptiveAvgPool1d(1) + + # self.linear = nn.Linear(self.hidden_size, self.hidden_size) # 3. 专家层:8个领域专家 + 1个共享专家 @@ -191,9 +195,12 @@ class MoEModel(nn.Module): ) # [B, S, H] # ----- 3. 池化量 ----- - pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] - pooled = pooled.flatten(1) # [B, H*2] - pooled = self.linear(pooled) + for block in self.shared_resblocks: + encoded = block(encoded) + pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + # pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] + # pooled = pooled.flatten(1) # [B, H*2] + # pooled = self.linear(pooled) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing():