修复输出维度计算错误,使用 d_model 代替 input_dim
This commit is contained in:
parent
0e3418798e
commit
6923870171
|
|
@ -59,7 +59,7 @@ class Expert(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.output_dim = input_dim * output_multiplier
|
self.output_dim = d_model * output_multiplier
|
||||||
|
|
||||||
# 输入映射:input_dim -> d_model
|
# 输入映射:input_dim -> d_model
|
||||||
self.linear_in = nn.Linear(input_dim, d_model)
|
self.linear_in = nn.Linear(input_dim, d_model)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue