feat(model): 添加共享残差块并调整池化层输出维度
This commit is contained in:
parent
96706abb93
commit
3bb44f1d73
|
|
@ -120,9 +120,13 @@ class MoEModel(nn.Module):
|
|||
norm_first=True, # Pre-LN,与预训练一致
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||
self.pooler = nn.AdaptiveAvgPool1d(2)
|
||||
|
||||
self.linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
|
||||
self.shared_resblocks = nn.ModuleList(
|
||||
[ResidualBlock(self.hidden_size, 0.1) for _ in range(6)]
|
||||
)
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
# self.linear = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
|
||||
|
||||
# 3. 专家层:8个领域专家 + 1个共享专家
|
||||
|
|
@ -191,9 +195,12 @@ class MoEModel(nn.Module):
|
|||
) # [B, S, H]
|
||||
|
||||
# ----- 3. 池化量 -----
|
||||
pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
|
||||
pooled = pooled.flatten(1) # [B, H*2]
|
||||
pooled = self.linear(pooled)
|
||||
for block in self.shared_resblocks:
|
||||
encoded = block(encoded)
|
||||
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||
# pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
|
||||
# pooled = pooled.flatten(1) # [B, H*2]
|
||||
# pooled = self.linear(pooled)
|
||||
|
||||
# ----- 4. 专家路由(硬路由)-----
|
||||
if torch.jit.is_tracing():
|
||||
|
|
|
|||
Loading…
Reference in New Issue