fix(encoder): 修复 encoder 调用时缺少 src_key_padding_mask 参数
This commit is contained in:
parent
b1f78668dc
commit
db90516fcf
|
|
@ -180,9 +180,9 @@ class MoEModel(nn.Module):
|
|||
|
||||
# ----- 2. Transformer Encoder -----
|
||||
# padding mask: True 表示忽略该位置
|
||||
# padding_mask = attention_mask == 0
|
||||
padding_mask = attention_mask == 0
|
||||
encoded = self.encoder(
|
||||
embeddings #, src_key_padding_mask=padding_mask
|
||||
embeddings , src_key_padding_mask=padding_mask
|
||||
) # [B, S, H]
|
||||
|
||||
# ----- 3. 池化量 -----
|
||||
|
|
|
|||
Loading…
Reference in New Issue