diff --git a/src/model/model.py b/src/model/model.py index ac50d78..028685a 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -130,8 +130,9 @@ class InputMethodEngine(nn.Module): batch_size = input_ids.size(0) slot_scores = self.slot_attention(moe_out).squeeze(-1) - slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9) slot_weights = torch.softmax(slot_scores, dim=1) + slot_weights = slot_weights * slot_mask + slot_weights = slot_weights / (slot_weights.sum(dim=1, keepdim=True) + 1e-12) pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) logits = self.classifier(pooled)