feat(trainer): 添加残差块以增强模型表达能力
This commit is contained in:
parent
ab2dbc378b
commit
ae414bae6b
|
|
@ -155,6 +155,7 @@ class MoEModel(nn.Module):
|
||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||||
|
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
|
||||||
|
|
||||||
self.total_experts = 20
|
self.total_experts = 20
|
||||||
self.experts = nn.ModuleList()
|
self.experts = nn.ModuleList()
|
||||||
|
|
@ -215,6 +216,9 @@ class MoEModel(nn.Module):
|
||||||
# ----- 3. 池化量 -----
|
# ----- 3. 池化量 -----
|
||||||
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
pooled = block(pooled)
|
||||||
|
|
||||||
# ----- 4. 专家路由(硬路由)-----
|
# ----- 4. 专家路由(硬路由)-----
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
# ------------------ ONNX 导出模式:条件分支(batch=1)------------------
|
# ------------------ ONNX 导出模式:条件分支(batch=1)------------------
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue