feat(trainer): 使用 hidden_size 代替 d_model 计算输出维度并添加池化层

This commit is contained in:
songsenand 2026-02-13 15:05:53 +08:00
parent d82c80f3a9
commit f4be47df78
1 changed files with 15 additions and 14 deletions

View File

@ -59,7 +59,7 @@ class Expert(nn.Module):
super().__init__()
self.input_dim = input_dim
self.d_model = d_model
self.output_dim = d_model * output_multiplier
self.output_dim = input_dim * output_multiplier
# 输入映射input_dim -> d_model
self.linear_in = nn.Linear(input_dim, d_model)
@ -113,6 +113,7 @@ class MoEModel(nn.Module):
norm_first=True, # Pre-LN与预训练一致
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1)
# 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts
@ -125,16 +126,16 @@ class MoEModel(nn.Module):
input_dim=self.hidden_size,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=2, # 输出维度 = 2 * d_model
output_multiplier=2, # 输出维度 = 2 * hidden_size
dropout_prob=dropout_prob,
)
self.experts.append(expert)
# 4. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(2 * d_model), # 专家输出维度
nn.LayerNorm(2 * self.hidden_size), # 专家输出维度
nn.Dropout(0.1),
nn.Linear(2 * d_model, num_classes),
nn.Linear(2 * self.hidden_size, num_classes),
)
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
@ -161,7 +162,7 @@ class MoEModel(nn.Module):
) # [B, S, H]
# ----- 3. [CLS] 向量 -----
cls_output = encoded[:, 0, :] # [B, H]
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
@ -170,23 +171,23 @@ class MoEModel(nn.Module):
group_id = pg.item() if torch.is_tensor(pg) else pg
if group_id == 0:
expert_out = self.experts[0](cls_output)
expert_out = self.experts[0](pooled)
elif group_id == 1:
expert_out = self.experts[1](cls_output)
expert_out = self.experts[1](pooled)
elif group_id == 2:
expert_out = self.experts[2](cls_output)
expert_out = self.experts[2](pooled)
elif group_id == 3:
expert_out = self.experts[3](cls_output)
expert_out = self.experts[3](pooled)
elif group_id == 4:
expert_out = self.experts[4](cls_output)
expert_out = self.experts[4](pooled)
elif group_id == 5:
expert_out = self.experts[5](cls_output)
expert_out = self.experts[5](pooled)
elif group_id == 6:
expert_out = self.experts[6](cls_output)
expert_out = self.experts[6](pooled)
elif group_id == 7:
expert_out = self.experts[7](cls_output)
expert_out = self.experts[7](pooled)
else: # group_id == 8
expert_out = self.experts[8](cls_output)
expert_out = self.experts[8](pooled)
else:
# ------------------ 训练 / 普通推理:全量计算 + Gather ------------------