refactor(model): 优化slot权重计算逻辑以提升稳定性
This commit is contained in:
parent
53f244de2f
commit
88955bcfdd
|
|
@ -130,8 +130,9 @@ class InputMethodEngine(nn.Module):
|
||||||
|
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||||||
slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9)
|
|
||||||
slot_weights = torch.softmax(slot_scores, dim=1)
|
slot_weights = torch.softmax(slot_scores, dim=1)
|
||||||
|
slot_weights = slot_weights * slot_mask
|
||||||
|
slot_weights = slot_weights / (slot_weights.sum(dim=1, keepdim=True) + 1e-12)
|
||||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
||||||
|
|
||||||
logits = self.classifier(pooled)
|
logits = self.classifier(pooled)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue