diff --git a/src/trainer/model.py b/src/trainer/model.py index 5ce92e3..3b2e4f1 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -59,7 +59,7 @@ class Expert(nn.Module): super().__init__() self.input_dim = input_dim self.d_model = d_model - self.output_dim = d_model * output_multiplier + self.output_dim = input_dim * output_multiplier # 输入映射:input_dim -> d_model self.linear_in = nn.Linear(input_dim, d_model) @@ -113,6 +113,7 @@ class MoEModel(nn.Module): norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) + self.pooler = nn.AdaptiveAvgPool1d(1) # 3. 专家层:8个领域专家 + 1个共享专家 total_experts = num_domain_experts + num_shared_experts @@ -125,16 +126,16 @@ class MoEModel(nn.Module): input_dim=self.hidden_size, d_model=d_model, num_resblocks=num_resblocks, - output_multiplier=2, # 输出维度 = 2 * d_model + output_multiplier=2, # 输出维度 = 2 * hidden_size dropout_prob=dropout_prob, ) self.experts.append(expert) # 4. 分类头 self.classifier = nn.Sequential( - nn.LayerNorm(2 * d_model), # 专家输出维度 + nn.LayerNorm(2 * self.hidden_size), # 专家输出维度 nn.Dropout(0.1), - nn.Linear(2 * d_model, num_classes), + nn.Linear(2 * self.hidden_size, num_classes), ) # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) @@ -161,7 +162,7 @@ class MoEModel(nn.Module): ) # [B, S, H] # ----- 3. [CLS] 向量 ----- - cls_output = encoded[:, 0, :] # [B, H] + pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): @@ -170,23 +171,23 @@ class MoEModel(nn.Module): group_id = pg.item() if torch.is_tensor(pg) else pg if group_id == 0: - expert_out = self.experts[0](cls_output) + expert_out = self.experts[0](pooled) elif group_id == 1: - expert_out = self.experts[1](cls_output) + expert_out = self.experts[1](pooled) elif group_id == 2: - expert_out = self.experts[2](cls_output) + expert_out = self.experts[2](pooled) elif group_id == 3: - expert_out = self.experts[3](cls_output) + expert_out = self.experts[3](pooled) elif group_id == 4: - expert_out = self.experts[4](cls_output) + expert_out = self.experts[4](pooled) elif group_id == 5: - expert_out = self.experts[5](cls_output) + expert_out = self.experts[5](pooled) elif group_id == 6: - expert_out = self.experts[6](cls_output) + expert_out = self.experts[6](pooled) elif group_id == 7: - expert_out = self.experts[7](cls_output) + expert_out = self.experts[7](pooled) else: # group_id == 8 - expert_out = self.experts[8](cls_output) + expert_out = self.experts[8](pooled) else: # ------------------ 训练 / 普通推理:全量计算 + Gather ------------------