feat(model): 优化专家输出结构并添加专家偏置支持
This commit is contained in:
parent
f4be47df78
commit
7eb00c6207
|
|
@ -89,12 +89,15 @@ class MoEModel(nn.Module):
|
||||||
self,
|
self,
|
||||||
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
|
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
|
||||||
num_classes=10018,
|
num_classes=10018,
|
||||||
d_model=1024,
|
output_multiplier=2,
|
||||||
|
d_model=768,
|
||||||
num_resblocks=4,
|
num_resblocks=4,
|
||||||
num_domain_experts=8,
|
num_domain_experts=8,
|
||||||
num_shared_experts=1,
|
num_shared_experts=1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.output_multiplier = output_multiplier
|
||||||
|
|
||||||
# 1. 加载预训练 BERT,仅保留 embeddings
|
# 1. 加载预训练 BERT,仅保留 embeddings
|
||||||
bert = AutoModel.from_pretrained(pretrained_model_name)
|
bert = AutoModel.from_pretrained(pretrained_model_name)
|
||||||
self.embedding = bert.embeddings
|
self.embedding = bert.embeddings
|
||||||
|
|
@ -126,18 +129,36 @@ class MoEModel(nn.Module):
|
||||||
input_dim=self.hidden_size,
|
input_dim=self.hidden_size,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
num_resblocks=num_resblocks,
|
num_resblocks=num_resblocks,
|
||||||
output_multiplier=2, # 输出维度 = 2 * hidden_size
|
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
||||||
dropout_prob=dropout_prob,
|
dropout_prob=dropout_prob,
|
||||||
)
|
)
|
||||||
self.experts.append(expert)
|
self.experts.append(expert)
|
||||||
|
|
||||||
# 4. 分类头
|
self.expert_bias = nn.Embedding(
|
||||||
self.classifier = nn.Sequential(
|
total_experts, self.output_multiplier * self.hidden_size
|
||||||
nn.LayerNorm(2 * self.hidden_size), # 专家输出维度
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(2 * self.hidden_size, num_classes),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 4. 分类头
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.LayerNorm(self.output_multiplier * self.hidden_size),
|
||||||
|
nn.Linear(
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size * 2,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
|
||||||
|
)
|
||||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
|
|
@ -161,7 +182,7 @@ class MoEModel(nn.Module):
|
||||||
embeddings, src_key_padding_mask=padding_mask
|
embeddings, src_key_padding_mask=padding_mask
|
||||||
) # [B, S, H]
|
) # [B, S, H]
|
||||||
|
|
||||||
# ----- 3. [CLS] 向量 -----
|
# ----- 3. 池化量 -----
|
||||||
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||||
|
|
||||||
# ----- 4. 专家路由(硬路由)-----
|
# ----- 4. 专家路由(硬路由)-----
|
||||||
|
|
@ -171,36 +192,53 @@ class MoEModel(nn.Module):
|
||||||
group_id = pg.item() if torch.is_tensor(pg) else pg
|
group_id = pg.item() if torch.is_tensor(pg) else pg
|
||||||
|
|
||||||
if group_id == 0:
|
if group_id == 0:
|
||||||
expert_out = self.experts[0](pooled)
|
expert_out = self.experts[0](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(0, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 1:
|
elif group_id == 1:
|
||||||
expert_out = self.experts[1](pooled)
|
expert_out = self.experts[1](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(1, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 2:
|
elif group_id == 2:
|
||||||
expert_out = self.experts[2](pooled)
|
expert_out = self.experts[2](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(2, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 3:
|
elif group_id == 3:
|
||||||
expert_out = self.experts[3](pooled)
|
expert_out = self.experts[3](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(3, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 4:
|
elif group_id == 4:
|
||||||
expert_out = self.experts[4](pooled)
|
expert_out = self.experts[4](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(4, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 5:
|
elif group_id == 5:
|
||||||
expert_out = self.experts[5](pooled)
|
expert_out = self.experts[5](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(5, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 6:
|
elif group_id == 6:
|
||||||
expert_out = self.experts[6](pooled)
|
expert_out = self.experts[6](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(6, device=pooled.device)
|
||||||
|
)
|
||||||
elif group_id == 7:
|
elif group_id == 7:
|
||||||
expert_out = self.experts[7](pooled)
|
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(7, device=pooled.device)
|
||||||
|
)
|
||||||
else: # group_id == 8
|
else: # group_id == 8
|
||||||
expert_out = self.experts[8](pooled)
|
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(8, device=pooled.device)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# ------------------ 训练 / 普通推理:全量计算 + Gather ------------------
|
batch_size = pooled.size(0)
|
||||||
# 此时 pg 为 [batch] 的 LongTensor
|
# 并行计算所有专家输出
|
||||||
batch_size = cls_output.size(0)
|
|
||||||
# 所有专家并行计算,输出堆叠
|
|
||||||
expert_outputs = torch.stack(
|
expert_outputs = torch.stack(
|
||||||
[e(cls_output) for e in self.experts], dim=0
|
[e(pooled) for e in self.experts], dim=0
|
||||||
) # [num_experts, batch, output_dim]
|
) # [E, B, D]
|
||||||
# 根据 pg 索引对应的专家输出
|
# 根据 pg 索引专家输出
|
||||||
expert_out = expert_outputs[
|
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
|
||||||
pg, torch.arange(batch_size)
|
# 添加专家偏置
|
||||||
] # [batch, output_dim]
|
bias = self.expert_bias(pg) # [B, D]
|
||||||
|
expert_out = expert_out + bias
|
||||||
|
|
||||||
# ----- 5. 分类头 -----
|
# ----- 5. 分类头 -----
|
||||||
logits = self.classifier(expert_out) # [batch, num_classes]
|
logits = self.classifier(expert_out) # [batch, num_classes]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue