fix(encoder): 修复 encoder 调用时缺少 src_key_padding_mask 参数
This commit is contained in:
parent
b1f78668dc
commit
db90516fcf
|
|
@ -180,9 +180,9 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# ----- 2. Transformer Encoder -----
|
# ----- 2. Transformer Encoder -----
|
||||||
# padding mask: True 表示忽略该位置
|
# padding mask: True 表示忽略该位置
|
||||||
# padding_mask = attention_mask == 0
|
padding_mask = attention_mask == 0
|
||||||
encoded = self.encoder(
|
encoded = self.encoder(
|
||||||
embeddings #, src_key_padding_mask=padding_mask
|
embeddings , src_key_padding_mask=padding_mask
|
||||||
) # [B, S, H]
|
) # [B, S, H]
|
||||||
|
|
||||||
# ----- 3. 池化量 -----
|
# ----- 3. 池化量 -----
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue