feat: 优化模型输入处理与专家数量,增强训练与推理兼容性
This commit is contained in:
parent
9fad2bf1d4
commit
e91f823d65
|
|
@ -424,8 +424,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# Tokenize
|
# Tokenize
|
||||||
hint = self.tokenizer(
|
hint = self.tokenizer(
|
||||||
sampled_context,
|
sampled_context + processed_pinyin,
|
||||||
processed_pinyin,
|
|
||||||
max_length=self.max_len,
|
max_length=self.max_len,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||||
logger.info("数据集初始化")
|
logger.info("数据集初始化")
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=2,
|
batch_size=1024,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
pin_memory=True if torch.cuda.is_available() else False,
|
pin_memory=True if torch.cuda.is_available() else False,
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -348,13 +348,13 @@ class MoEModel(nn.Module):
|
||||||
labels = batch["char_id"].to(self.device)
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
logits = self(input_ids, attention_mask, pg)
|
probs = self(input_ids, attention_mask, pg)
|
||||||
log_probs = torch.log(probs + 1e-12)
|
log_probs = torch.log(probs + 1e-12)
|
||||||
loss = nn.NLLLoss()(log_probs, labels)
|
loss = nn.NLLLoss()(log_probs, labels)
|
||||||
total_loss += loss.item() * labels.size(0)
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
# 计算准确率
|
# 计算准确率
|
||||||
preds = logits.argmax(dim=-1)
|
preds = probs.argmax(dim=-1)
|
||||||
correct += (preds == labels).sum().item()
|
correct += (preds == labels).sum().item()
|
||||||
total += labels.size(0)
|
total += labels.size(0)
|
||||||
|
|
||||||
|
|
@ -388,8 +388,8 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# ------------------ 1. 提取并规范化输入 ------------------
|
# ------------------ 1. 提取并规范化输入 ------------------
|
||||||
# 判断是否为单样本(input_ids 无 batch 维度)
|
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample['hint']["input_ids"]
|
||||||
attention_mask = sample["attention_mask"]
|
attention_mask = sample['hint']["attention_mask"]
|
||||||
pg = sample["pg"]
|
pg = sample["pg"]
|
||||||
has_batch_dim = input_ids.dim() > 1
|
has_batch_dim = input_ids.dim() > 1
|
||||||
|
|
||||||
|
|
@ -577,6 +577,37 @@ class MoEModel(nn.Module):
|
||||||
)
|
)
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
|
|
||||||
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
|
state_dict = torch.load(
|
||||||
|
state_dict_path, weights_only=True, map_location=self.device
|
||||||
|
)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def load_from_pretrained_base_model(
|
||||||
|
self,
|
||||||
|
BaseModel,
|
||||||
|
snapshot_path: Union[str, Path],
|
||||||
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
base_model = BaseModel(*args, **kwargs)
|
||||||
|
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||||||
|
self_static_dict = self.state_dict()
|
||||||
|
pretrained_dict = base_model.state_dict()
|
||||||
|
|
||||||
|
freeze_layers = []
|
||||||
|
|
||||||
|
for key in self_static_dict.keys():
|
||||||
|
if key in pretrained_dict.keys():
|
||||||
|
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||||||
|
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||||||
|
freeze_layers.append(key)
|
||||||
|
self.load_state_dict(self_static_dict)
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if name in freeze_layers:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
# ============================ 使用示例 ============================
|
# ============================ 使用示例 ============================
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,43 @@ def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset
|
||||||
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||||||
|
|
||||||
|
|
||||||
|
def round_to_power_of_two(x):
|
||||||
|
if x < 1:
|
||||||
|
return 0
|
||||||
|
n = x.bit_length()
|
||||||
|
n = min(max(7, n), 9)
|
||||||
|
lower = 1 << (n) # 小于等于x的最大2的幂次
|
||||||
|
upper = lower << 1 # 大于x的最小2的幂次
|
||||||
|
if x - lower < upper - x:
|
||||||
|
return lower
|
||||||
|
else:
|
||||||
|
return upper
|
||||||
|
|
||||||
|
|
||||||
|
EXPORT_HIDE_DIM = {
|
||||||
|
0: 1024,
|
||||||
|
1: 1024,
|
||||||
|
2: 1024,
|
||||||
|
3: 512,
|
||||||
|
4: 512,
|
||||||
|
5: 512,
|
||||||
|
6: 512,
|
||||||
|
7: 512,
|
||||||
|
8: 512,
|
||||||
|
9: 512,
|
||||||
|
10: 512,
|
||||||
|
11: 512,
|
||||||
|
12: 512,
|
||||||
|
13: 512,
|
||||||
|
14: 512,
|
||||||
|
15: 512,
|
||||||
|
16: 512,
|
||||||
|
17: 512,
|
||||||
|
18: 512,
|
||||||
|
19: 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
# ---------------------------- 残差块 ----------------------------
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, dim, dropout_prob=0.1):
|
def __init__(self, dim, dropout_prob=0.1):
|
||||||
|
|
@ -93,7 +130,8 @@ class MoEModel(nn.Module):
|
||||||
output_multiplier=2,
|
output_multiplier=2,
|
||||||
d_model=768,
|
d_model=768,
|
||||||
num_resblocks=4,
|
num_resblocks=4,
|
||||||
num_domain_experts=23,
|
num_domain_experts=20,
|
||||||
|
experts_dim=EXPORT_HIDE_DIM,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.output_multiplier = output_multiplier
|
self.output_multiplier = output_multiplier
|
||||||
|
|
@ -104,30 +142,27 @@ class MoEModel(nn.Module):
|
||||||
self.bert_config = bert.config
|
self.bert_config = bert.config
|
||||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||||
self.device = None # 将在 to() 调用时设置
|
self.device = None # 将在 to() 调用时设置
|
||||||
|
self.experts_dim = experts_dim
|
||||||
self.linear = nn.Linear(256, d_model)
|
|
||||||
|
|
||||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
d_model=self.hidden_size,
|
d_model=self.hidden_size,
|
||||||
nhead=self.bert_config.num_attention_heads,
|
nhead=8,
|
||||||
dim_feedforward=self.bert_config.intermediate_size,
|
dim_feedforward=self.bert_config.intermediate_size,
|
||||||
dropout=self.bert_config.hidden_dropout_prob,
|
dropout=self.bert_config.hidden_dropout_prob,
|
||||||
activation="gelu",
|
activation="gelu",
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
norm_first=True, # Pre-LN,与预训练一致
|
|
||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||||
|
|
||||||
self.total_experts = 23
|
self.total_experts = 20
|
||||||
self.experts = nn.ModuleList()
|
self.experts = nn.ModuleList()
|
||||||
|
|
||||||
for i in range(self.total_experts):
|
for i in range(self.total_experts):
|
||||||
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
|
|
||||||
expert = Expert(
|
expert = Expert(
|
||||||
input_dim=self.hidden_size * 2,
|
input_dim=self.hidden_size,
|
||||||
d_model=d_model,
|
d_model=self.experts_dim[i],
|
||||||
num_resblocks=num_resblocks,
|
num_resblocks=num_resblocks,
|
||||||
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
||||||
dropout_prob=0.1,
|
dropout_prob=0.1,
|
||||||
|
|
@ -146,11 +181,6 @@ class MoEModel(nn.Module):
|
||||||
self.output_multiplier * self.hidden_size,
|
self.output_multiplier * self.hidden_size,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(
|
|
||||||
self.output_multiplier * self.hidden_size,
|
|
||||||
self.output_multiplier * self.hidden_size,
|
|
||||||
),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
self.output_multiplier * self.hidden_size,
|
self.output_multiplier * self.hidden_size,
|
||||||
self.output_multiplier * self.hidden_size * 2,
|
self.output_multiplier * self.hidden_size * 2,
|
||||||
|
|
@ -267,22 +297,10 @@ class MoEModel(nn.Module):
|
||||||
expert_out = self.experts[18](pooled) + self.expert_bias(
|
expert_out = self.experts[18](pooled) + self.expert_bias(
|
||||||
torch.tensor(18, device=pooled.device)
|
torch.tensor(18, device=pooled.device)
|
||||||
)
|
)
|
||||||
elif group_id == 19: # group_id == 19
|
else: # group_id == 19
|
||||||
expert_out = self.experts[19](pooled) + self.expert_bias(
|
expert_out = self.experts[19](pooled) + self.expert_bias(
|
||||||
torch.tensor(19, device=pooled.device)
|
torch.tensor(19, device=pooled.device)
|
||||||
)
|
)
|
||||||
elif group_id == 20: # group_id == 20
|
|
||||||
expert_out = self.experts[20](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(20, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 21: # group_id == 21
|
|
||||||
expert_out = self.experts[21](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(21, device=pooled.device)
|
|
||||||
)
|
|
||||||
elif group_id == 22: # group_id == 22
|
|
||||||
expert_out = self.experts[22](pooled) + self.expert_bias(
|
|
||||||
torch.tensor(22, device=pooled.device)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
batch_size = pooled.size(0)
|
batch_size = pooled.size(0)
|
||||||
# 并行计算所有专家输出
|
# 并行计算所有专家输出
|
||||||
|
|
@ -330,12 +348,13 @@ class MoEModel(nn.Module):
|
||||||
labels = batch["char_id"].to(self.device)
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
logits = self(input_ids, attention_mask, pg)
|
probs = self(input_ids, attention_mask, pg)
|
||||||
loss = criterion(logits, labels)
|
log_probs = torch.log(probs + 1e-12)
|
||||||
|
loss = nn.NLLLoss()(log_probs, labels)
|
||||||
total_loss += loss.item() * labels.size(0)
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
# 计算准确率
|
# 计算准确率
|
||||||
preds = logits.argmax(dim=-1)
|
preds = probs.argmax(dim=-1)
|
||||||
correct += (preds == labels).sum().item()
|
correct += (preds == labels).sum().item()
|
||||||
total += labels.size(0)
|
total += labels.size(0)
|
||||||
|
|
||||||
|
|
@ -558,38 +577,33 @@ class MoEModel(nn.Module):
|
||||||
)
|
)
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
|
|
||||||
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
# ============================ 使用示例 ============================
|
state_dict = torch.load(
|
||||||
if __name__ == "__main__":
|
state_dict_path, weights_only=True, map_location=self.device
|
||||||
# 1. 初始化模型
|
|
||||||
model = MoEModel()
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# 2. 构造 dummy 输入(batch=1,用于导出 ONNX)
|
|
||||||
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
|
|
||||||
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
|
|
||||||
|
|
||||||
# 3. 导出 ONNX(使用条件分支,仅计算一个专家)
|
|
||||||
torch.onnx.export(
|
|
||||||
model,
|
|
||||||
(dummy_input_ids, dummy_attention_mask, dummy_pg),
|
|
||||||
"moe_cpu.onnx",
|
|
||||||
input_names=["input_ids", "attention_mask", "pg"],
|
|
||||||
output_names=["logits"],
|
|
||||||
dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes
|
|
||||||
"input_ids": {0: "batch"},
|
|
||||||
"attention_mask": {0: "batch"},
|
|
||||||
},
|
|
||||||
opset_version=12,
|
|
||||||
do_constant_folding=True,
|
|
||||||
)
|
)
|
||||||
print("ONNX 导出成功!")
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# 4. 测试训练模式(batch=4)
|
def load_from_pretrained_base_model(
|
||||||
model.train()
|
self,
|
||||||
batch_input_ids = torch.randint(0, 100, (4, 64))
|
BaseModel,
|
||||||
batch_attention_mask = torch.ones_like(batch_input_ids)
|
snapshot_path: Union[str, Path],
|
||||||
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
|
*args,
|
||||||
print("训练模式输出形状:", logits.shape) # [4, 10018]
|
**kwargs,
|
||||||
|
):
|
||||||
|
base_model = BaseModel(*args, **kwargs)
|
||||||
|
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||||||
|
self_static_dict = self.state_dict()
|
||||||
|
pretrained_dict = base_model.state_dict()
|
||||||
|
|
||||||
|
freeze_layers = []
|
||||||
|
|
||||||
|
for key in self_static_dict.keys():
|
||||||
|
if key in pretrained_dict.keys():
|
||||||
|
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||||||
|
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||||||
|
freeze_layers.append(key)
|
||||||
|
self.load_state_dict(self_static_dict)
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if name in freeze_layers:
|
||||||
|
param.requires_grad = False
|
||||||
Loading…
Reference in New Issue