refactor(model): 优化slot权重计算逻辑以提升稳定性

This commit is contained in:
songsenand 2026-05-23 13:42:44 +08:00
parent 53f244de2f
commit 88955bcfdd
1 changed files with 2 additions and 1 deletions

View File

@ -130,8 +130,9 @@ class InputMethodEngine(nn.Module):
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
slot_scores = self.slot_attention(moe_out).squeeze(-1) slot_scores = self.slot_attention(moe_out).squeeze(-1)
slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9)
slot_weights = torch.softmax(slot_scores, dim=1) slot_weights = torch.softmax(slot_scores, dim=1)
slot_weights = slot_weights * slot_mask
slot_weights = slot_weights / (slot_weights.sum(dim=1, keepdim=True) + 1e-12)
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
logits = self.classifier(pooled) logits = self.classifier(pooled)