From ae414bae6b780a3aabef1b03da08fda0b9f2fa34 Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 16 Feb 2026 10:26:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(trainer):=20=E6=B7=BB=E5=8A=A0=E6=AE=8B?= =?UTF-8?q?=E5=B7=AE=E5=9D=97=E4=BB=A5=E5=A2=9E=E5=BC=BA=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=A1=A8=E8=BE=BE=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/trainer/model.py b/src/trainer/model.py index a29f5bb..cc97f25 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -155,6 +155,7 @@ class MoEModel(nn.Module): ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.pooler = nn.AdaptiveAvgPool1d(1) + self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)]) self.total_experts = 20 self.experts = nn.ModuleList() @@ -215,6 +216,9 @@ class MoEModel(nn.Module): # ----- 3. 池化量 ----- pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + for block in self.res_blocks: + pooled = block(pooled) + # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): # ------------------ ONNX 导出模式:条件分支(batch=1)------------------