feat(trainer): 使用 hidden_size 代替 d_model 计算输出维度并添加池化层

This commit is contained in:
songsenand 2026-02-13 15:05:53 +08:00
parent d82c80f3a9
commit f4be47df78
1 changed files with 15 additions and 14 deletions

View File

@ -59,7 +59,7 @@ class Expert(nn.Module):
super().__init__() super().__init__()
self.input_dim = input_dim self.input_dim = input_dim
self.d_model = d_model self.d_model = d_model
self.output_dim = d_model * output_multiplier self.output_dim = input_dim * output_multiplier
# 输入映射input_dim -> d_model # 输入映射input_dim -> d_model
self.linear_in = nn.Linear(input_dim, d_model) self.linear_in = nn.Linear(input_dim, d_model)
@ -113,6 +113,7 @@ class MoEModel(nn.Module):
norm_first=True, # Pre-LN与预训练一致 norm_first=True, # Pre-LN与预训练一致
) )
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1)
# 3. 专家层8个领域专家 + 1个共享专家 # 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts total_experts = num_domain_experts + num_shared_experts
@ -125,16 +126,16 @@ class MoEModel(nn.Module):
input_dim=self.hidden_size, input_dim=self.hidden_size,
d_model=d_model, d_model=d_model,
num_resblocks=num_resblocks, num_resblocks=num_resblocks,
output_multiplier=2, # 输出维度 = 2 * d_model output_multiplier=2, # 输出维度 = 2 * hidden_size
dropout_prob=dropout_prob, dropout_prob=dropout_prob,
) )
self.experts.append(expert) self.experts.append(expert)
# 4. 分类头 # 4. 分类头
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.LayerNorm(2 * d_model), # 专家输出维度 nn.LayerNorm(2 * self.hidden_size), # 专家输出维度
nn.Dropout(0.1), nn.Dropout(0.1),
nn.Linear(2 * d_model, num_classes), nn.Linear(2 * self.hidden_size, num_classes),
) )
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
@ -161,7 +162,7 @@ class MoEModel(nn.Module):
) # [B, S, H] ) # [B, S, H]
# ----- 3. [CLS] 向量 ----- # ----- 3. [CLS] 向量 -----
cls_output = encoded[:, 0, :] # [B, H] pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# ----- 4. 专家路由(硬路由)----- # ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing(): if torch.jit.is_tracing():
@ -170,23 +171,23 @@ class MoEModel(nn.Module):
group_id = pg.item() if torch.is_tensor(pg) else pg group_id = pg.item() if torch.is_tensor(pg) else pg
if group_id == 0: if group_id == 0:
expert_out = self.experts[0](cls_output) expert_out = self.experts[0](pooled)
elif group_id == 1: elif group_id == 1:
expert_out = self.experts[1](cls_output) expert_out = self.experts[1](pooled)
elif group_id == 2: elif group_id == 2:
expert_out = self.experts[2](cls_output) expert_out = self.experts[2](pooled)
elif group_id == 3: elif group_id == 3:
expert_out = self.experts[3](cls_output) expert_out = self.experts[3](pooled)
elif group_id == 4: elif group_id == 4:
expert_out = self.experts[4](cls_output) expert_out = self.experts[4](pooled)
elif group_id == 5: elif group_id == 5:
expert_out = self.experts[5](cls_output) expert_out = self.experts[5](pooled)
elif group_id == 6: elif group_id == 6:
expert_out = self.experts[6](cls_output) expert_out = self.experts[6](pooled)
elif group_id == 7: elif group_id == 7:
expert_out = self.experts[7](cls_output) expert_out = self.experts[7](pooled)
else: # group_id == 8 else: # group_id == 8
expert_out = self.experts[8](cls_output) expert_out = self.experts[8](pooled)
else: else:
# ------------------ 训练 / 普通推理:全量计算 + Gather ------------------ # ------------------ 训练 / 普通推理:全量计算 + Gather ------------------