调整拼音分组与采样逻辑,优化模型结构及专家路由策略

This commit is contained in:
songsenand 2026-02-21 21:55:55 +08:00
parent 917c9f4256
commit 8f58917d13
8 changed files with 60 additions and 137 deletions

View File

@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset):
# 加载拼音分组 # 加载拼音分组
self.pg_groups = { self.pg_groups = {
"y": 0, "y": 0,
"z": 1, "k": 0,
"j": 2, "e": 0,
"l": 3, "l": 1,
"s": 4, "w": 1,
"x": 5, "f": 1,
"c": 6, "q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7, "h": 7,
"d": 8,
"b": 9,
"q": 10,
"g": 11,
"t": 12,
"m": 13,
"p": 14,
"w": 15,
"f": 16,
"k": 17,
"n": 18,
"r": 19,
"a": 19,
"e": 18,
"o": 17,
} }
def get_next_chinese_chars( def get_next_chinese_chars(
@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
# 随机选择采样方式 (各1/3概率) # 随机选择采样方式 (各1/3概率)
choice = random.random() choice = random.random()
if choice < 0.85: if choice < 0.3333:
# 方式1: 靠近汉字的54个字符 # 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context return context[-54:] if context_len >= 54 else context
elif choice < 0.95: elif choice < 0.6667:
# 方式2: 随机位置取46个连续字符 # 方式2: 随机位置取46个连续字符
if context_len <= 46: if context_len <= 46:
return context return context
start = random.randint(0, context_len - 46) start = random.randint(0, context_len - 46)
return context[start : start + 46] + context[-8:] return context[start : start + 46]
else: else:
# 方式3: 12+6×7组合 # 方式3: 12+6×7组合
if context_len < 12: if context_len < 12:

View File

@ -1,7 +1,6 @@
import pickle import pickle
from pathlib import Path from pathlib import Path
import torch
from loguru import logger from loguru import logger
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm

View File

@ -77,8 +77,9 @@ class ResidualBlock(nn.Module):
x = x + residual x = x + residual
return self.relu(x) return self.relu(x)
# ---------------------------- 专家网络 ----------------------------
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module): class Expert(nn.Module):
def __init__( def __init__(
self, self,
@ -90,7 +91,7 @@ class Expert(nn.Module):
): ):
""" """
input_dim : 输入维度 input_dim : 输入维度
d_model : 专家内部维度 d_model : 专家内部维度固定 1024
output_multiplier : 输出维度 = input_dim * output_multiplier output_multiplier : 输出维度 = input_dim * output_multiplier
dropout_prob : 残差块内 Dropout dropout_prob : 残差块内 Dropout
""" """
@ -120,8 +121,9 @@ class Expert(nn.Module):
x = block(x) x = block(x)
return self.output(x) return self.output(x)
# ---------------------------- 主模型MoE + 硬路由)------------------------
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module): class MoEModel(nn.Module):
def __init__( def __init__(
self, self,
@ -130,8 +132,8 @@ class MoEModel(nn.Module):
output_multiplier=2, output_multiplier=2,
d_model=768, d_model=768,
num_resblocks=4, num_resblocks=4,
num_domain_experts=20, num_domain_experts=8,
experts_dim=EXPORT_HIDE_DIM, num_shared_experts=1,
): ):
super().__init__() super().__init__()
self.output_multiplier = output_multiplier self.output_multiplier = output_multiplier
@ -142,54 +144,60 @@ class MoEModel(nn.Module):
self.bert_config = bert.config self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置 self.device = None # 将在 to() 调用时设置
self.experts_dim = experts_dim
# 2. 4 层标准 Transformer Encoder从 config 读取参数) # 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer( encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size, d_model=self.hidden_size,
nhead=8, nhead=self.bert_config.num_attention_heads,
dim_feedforward=self.bert_config.intermediate_size, dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob, dropout=self.bert_config.hidden_dropout_prob,
activation="gelu", activation="gelu",
batch_first=True, batch_first=True,
norm_first=True, # Pre-LN与预训练一致
) )
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)]) self.pooler = nn.AdaptiveAvgPool1d(1)
self.pooler = nn.AdaptiveAvgPool1d(2) # 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts
self.total_experts = 20
self.experts = nn.ModuleList() self.experts = nn.ModuleList()
for i in range(self.total_experts): for i in range(total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
dropout_prob = 0.1 if i < num_domain_experts else 0.2
expert = Expert( expert = Expert(
input_dim=self.hidden_size * 2, input_dim=self.hidden_size,
d_model=self.experts_dim[i], d_model=d_model,
num_resblocks=num_resblocks, num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
dropout_prob=0.1, dropout_prob=dropout_prob,
) )
self.experts.append(expert) self.experts.append(expert)
self.expert_bias = nn.Embedding( self.expert_bias = nn.Embedding(
self.total_experts, self.output_multiplier * self.hidden_size * 2 total_experts, self.output_multiplier * self.hidden_size
) )
# 4. 分类头 # 4. 分类头
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.LayerNorm(self.output_multiplier * self.hidden_size * 2), nn.LayerNorm(self.output_multiplier * self.hidden_size),
nn.Linear( nn.Linear(
self.output_multiplier * self.hidden_size * 2, self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2, self.output_multiplier * self.hidden_size,
), ),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear( nn.Linear(
self.output_multiplier * self.hidden_size * 2, self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 4, self.output_multiplier * self.hidden_size,
), ),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Dropout(0.1), nn.Linear(
nn.Linear(self.output_multiplier * self.hidden_size * 4, num_classes), self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
) )
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
@ -214,12 +222,8 @@ class MoEModel(nn.Module):
embeddings, src_key_padding_mask=padding_mask embeddings, src_key_padding_mask=padding_mask
) # [B, S, H] ) # [B, S, H]
for block in self.res_blocks:
pooled = block(encoded)
# ----- 3. 池化量 ----- # ----- 3. 池化量 -----
pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
pooled = pooled.flatten(1) # [B, H*2]
# ----- 4. 专家路由(硬路由)----- # ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing(): if torch.jit.is_tracing():
@ -259,54 +263,10 @@ class MoEModel(nn.Module):
expert_out = self.experts[7](pooled) + self.expert_bias( expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device) torch.tensor(7, device=pooled.device)
) )
elif group_id == 8: # group_id == 8 else: # group_id == 8
expert_out = self.experts[8](pooled) + self.expert_bias( expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device) torch.tensor(8, device=pooled.device)
) )
elif group_id == 9: # group_id == 9
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10: # group_id == 10
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11: # group_id == 11
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
elif group_id == 12: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
elif group_id == 13: # group_id == 13
expert_out = self.experts[13](pooled) + self.expert_bias(
torch.tensor(13, device=pooled.device)
)
elif group_id == 14: # group_id == 14
expert_out = self.experts[14](pooled) + self.expert_bias(
torch.tensor(14, device=pooled.device)
)
elif group_id == 15: # group_id == 15
expert_out = self.experts[15](pooled) + self.expert_bias(
torch.tensor(15, device=pooled.device)
)
elif group_id == 16: # group_id == 16
expert_out = self.experts[16](pooled) + self.expert_bias(
torch.tensor(16, device=pooled.device)
)
elif group_id == 17: # group_id == 17
expert_out = self.experts[17](pooled) + self.expert_bias(
torch.tensor(17, device=pooled.device)
)
elif group_id == 18: # group_id == 18
expert_out = self.experts[18](pooled) + self.expert_bias(
torch.tensor(18, device=pooled.device)
)
else: # group_id == 19
expert_out = self.experts[19](pooled) + self.expert_bias(
torch.tensor(19, device=pooled.device)
)
else: else:
batch_size = pooled.size(0) batch_size = pooled.size(0)
# 并行计算所有专家输出 # 并行计算所有专家输出
@ -394,8 +354,8 @@ class MoEModel(nn.Module):
# ------------------ 1. 提取并规范化输入 ------------------ # ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度) # 判断是否为单样本input_ids 无 batch 维度)
input_ids = sample['hint']["input_ids"] input_ids = sample["hint"]["input_ids"]
attention_mask = sample['hint']["attention_mask"] attention_mask = sample["hint"]["attention_mask"]
pg = sample["pg"] pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1 has_batch_dim = input_ids.dim() > 1
@ -521,7 +481,7 @@ class MoEModel(nn.Module):
# 默认优化器 # 默认优化器
if optimizer is None: if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr) optimizer = optim.AdamW(self.parameters(), lr=lr)
if criterion is None: if criterion is None:
if loss_weight is not None: if loss_weight is not None:
criterion = nn.CrossEntropyLoss(weight=loss_weight) criterion = nn.CrossEntropyLoss(weight=loss_weight)
@ -625,39 +585,3 @@ class MoEModel(nn.Module):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if name in freeze_layers: if name in freeze_layers:
param.requires_grad = False param.requires_grad = False
# ============================ 使用示例 ============================
if __name__ == "__main__":
# 1. 初始化模型
model = MoEModel()
model.eval()
# 2. 构造 dummy 输入batch=1用于导出 ONNX
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
# 3. 导出 ONNX使用条件分支仅计算一个专家
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask, dummy_pg),
"moe_cpu.onnx",
input_names=["input_ids", "attention_mask", "pg"],
output_names=["logits"],
dynamic_axes={ # 固定 batch=1可不设 dynamic_axes
"input_ids": {0: "batch"},
"attention_mask": {0: "batch"},
},
opset_version=12,
do_constant_folding=True,
)
print("ONNX 导出成功!")
# 4. 测试训练模式batch=4
model.train()
batch_input_ids = torch.randint(0, 100, (4, 64))
batch_attention_mask = torch.ones_like(batch_input_ids)
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
print("训练模式输出形状:", logits.shape) # [4, 10018]