From 7eb00c6207b7fab71b8d4ce2c6715085a59134a3 Mon Sep 17 00:00:00 2001 From: songsenand Date: Fri, 13 Feb 2026 16:11:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E4=BC=98=E5=8C=96=E4=B8=93?= =?UTF-8?q?=E5=AE=B6=E8=BE=93=E5=87=BA=E7=BB=93=E6=9E=84=E5=B9=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E4=B8=93=E5=AE=B6=E5=81=8F=E7=BD=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/trainer/model.py | 92 +++++++++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 27 deletions(-) diff --git a/src/trainer/model.py b/src/trainer/model.py index 3b2e4f1..dbc4715 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -89,12 +89,15 @@ class MoEModel(nn.Module): self, pretrained_model_name="iic/nlp_structbert_backbone_tiny_std", num_classes=10018, - d_model=1024, + output_multiplier=2, + d_model=768, num_resblocks=4, num_domain_experts=8, num_shared_experts=1, ): super().__init__() + self.output_multiplier = output_multiplier + # 1. 加载预训练 BERT,仅保留 embeddings bert = AutoModel.from_pretrained(pretrained_model_name) self.embedding = bert.embeddings @@ -126,18 +129,36 @@ class MoEModel(nn.Module): input_dim=self.hidden_size, d_model=d_model, num_resblocks=num_resblocks, - output_multiplier=2, # 输出维度 = 2 * hidden_size + output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size dropout_prob=dropout_prob, ) self.experts.append(expert) - # 4. 分类头 - self.classifier = nn.Sequential( - nn.LayerNorm(2 * self.hidden_size), # 专家输出维度 - nn.Dropout(0.1), - nn.Linear(2 * self.hidden_size, num_classes), + self.expert_bias = nn.Embedding( + total_experts, self.output_multiplier * self.hidden_size ) + # 4. 分类头 + self.classifier = nn.Sequential( + nn.LayerNorm(self.output_multiplier * self.hidden_size), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size, + ), + nn.ReLU(inplace=True), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size, + ), + nn.ReLU(inplace=True), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size * 2, + ), + nn.ReLU(inplace=True), + nn.Dropout(0.2), + nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes), + ) # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) def to(self, device): @@ -161,7 +182,7 @@ class MoEModel(nn.Module): embeddings, src_key_padding_mask=padding_mask ) # [B, S, H] - # ----- 3. [CLS] 向量 ----- + # ----- 3. 池化量 ----- pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) # ----- 4. 专家路由(硬路由)----- @@ -171,36 +192,53 @@ class MoEModel(nn.Module): group_id = pg.item() if torch.is_tensor(pg) else pg if group_id == 0: - expert_out = self.experts[0](pooled) + expert_out = self.experts[0](pooled) + self.expert_bias( + torch.tensor(0, device=pooled.device) + ) elif group_id == 1: - expert_out = self.experts[1](pooled) + expert_out = self.experts[1](pooled) + self.expert_bias( + torch.tensor(1, device=pooled.device) + ) elif group_id == 2: - expert_out = self.experts[2](pooled) + expert_out = self.experts[2](pooled) + self.expert_bias( + torch.tensor(2, device=pooled.device) + ) elif group_id == 3: - expert_out = self.experts[3](pooled) + expert_out = self.experts[3](pooled) + self.expert_bias( + torch.tensor(3, device=pooled.device) + ) elif group_id == 4: - expert_out = self.experts[4](pooled) + expert_out = self.experts[4](pooled) + self.expert_bias( + torch.tensor(4, device=pooled.device) + ) elif group_id == 5: - expert_out = self.experts[5](pooled) + expert_out = self.experts[5](pooled) + self.expert_bias( + torch.tensor(5, device=pooled.device) + ) elif group_id == 6: - expert_out = self.experts[6](pooled) + expert_out = self.experts[6](pooled) + self.expert_bias( + torch.tensor(6, device=pooled.device) + ) elif group_id == 7: - expert_out = self.experts[7](pooled) + expert_out = self.experts[7](pooled) + self.expert_bias( + torch.tensor(7, device=pooled.device) + ) else: # group_id == 8 - expert_out = self.experts[8](pooled) + expert_out = self.experts[8](pooled) + self.expert_bias( + torch.tensor(8, device=pooled.device) + ) else: - # ------------------ 训练 / 普通推理:全量计算 + Gather ------------------ - # 此时 pg 为 [batch] 的 LongTensor - batch_size = cls_output.size(0) - # 所有专家并行计算,输出堆叠 + batch_size = pooled.size(0) + # 并行计算所有专家输出 expert_outputs = torch.stack( - [e(cls_output) for e in self.experts], dim=0 - ) # [num_experts, batch, output_dim] - # 根据 pg 索引对应的专家输出 - expert_out = expert_outputs[ - pg, torch.arange(batch_size) - ] # [batch, output_dim] + [e(pooled) for e in self.experts], dim=0 + ) # [E, B, D] + # 根据 pg 索引专家输出 + expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D] + # 添加专家偏置 + bias = self.expert_bias(pg) # [B, D] + expert_out = expert_out + bias # ----- 5. 分类头 ----- logits = self.classifier(expert_out) # [batch, num_classes]