diff --git a/src/trainer/model.py b/src/trainer/model.py index a29f5bb..cc97f25 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -155,6 +155,7 @@ class MoEModel(nn.Module): ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.pooler = nn.AdaptiveAvgPool1d(1) + self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)]) self.total_experts = 20 self.experts = nn.ModuleList() @@ -215,6 +216,9 @@ class MoEModel(nn.Module): # ----- 3. 池化量 ----- pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + for block in self.res_blocks: + pooled = block(pooled) + # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): # ------------------ ONNX 导出模式:条件分支(batch=1)------------------