调整拼音分组与采样逻辑,优化模型结构及专家路由策略
This commit is contained in:
parent
917c9f4256
commit
8f58917d13
|
|
@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 加载拼音分组
|
||||
self.pg_groups = {
|
||||
"y": 0,
|
||||
"z": 1,
|
||||
"j": 2,
|
||||
"l": 3,
|
||||
"s": 4,
|
||||
"x": 5,
|
||||
"c": 6,
|
||||
"k": 0,
|
||||
"e": 0,
|
||||
"l": 1,
|
||||
"w": 1,
|
||||
"f": 1,
|
||||
"q": 2,
|
||||
"a": 2,
|
||||
"s": 2,
|
||||
"x": 3,
|
||||
"b": 3,
|
||||
"r": 3,
|
||||
"o": 4,
|
||||
"m": 4,
|
||||
"z": 4,
|
||||
"g": 5,
|
||||
"n": 5,
|
||||
"c": 5,
|
||||
"t": 6,
|
||||
"p": 6,
|
||||
"d": 6,
|
||||
"j": 7,
|
||||
"h": 7,
|
||||
"d": 8,
|
||||
"b": 9,
|
||||
"q": 10,
|
||||
"g": 11,
|
||||
"t": 12,
|
||||
"m": 13,
|
||||
"p": 14,
|
||||
"w": 15,
|
||||
"f": 16,
|
||||
"k": 17,
|
||||
"n": 18,
|
||||
"r": 19,
|
||||
"a": 19,
|
||||
"e": 18,
|
||||
"o": 17,
|
||||
}
|
||||
|
||||
def get_next_chinese_chars(
|
||||
|
|
@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 随机选择采样方式 (各1/3概率)
|
||||
choice = random.random()
|
||||
|
||||
if choice < 0.85:
|
||||
if choice < 0.3333:
|
||||
# 方式1: 靠近汉字的54个字符
|
||||
return context[-54:] if context_len >= 54 else context
|
||||
elif choice < 0.95:
|
||||
elif choice < 0.6667:
|
||||
# 方式2: 随机位置取46个连续字符
|
||||
if context_len <= 46:
|
||||
return context
|
||||
start = random.randint(0, context_len - 46)
|
||||
return context[start : start + 46] + context[-8:]
|
||||
return context[start : start + 46]
|
||||
else:
|
||||
# 方式3: 12+6×7组合
|
||||
if context_len < 12:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -77,8 +77,9 @@ class ResidualBlock(nn.Module):
|
|||
x = x + residual
|
||||
return self.relu(x)
|
||||
|
||||
# ---------------------------- 专家网络 ----------------------------
|
||||
|
||||
|
||||
# ---------------------------- 专家网络 ----------------------------
|
||||
class Expert(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -90,7 +91,7 @@ class Expert(nn.Module):
|
|||
):
|
||||
"""
|
||||
input_dim : 输入维度
|
||||
d_model : 专家内部维度
|
||||
d_model : 专家内部维度(固定 1024)
|
||||
output_multiplier : 输出维度 = input_dim * output_multiplier
|
||||
dropout_prob : 残差块内 Dropout
|
||||
"""
|
||||
|
|
@ -120,8 +121,9 @@ class Expert(nn.Module):
|
|||
x = block(x)
|
||||
return self.output(x)
|
||||
|
||||
# ---------------------------- 主模型(MoE + 硬路由)------------------------
|
||||
|
||||
|
||||
# ---------------------------- 主模型(MoE + 硬路由)------------------------
|
||||
class MoEModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -130,8 +132,8 @@ class MoEModel(nn.Module):
|
|||
output_multiplier=2,
|
||||
d_model=768,
|
||||
num_resblocks=4,
|
||||
num_domain_experts=20,
|
||||
experts_dim=EXPORT_HIDE_DIM,
|
||||
num_domain_experts=8,
|
||||
num_shared_experts=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_multiplier = output_multiplier
|
||||
|
|
@ -142,54 +144,60 @@ class MoEModel(nn.Module):
|
|||
self.bert_config = bert.config
|
||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||
self.device = None # 将在 to() 调用时设置
|
||||
self.experts_dim = experts_dim
|
||||
|
||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=self.hidden_size,
|
||||
nhead=8,
|
||||
nhead=self.bert_config.num_attention_heads,
|
||||
dim_feedforward=self.bert_config.intermediate_size,
|
||||
dropout=self.bert_config.hidden_dropout_prob,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
norm_first=True, # Pre-LN,与预训练一致
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
self.pooler = nn.AdaptiveAvgPool1d(2)
|
||||
|
||||
self.total_experts = 20
|
||||
# 3. 专家层:8个领域专家 + 1个共享专家
|
||||
total_experts = num_domain_experts + num_shared_experts
|
||||
self.experts = nn.ModuleList()
|
||||
|
||||
for i in range(self.total_experts):
|
||||
for i in range(total_experts):
|
||||
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
|
||||
dropout_prob = 0.1 if i < num_domain_experts else 0.2
|
||||
expert = Expert(
|
||||
input_dim=self.hidden_size * 2,
|
||||
d_model=self.experts_dim[i],
|
||||
input_dim=self.hidden_size,
|
||||
d_model=d_model,
|
||||
num_resblocks=num_resblocks,
|
||||
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
||||
dropout_prob=0.1,
|
||||
dropout_prob=dropout_prob,
|
||||
)
|
||||
self.experts.append(expert)
|
||||
|
||||
self.expert_bias = nn.Embedding(
|
||||
self.total_experts, self.output_multiplier * self.hidden_size * 2
|
||||
total_experts, self.output_multiplier * self.hidden_size
|
||||
)
|
||||
|
||||
# 4. 分类头
|
||||
self.classifier = nn.Sequential(
|
||||
nn.LayerNorm(self.output_multiplier * self.hidden_size * 2),
|
||||
nn.LayerNorm(self.output_multiplier * self.hidden_size),
|
||||
nn.Linear(
|
||||
self.output_multiplier * self.hidden_size * 2,
|
||||
self.output_multiplier * self.hidden_size * 2,
|
||||
self.output_multiplier * self.hidden_size,
|
||||
self.output_multiplier * self.hidden_size,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(
|
||||
self.output_multiplier * self.hidden_size * 2,
|
||||
self.output_multiplier * self.hidden_size * 4,
|
||||
self.output_multiplier * self.hidden_size,
|
||||
self.output_multiplier * self.hidden_size,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(self.output_multiplier * self.hidden_size * 4, num_classes),
|
||||
nn.Linear(
|
||||
self.output_multiplier * self.hidden_size,
|
||||
self.output_multiplier * self.hidden_size * 2,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
|
||||
)
|
||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||
|
||||
|
|
@ -214,12 +222,8 @@ class MoEModel(nn.Module):
|
|||
embeddings, src_key_padding_mask=padding_mask
|
||||
) # [B, S, H]
|
||||
|
||||
for block in self.res_blocks:
|
||||
pooled = block(encoded)
|
||||
|
||||
# ----- 3. 池化量 -----
|
||||
pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
|
||||
pooled = pooled.flatten(1) # [B, H*2]
|
||||
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||
|
||||
# ----- 4. 专家路由(硬路由)-----
|
||||
if torch.jit.is_tracing():
|
||||
|
|
@ -259,54 +263,10 @@ class MoEModel(nn.Module):
|
|||
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||||
torch.tensor(7, device=pooled.device)
|
||||
)
|
||||
elif group_id == 8: # group_id == 8
|
||||
else: # group_id == 8
|
||||
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||||
torch.tensor(8, device=pooled.device)
|
||||
)
|
||||
elif group_id == 9: # group_id == 9
|
||||
expert_out = self.experts[9](pooled) + self.expert_bias(
|
||||
torch.tensor(9, device=pooled.device)
|
||||
)
|
||||
elif group_id == 10: # group_id == 10
|
||||
expert_out = self.experts[10](pooled) + self.expert_bias(
|
||||
torch.tensor(10, device=pooled.device)
|
||||
)
|
||||
elif group_id == 11: # group_id == 11
|
||||
expert_out = self.experts[11](pooled) + self.expert_bias(
|
||||
torch.tensor(11, device=pooled.device)
|
||||
)
|
||||
elif group_id == 12: # group_id == 12
|
||||
expert_out = self.experts[12](pooled) + self.expert_bias(
|
||||
torch.tensor(12, device=pooled.device)
|
||||
)
|
||||
elif group_id == 13: # group_id == 13
|
||||
expert_out = self.experts[13](pooled) + self.expert_bias(
|
||||
torch.tensor(13, device=pooled.device)
|
||||
)
|
||||
elif group_id == 14: # group_id == 14
|
||||
expert_out = self.experts[14](pooled) + self.expert_bias(
|
||||
torch.tensor(14, device=pooled.device)
|
||||
)
|
||||
elif group_id == 15: # group_id == 15
|
||||
expert_out = self.experts[15](pooled) + self.expert_bias(
|
||||
torch.tensor(15, device=pooled.device)
|
||||
)
|
||||
elif group_id == 16: # group_id == 16
|
||||
expert_out = self.experts[16](pooled) + self.expert_bias(
|
||||
torch.tensor(16, device=pooled.device)
|
||||
)
|
||||
elif group_id == 17: # group_id == 17
|
||||
expert_out = self.experts[17](pooled) + self.expert_bias(
|
||||
torch.tensor(17, device=pooled.device)
|
||||
)
|
||||
elif group_id == 18: # group_id == 18
|
||||
expert_out = self.experts[18](pooled) + self.expert_bias(
|
||||
torch.tensor(18, device=pooled.device)
|
||||
)
|
||||
else: # group_id == 19
|
||||
expert_out = self.experts[19](pooled) + self.expert_bias(
|
||||
torch.tensor(19, device=pooled.device)
|
||||
)
|
||||
else:
|
||||
batch_size = pooled.size(0)
|
||||
# 并行计算所有专家输出
|
||||
|
|
@ -394,8 +354,8 @@ class MoEModel(nn.Module):
|
|||
|
||||
# ------------------ 1. 提取并规范化输入 ------------------
|
||||
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||
input_ids = sample['hint']["input_ids"]
|
||||
attention_mask = sample['hint']["attention_mask"]
|
||||
input_ids = sample["hint"]["input_ids"]
|
||||
attention_mask = sample["hint"]["attention_mask"]
|
||||
pg = sample["pg"]
|
||||
has_batch_dim = input_ids.dim() > 1
|
||||
|
||||
|
|
@ -625,39 +585,3 @@ class MoEModel(nn.Module):
|
|||
for name, param in self.named_parameters():
|
||||
if name in freeze_layers:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# ============================ 使用示例 ============================
|
||||
if __name__ == "__main__":
|
||||
# 1. 初始化模型
|
||||
model = MoEModel()
|
||||
model.eval()
|
||||
|
||||
# 2. 构造 dummy 输入(batch=1,用于导出 ONNX)
|
||||
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
|
||||
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
|
||||
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
|
||||
|
||||
# 3. 导出 ONNX(使用条件分支,仅计算一个专家)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy_input_ids, dummy_attention_mask, dummy_pg),
|
||||
"moe_cpu.onnx",
|
||||
input_names=["input_ids", "attention_mask", "pg"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes
|
||||
"input_ids": {0: "batch"},
|
||||
"attention_mask": {0: "batch"},
|
||||
},
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
print("ONNX 导出成功!")
|
||||
|
||||
# 4. 测试训练模式(batch=4)
|
||||
model.train()
|
||||
batch_input_ids = torch.randint(0, 100, (4, 64))
|
||||
batch_attention_mask = torch.ones_like(batch_input_ids)
|
||||
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
|
||||
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
|
||||
print("训练模式输出形状:", logits.shape) # [4, 10018]
|
||||
|
|
|
|||
Loading…
Reference in New Issue