feat(trainer): 添加残差块以增强模型表达能力
This commit is contained in:
parent
ab2dbc378b
commit
ae414bae6b
|
|
@ -155,6 +155,7 @@ class MoEModel(nn.Module):
|
|||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
|
||||
|
||||
self.total_experts = 20
|
||||
self.experts = nn.ModuleList()
|
||||
|
|
@ -215,6 +216,9 @@ class MoEModel(nn.Module):
|
|||
# ----- 3. 池化量 -----
|
||||
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||
|
||||
for block in self.res_blocks:
|
||||
pooled = block(pooled)
|
||||
|
||||
# ----- 4. 专家路由(硬路由)-----
|
||||
if torch.jit.is_tracing():
|
||||
# ------------------ ONNX 导出模式:条件分支(batch=1)------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue