feat(trainer): 使用 hidden_size 代替 d_model 计算输出维度并添加池化层
This commit is contained in:
parent
d82c80f3a9
commit
f4be47df78
|
|
@ -59,7 +59,7 @@ class Expert(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.output_dim = d_model * output_multiplier
|
self.output_dim = input_dim * output_multiplier
|
||||||
|
|
||||||
# 输入映射:input_dim -> d_model
|
# 输入映射:input_dim -> d_model
|
||||||
self.linear_in = nn.Linear(input_dim, d_model)
|
self.linear_in = nn.Linear(input_dim, d_model)
|
||||||
|
|
@ -113,6 +113,7 @@ class MoEModel(nn.Module):
|
||||||
norm_first=True, # Pre-LN,与预训练一致
|
norm_first=True, # Pre-LN,与预训练一致
|
||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
|
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||||
|
|
||||||
# 3. 专家层:8个领域专家 + 1个共享专家
|
# 3. 专家层:8个领域专家 + 1个共享专家
|
||||||
total_experts = num_domain_experts + num_shared_experts
|
total_experts = num_domain_experts + num_shared_experts
|
||||||
|
|
@ -125,16 +126,16 @@ class MoEModel(nn.Module):
|
||||||
input_dim=self.hidden_size,
|
input_dim=self.hidden_size,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
num_resblocks=num_resblocks,
|
num_resblocks=num_resblocks,
|
||||||
output_multiplier=2, # 输出维度 = 2 * d_model
|
output_multiplier=2, # 输出维度 = 2 * hidden_size
|
||||||
dropout_prob=dropout_prob,
|
dropout_prob=dropout_prob,
|
||||||
)
|
)
|
||||||
self.experts.append(expert)
|
self.experts.append(expert)
|
||||||
|
|
||||||
# 4. 分类头
|
# 4. 分类头
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.LayerNorm(2 * d_model), # 专家输出维度
|
nn.LayerNorm(2 * self.hidden_size), # 专家输出维度
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
nn.Linear(2 * d_model, num_classes),
|
nn.Linear(2 * self.hidden_size, num_classes),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||||
|
|
@ -161,7 +162,7 @@ class MoEModel(nn.Module):
|
||||||
) # [B, S, H]
|
) # [B, S, H]
|
||||||
|
|
||||||
# ----- 3. [CLS] 向量 -----
|
# ----- 3. [CLS] 向量 -----
|
||||||
cls_output = encoded[:, 0, :] # [B, H]
|
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||||
|
|
||||||
# ----- 4. 专家路由(硬路由)-----
|
# ----- 4. 专家路由(硬路由)-----
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
|
|
@ -170,23 +171,23 @@ class MoEModel(nn.Module):
|
||||||
group_id = pg.item() if torch.is_tensor(pg) else pg
|
group_id = pg.item() if torch.is_tensor(pg) else pg
|
||||||
|
|
||||||
if group_id == 0:
|
if group_id == 0:
|
||||||
expert_out = self.experts[0](cls_output)
|
expert_out = self.experts[0](pooled)
|
||||||
elif group_id == 1:
|
elif group_id == 1:
|
||||||
expert_out = self.experts[1](cls_output)
|
expert_out = self.experts[1](pooled)
|
||||||
elif group_id == 2:
|
elif group_id == 2:
|
||||||
expert_out = self.experts[2](cls_output)
|
expert_out = self.experts[2](pooled)
|
||||||
elif group_id == 3:
|
elif group_id == 3:
|
||||||
expert_out = self.experts[3](cls_output)
|
expert_out = self.experts[3](pooled)
|
||||||
elif group_id == 4:
|
elif group_id == 4:
|
||||||
expert_out = self.experts[4](cls_output)
|
expert_out = self.experts[4](pooled)
|
||||||
elif group_id == 5:
|
elif group_id == 5:
|
||||||
expert_out = self.experts[5](cls_output)
|
expert_out = self.experts[5](pooled)
|
||||||
elif group_id == 6:
|
elif group_id == 6:
|
||||||
expert_out = self.experts[6](cls_output)
|
expert_out = self.experts[6](pooled)
|
||||||
elif group_id == 7:
|
elif group_id == 7:
|
||||||
expert_out = self.experts[7](cls_output)
|
expert_out = self.experts[7](pooled)
|
||||||
else: # group_id == 8
|
else: # group_id == 8
|
||||||
expert_out = self.experts[8](cls_output)
|
expert_out = self.experts[8](pooled)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# ------------------ 训练 / 普通推理:全量计算 + Gather ------------------
|
# ------------------ 训练 / 普通推理:全量计算 + Gather ------------------
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue