调整拼音分组与采样逻辑,优化模型结构及专家路由策略

This commit is contained in:
songsenand 2026-02-21 21:55:55 +08:00
parent 917c9f4256
commit 8f58917d13
8 changed files with 60 additions and 137 deletions

View File

@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset):
# 加载拼音分组
self.pg_groups = {
"y": 0,
"z": 1,
"j": 2,
"l": 3,
"s": 4,
"x": 5,
"c": 6,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
"d": 8,
"b": 9,
"q": 10,
"g": 11,
"t": 12,
"m": 13,
"p": 14,
"w": 15,
"f": 16,
"k": 17,
"n": 18,
"r": 19,
"a": 19,
"e": 18,
"o": 17,
}
def get_next_chinese_chars(
@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
# 随机选择采样方式 (各1/3概率)
choice = random.random()
if choice < 0.85:
if choice < 0.3333:
# 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context
elif choice < 0.95:
elif choice < 0.6667:
# 方式2: 随机位置取46个连续字符
if context_len <= 46:
return context
start = random.randint(0, context_len - 46)
return context[start : start + 46] + context[-8:]
return context[start : start + 46]
else:
# 方式3: 12+6×7组合
if context_len < 12:

View File

@ -1,7 +1,6 @@
import pickle
from pathlib import Path
import torch
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm

View File

@ -77,8 +77,9 @@ class ResidualBlock(nn.Module):
x = x + residual
return self.relu(x)
# ---------------------------- 专家网络 ----------------------------
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
@ -90,7 +91,7 @@ class Expert(nn.Module):
):
"""
input_dim : 输入维度
d_model : 专家内部维度
d_model : 专家内部维度固定 1024
output_multiplier : 输出维度 = input_dim * output_multiplier
dropout_prob : 残差块内 Dropout
"""
@ -120,8 +121,9 @@ class Expert(nn.Module):
x = block(x)
return self.output(x)
# ---------------------------- 主模型MoE + 硬路由)------------------------
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module):
def __init__(
self,
@ -130,8 +132,8 @@ class MoEModel(nn.Module):
output_multiplier=2,
d_model=768,
num_resblocks=4,
num_domain_experts=20,
experts_dim=EXPORT_HIDE_DIM,
num_domain_experts=8,
num_shared_experts=1,
):
super().__init__()
self.output_multiplier = output_multiplier
@ -142,54 +144,60 @@ class MoEModel(nn.Module):
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置
self.experts_dim = experts_dim
# 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
nhead=self.bert_config.num_attention_heads,
dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
norm_first=True, # Pre-LN与预训练一致
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
self.pooler = nn.AdaptiveAvgPool1d(1)
self.pooler = nn.AdaptiveAvgPool1d(2)
self.total_experts = 20
# 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList()
for i in range(self.total_experts):
for i in range(total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
dropout_prob = 0.1 if i < num_domain_experts else 0.2
expert = Expert(
input_dim=self.hidden_size * 2,
d_model=self.experts_dim[i],
input_dim=self.hidden_size,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
dropout_prob=0.1,
dropout_prob=dropout_prob,
)
self.experts.append(expert)
self.expert_bias = nn.Embedding(
self.total_experts, self.output_multiplier * self.hidden_size * 2
total_experts, self.output_multiplier * self.hidden_size
)
# 4. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.output_multiplier * self.hidden_size * 2),
nn.LayerNorm(self.output_multiplier * self.hidden_size),
nn.Linear(
self.output_multiplier * self.hidden_size * 2,
self.output_multiplier * self.hidden_size * 2,
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size * 2,
self.output_multiplier * self.hidden_size * 4,
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(self.output_multiplier * self.hidden_size * 4, num_classes),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
)
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
@ -214,12 +222,8 @@ class MoEModel(nn.Module):
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
for block in self.res_blocks:
pooled = block(encoded)
# ----- 3. 池化量 -----
pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
pooled = pooled.flatten(1) # [B, H*2]
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
@ -259,54 +263,10 @@ class MoEModel(nn.Module):
expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device)
)
elif group_id == 8: # group_id == 8
else: # group_id == 8
expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device)
)
elif group_id == 9: # group_id == 9
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10: # group_id == 10
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11: # group_id == 11
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
elif group_id == 12: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
elif group_id == 13: # group_id == 13
expert_out = self.experts[13](pooled) + self.expert_bias(
torch.tensor(13, device=pooled.device)
)
elif group_id == 14: # group_id == 14
expert_out = self.experts[14](pooled) + self.expert_bias(
torch.tensor(14, device=pooled.device)
)
elif group_id == 15: # group_id == 15
expert_out = self.experts[15](pooled) + self.expert_bias(
torch.tensor(15, device=pooled.device)
)
elif group_id == 16: # group_id == 16
expert_out = self.experts[16](pooled) + self.expert_bias(
torch.tensor(16, device=pooled.device)
)
elif group_id == 17: # group_id == 17
expert_out = self.experts[17](pooled) + self.expert_bias(
torch.tensor(17, device=pooled.device)
)
elif group_id == 18: # group_id == 18
expert_out = self.experts[18](pooled) + self.expert_bias(
torch.tensor(18, device=pooled.device)
)
else: # group_id == 19
expert_out = self.experts[19](pooled) + self.expert_bias(
torch.tensor(19, device=pooled.device)
)
else:
batch_size = pooled.size(0)
# 并行计算所有专家输出
@ -394,8 +354,8 @@ class MoEModel(nn.Module):
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
input_ids = sample['hint']["input_ids"]
attention_mask = sample['hint']["attention_mask"]
input_ids = sample["hint"]["input_ids"]
attention_mask = sample["hint"]["attention_mask"]
pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1
@ -521,7 +481,7 @@ class MoEModel(nn.Module):
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr)
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(weight=loss_weight)
@ -625,39 +585,3 @@ class MoEModel(nn.Module):
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False
# ============================ 使用示例 ============================
if __name__ == "__main__":
# 1. 初始化模型
model = MoEModel()
model.eval()
# 2. 构造 dummy 输入batch=1用于导出 ONNX
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
# 3. 导出 ONNX使用条件分支仅计算一个专家
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask, dummy_pg),
"moe_cpu.onnx",
input_names=["input_ids", "attention_mask", "pg"],
output_names=["logits"],
dynamic_axes={ # 固定 batch=1可不设 dynamic_axes
"input_ids": {0: "batch"},
"attention_mask": {0: "batch"},
},
opset_version=12,
do_constant_folding=True,
)
print("ONNX 导出成功!")
# 4. 测试训练模式batch=4
model.train()
batch_input_ids = torch.randint(0, 100, (4, 64))
batch_attention_mask = torch.ones_like(batch_input_ids)
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
print("训练模式输出形状:", logits.shape) # [4, 10018]