diff --git a/src/trainer/model.py b/src/trainer/model.py index 660d74f..07c301c 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -180,9 +180,9 @@ class MoEModel(nn.Module): # ----- 2. Transformer Encoder ----- # padding mask: True 表示忽略该位置 - # padding_mask = attention_mask == 0 + padding_mask = attention_mask == 0 encoded = self.encoder( - embeddings #, src_key_padding_mask=padding_mask + embeddings , src_key_padding_mask=padding_mask ) # [B, S, H] # ----- 3. 池化量 -----