From 88955bcfdd18482c0d2119c9440a1d5c634ac861 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 23 May 2026 13:42:44 +0800 Subject: [PATCH] =?UTF-8?q?refactor(model):=20=E4=BC=98=E5=8C=96slot?= =?UTF-8?q?=E6=9D=83=E9=87=8D=E8=AE=A1=E7=AE=97=E9=80=BB=E8=BE=91=E4=BB=A5?= =?UTF-8?q?=E6=8F=90=E5=8D=87=E7=A8=B3=E5=AE=9A=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model/model.py b/src/model/model.py index ac50d78..028685a 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -130,8 +130,9 @@ class InputMethodEngine(nn.Module): batch_size = input_ids.size(0) slot_scores = self.slot_attention(moe_out).squeeze(-1) - slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9) slot_weights = torch.softmax(slot_scores, dim=1) + slot_weights = slot_weights * slot_mask + slot_weights = slot_weights / (slot_weights.sum(dim=1, keepdim=True) + 1e-12) pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) logits = self.classifier(pooled)