diff --git a/src/trainer/model.py b/src/trainer/model.py index 9f7b64e..cb3e8cf 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -59,7 +59,7 @@ class Expert(nn.Module): super().__init__() self.input_dim = input_dim self.d_model = d_model - self.output_dim = input_dim * output_multiplier + self.output_dim = d_model * output_multiplier # 输入映射:input_dim -> d_model self.linear_in = nn.Linear(input_dim, d_model)