diff --git a/src/trainer/model.py b/src/trainer/model.py index fcffcca..660d74f 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -188,7 +188,7 @@ class MoEModel(nn.Module): # ----- 3. 池化量 ----- # for block in self.shared_resblocks: # encoded = block(encoded) - pooled = self.pooler(embeddings.transpose(1, 2)).squeeze(-1) + pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) # pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] # pooled = pooled.flatten(1) # [B, H*2] # pooled = self.linear(pooled)