调整拼音分组与采样逻辑,优化模型结构及专家路由策略
This commit is contained in:
parent
917c9f4256
commit
8f58917d13
|
|
@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 加载拼音分组
|
# 加载拼音分组
|
||||||
self.pg_groups = {
|
self.pg_groups = {
|
||||||
"y": 0,
|
"y": 0,
|
||||||
"z": 1,
|
"k": 0,
|
||||||
"j": 2,
|
"e": 0,
|
||||||
"l": 3,
|
"l": 1,
|
||||||
"s": 4,
|
"w": 1,
|
||||||
"x": 5,
|
"f": 1,
|
||||||
"c": 6,
|
"q": 2,
|
||||||
|
"a": 2,
|
||||||
|
"s": 2,
|
||||||
|
"x": 3,
|
||||||
|
"b": 3,
|
||||||
|
"r": 3,
|
||||||
|
"o": 4,
|
||||||
|
"m": 4,
|
||||||
|
"z": 4,
|
||||||
|
"g": 5,
|
||||||
|
"n": 5,
|
||||||
|
"c": 5,
|
||||||
|
"t": 6,
|
||||||
|
"p": 6,
|
||||||
|
"d": 6,
|
||||||
|
"j": 7,
|
||||||
"h": 7,
|
"h": 7,
|
||||||
"d": 8,
|
|
||||||
"b": 9,
|
|
||||||
"q": 10,
|
|
||||||
"g": 11,
|
|
||||||
"t": 12,
|
|
||||||
"m": 13,
|
|
||||||
"p": 14,
|
|
||||||
"w": 15,
|
|
||||||
"f": 16,
|
|
||||||
"k": 17,
|
|
||||||
"n": 18,
|
|
||||||
"r": 19,
|
|
||||||
"a": 19,
|
|
||||||
"e": 18,
|
|
||||||
"o": 17,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
|
|
@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 随机选择采样方式 (各1/3概率)
|
# 随机选择采样方式 (各1/3概率)
|
||||||
choice = random.random()
|
choice = random.random()
|
||||||
|
|
||||||
if choice < 0.85:
|
if choice < 0.3333:
|
||||||
# 方式1: 靠近汉字的54个字符
|
# 方式1: 靠近汉字的54个字符
|
||||||
return context[-54:] if context_len >= 54 else context
|
return context[-54:] if context_len >= 54 else context
|
||||||
elif choice < 0.95:
|
elif choice < 0.6667:
|
||||||
# 方式2: 随机位置取46个连续字符
|
# 方式2: 随机位置取46个连续字符
|
||||||
if context_len <= 46:
|
if context_len <= 46:
|
||||||
return context
|
return context
|
||||||
start = random.randint(0, context_len - 46)
|
start = random.randint(0, context_len - 46)
|
||||||
return context[start : start + 46] + context[-8:]
|
return context[start : start + 46]
|
||||||
else:
|
else:
|
||||||
# 方式3: 12+6×7组合
|
# 方式3: 12+6×7组合
|
||||||
if context_len < 12:
|
if context_len < 12:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -77,8 +77,9 @@ class ResidualBlock(nn.Module):
|
||||||
x = x + residual
|
x = x + residual
|
||||||
return self.relu(x)
|
return self.relu(x)
|
||||||
|
|
||||||
|
# ---------------------------- 专家网络 ----------------------------
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 专家网络 ----------------------------
|
|
||||||
class Expert(nn.Module):
|
class Expert(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -90,7 +91,7 @@ class Expert(nn.Module):
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
input_dim : 输入维度
|
input_dim : 输入维度
|
||||||
d_model : 专家内部维度
|
d_model : 专家内部维度(固定 1024)
|
||||||
output_multiplier : 输出维度 = input_dim * output_multiplier
|
output_multiplier : 输出维度 = input_dim * output_multiplier
|
||||||
dropout_prob : 残差块内 Dropout
|
dropout_prob : 残差块内 Dropout
|
||||||
"""
|
"""
|
||||||
|
|
@ -120,8 +121,9 @@ class Expert(nn.Module):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
return self.output(x)
|
return self.output(x)
|
||||||
|
|
||||||
|
# ---------------------------- 主模型(MoE + 硬路由)------------------------
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 主模型(MoE + 硬路由)------------------------
|
|
||||||
class MoEModel(nn.Module):
|
class MoEModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -130,8 +132,8 @@ class MoEModel(nn.Module):
|
||||||
output_multiplier=2,
|
output_multiplier=2,
|
||||||
d_model=768,
|
d_model=768,
|
||||||
num_resblocks=4,
|
num_resblocks=4,
|
||||||
num_domain_experts=20,
|
num_domain_experts=8,
|
||||||
experts_dim=EXPORT_HIDE_DIM,
|
num_shared_experts=1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.output_multiplier = output_multiplier
|
self.output_multiplier = output_multiplier
|
||||||
|
|
@ -142,54 +144,60 @@ class MoEModel(nn.Module):
|
||||||
self.bert_config = bert.config
|
self.bert_config = bert.config
|
||||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||||
self.device = None # 将在 to() 调用时设置
|
self.device = None # 将在 to() 调用时设置
|
||||||
self.experts_dim = experts_dim
|
|
||||||
|
|
||||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
d_model=self.hidden_size,
|
d_model=self.hidden_size,
|
||||||
nhead=8,
|
nhead=self.bert_config.num_attention_heads,
|
||||||
dim_feedforward=self.bert_config.intermediate_size,
|
dim_feedforward=self.bert_config.intermediate_size,
|
||||||
dropout=self.bert_config.hidden_dropout_prob,
|
dropout=self.bert_config.hidden_dropout_prob,
|
||||||
activation="gelu",
|
activation="gelu",
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
|
norm_first=True, # Pre-LN,与预训练一致
|
||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
|
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||||
|
|
||||||
self.pooler = nn.AdaptiveAvgPool1d(2)
|
# 3. 专家层:8个领域专家 + 1个共享专家
|
||||||
|
total_experts = num_domain_experts + num_shared_experts
|
||||||
self.total_experts = 20
|
|
||||||
self.experts = nn.ModuleList()
|
self.experts = nn.ModuleList()
|
||||||
|
|
||||||
for i in range(self.total_experts):
|
for i in range(total_experts):
|
||||||
|
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
|
||||||
|
dropout_prob = 0.1 if i < num_domain_experts else 0.2
|
||||||
expert = Expert(
|
expert = Expert(
|
||||||
input_dim=self.hidden_size * 2,
|
input_dim=self.hidden_size,
|
||||||
d_model=self.experts_dim[i],
|
d_model=d_model,
|
||||||
num_resblocks=num_resblocks,
|
num_resblocks=num_resblocks,
|
||||||
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
||||||
dropout_prob=0.1,
|
dropout_prob=dropout_prob,
|
||||||
)
|
)
|
||||||
self.experts.append(expert)
|
self.experts.append(expert)
|
||||||
|
|
||||||
self.expert_bias = nn.Embedding(
|
self.expert_bias = nn.Embedding(
|
||||||
self.total_experts, self.output_multiplier * self.hidden_size * 2
|
total_experts, self.output_multiplier * self.hidden_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 分类头
|
# 4. 分类头
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.LayerNorm(self.output_multiplier * self.hidden_size * 2),
|
nn.LayerNorm(self.output_multiplier * self.hidden_size),
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
self.output_multiplier * self.hidden_size * 2,
|
self.output_multiplier * self.hidden_size,
|
||||||
self.output_multiplier * self.hidden_size * 2,
|
self.output_multiplier * self.hidden_size,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
self.output_multiplier * self.hidden_size * 2,
|
self.output_multiplier * self.hidden_size,
|
||||||
self.output_multiplier * self.hidden_size * 4,
|
self.output_multiplier * self.hidden_size,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Dropout(0.1),
|
nn.Linear(
|
||||||
nn.Linear(self.output_multiplier * self.hidden_size * 4, num_classes),
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size * 2,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
|
||||||
)
|
)
|
||||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||||
|
|
||||||
|
|
@ -214,12 +222,8 @@ class MoEModel(nn.Module):
|
||||||
embeddings, src_key_padding_mask=padding_mask
|
embeddings, src_key_padding_mask=padding_mask
|
||||||
) # [B, S, H]
|
) # [B, S, H]
|
||||||
|
|
||||||
for block in self.res_blocks:
|
|
||||||
pooled = block(encoded)
|
|
||||||
|
|
||||||
# ----- 3. 池化量 -----
|
# ----- 3. 池化量 -----
|
||||||
pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
|
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||||
pooled = pooled.flatten(1) # [B, H*2]
|
|
||||||
|
|
||||||
# ----- 4. 专家路由(硬路由)-----
|
# ----- 4. 专家路由(硬路由)-----
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
|
|
@ -259,54 +263,10 @@ class MoEModel(nn.Module):
|
||||||
expert_out = self.experts[7](pooled) + self.expert_bias(
|
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||||||
torch.tensor(7, device=pooled.device)
|
torch.tensor(7, device=pooled.device)
|
||||||
)
|
)
|
||||||
elif group_id == 8: # group_id == 8
|
else: # group_id == 8
|
||||||
expert_out = self.experts[8](pooled) + self.expert_bias(
|
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||||||
torch.tensor(8, device=pooled.device)
|
torch.tensor(8, device=pooled.device)
|
||||||
)
|
)
|
||||||
elif group_id == 9: # group_id == 9
|
|
||||||
expert_out = self.experts[9](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(9, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 10: # group_id == 10
|
|
||||||
expert_out = self.experts[10](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(10, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 11: # group_id == 11
|
|
||||||
expert_out = self.experts[11](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(11, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 12: # group_id == 12
|
|
||||||
expert_out = self.experts[12](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(12, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 13: # group_id == 13
|
|
||||||
expert_out = self.experts[13](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(13, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 14: # group_id == 14
|
|
||||||
expert_out = self.experts[14](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(14, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 15: # group_id == 15
|
|
||||||
expert_out = self.experts[15](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(15, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 16: # group_id == 16
|
|
||||||
expert_out = self.experts[16](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(16, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 17: # group_id == 17
|
|
||||||
expert_out = self.experts[17](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(17, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 18: # group_id == 18
|
|
||||||
expert_out = self.experts[18](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(18, device=pooled.device)
|
|
||||||
)
|
|
||||||
else: # group_id == 19
|
|
||||||
expert_out = self.experts[19](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(19, device=pooled.device)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
batch_size = pooled.size(0)
|
batch_size = pooled.size(0)
|
||||||
# 并行计算所有专家输出
|
# 并行计算所有专家输出
|
||||||
|
|
@ -394,8 +354,8 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# ------------------ 1. 提取并规范化输入 ------------------
|
# ------------------ 1. 提取并规范化输入 ------------------
|
||||||
# 判断是否为单样本(input_ids 无 batch 维度)
|
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||||
input_ids = sample['hint']["input_ids"]
|
input_ids = sample["hint"]["input_ids"]
|
||||||
attention_mask = sample['hint']["attention_mask"]
|
attention_mask = sample["hint"]["attention_mask"]
|
||||||
pg = sample["pg"]
|
pg = sample["pg"]
|
||||||
has_batch_dim = input_ids.dim() > 1
|
has_batch_dim = input_ids.dim() > 1
|
||||||
|
|
||||||
|
|
@ -625,39 +585,3 @@ class MoEModel(nn.Module):
|
||||||
for name, param in self.named_parameters():
|
for name, param in self.named_parameters():
|
||||||
if name in freeze_layers:
|
if name in freeze_layers:
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
# ============================ 使用示例 ============================
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 1. 初始化模型
|
|
||||||
model = MoEModel()
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# 2. 构造 dummy 输入(batch=1,用于导出 ONNX)
|
|
||||||
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
|
|
||||||
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
|
|
||||||
|
|
||||||
# 3. 导出 ONNX(使用条件分支,仅计算一个专家)
|
|
||||||
torch.onnx.export(
|
|
||||||
model,
|
|
||||||
(dummy_input_ids, dummy_attention_mask, dummy_pg),
|
|
||||||
"moe_cpu.onnx",
|
|
||||||
input_names=["input_ids", "attention_mask", "pg"],
|
|
||||||
output_names=["logits"],
|
|
||||||
dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes
|
|
||||||
"input_ids": {0: "batch"},
|
|
||||||
"attention_mask": {0: "batch"},
|
|
||||||
},
|
|
||||||
opset_version=12,
|
|
||||||
do_constant_folding=True,
|
|
||||||
)
|
|
||||||
print("ONNX 导出成功!")
|
|
||||||
|
|
||||||
# 4. 测试训练模式(batch=4)
|
|
||||||
model.train()
|
|
||||||
batch_input_ids = torch.randint(0, 100, (4, 64))
|
|
||||||
batch_attention_mask = torch.ones_like(batch_input_ids)
|
|
||||||
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
|
|
||||||
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
|
|
||||||
print("训练模式输出形状:", logits.shape) # [4, 10018]
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue