feat(trainer): 添加残差块以增强模型表达能力

This commit is contained in:
songsenand 2026-02-16 10:26:47 +08:00
parent ab2dbc378b
commit ae414bae6b
1 changed files with 4 additions and 0 deletions

View File

@ -155,6 +155,7 @@ class MoEModel(nn.Module):
) )
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1) self.pooler = nn.AdaptiveAvgPool1d(1)
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
self.total_experts = 20 self.total_experts = 20
self.experts = nn.ModuleList() self.experts = nn.ModuleList()
@ -215,6 +216,9 @@ class MoEModel(nn.Module):
# ----- 3. 池化量 ----- # ----- 3. 池化量 -----
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
for block in self.res_blocks:
pooled = block(pooled)
# ----- 4. 专家路由(硬路由)----- # ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing(): if torch.jit.is_tracing():
# ------------------ ONNX 导出模式条件分支batch=1------------------ # ------------------ ONNX 导出模式条件分支batch=1------------------