feat: 重构拼音输入数据集与 MoE 模型结构,优化专家网络配置及评估逻辑
This commit is contained in:
parent
7eb00c6207
commit
134c8a09cf
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Tuple, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 加载拼音分组
|
# 加载拼音分组
|
||||||
self.pg_groups = {
|
self.pg_groups = {
|
||||||
"y": 0,
|
"y": 0,
|
||||||
"k": 0,
|
"z": 1,
|
||||||
"e": 0,
|
"j": 2,
|
||||||
"l": 1,
|
"l": 3,
|
||||||
"w": 1,
|
"s": 4,
|
||||||
"f": 1,
|
"x": 5,
|
||||||
"q": 2,
|
"c": 6,
|
||||||
"a": 2,
|
|
||||||
"s": 2,
|
|
||||||
"x": 3,
|
|
||||||
"b": 3,
|
|
||||||
"r": 3,
|
|
||||||
"o": 4,
|
|
||||||
"m": 4,
|
|
||||||
"z": 4,
|
|
||||||
"g": 5,
|
|
||||||
"n": 5,
|
|
||||||
"c": 5,
|
|
||||||
"t": 6,
|
|
||||||
"p": 6,
|
|
||||||
"d": 6,
|
|
||||||
"j": 7,
|
|
||||||
"h": 7,
|
"h": 7,
|
||||||
|
"d": 8,
|
||||||
|
"b": 9,
|
||||||
|
"q": 10,
|
||||||
|
"g": 11,
|
||||||
|
"t": 12,
|
||||||
|
"m": 13,
|
||||||
|
"p": 14,
|
||||||
|
"w": 15,
|
||||||
|
"f": 16,
|
||||||
|
"k": 17,
|
||||||
|
"n": 18,
|
||||||
|
"r": 19,
|
||||||
|
"a": 19,
|
||||||
|
"e": 18,
|
||||||
|
"o": 17,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
|
|
@ -440,9 +440,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"char_id": torch.tensor([char_info["id"]]),
|
"char_id": torch.tensor([char_info["id"]]),
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
"pg": torch.tensor(
|
"pg": torch.tensor([self.pg_groups[char_info.pinyin[0]]]),
|
||||||
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据调整因子重复样本
|
# 根据调整因子重复样本
|
||||||
|
|
@ -480,7 +478,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
seed = base_seed + worker_id
|
seed = base_seed + worker_id
|
||||||
random.seed(seed % (2**32))
|
random.seed(seed % (2**32))
|
||||||
np.random.seed(seed % (2**32))
|
np.random.seed(seed % (2**32))
|
||||||
|
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
for item in self.dataset:
|
for item in self.dataset:
|
||||||
text = item.get(self.text_field, "")
|
text = item.get(self.text_field, "")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import pickle
|
import pickle
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.amp as amp
|
import torch.amp as amp
|
||||||
|
|
@ -12,10 +14,46 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from .monitor import TrainingMonitor
|
from .monitor import TrainingMonitor
|
||||||
|
|
||||||
EVAL_DATALOADER = [
|
|
||||||
pickle.load(file.open("rb"))
|
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||||||
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
|
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||||||
]
|
|
||||||
|
|
||||||
|
def round_to_power_of_two(x):
|
||||||
|
if x < 1:
|
||||||
|
return 0
|
||||||
|
n = x.bit_length()
|
||||||
|
n = min(max(7, n), 9)
|
||||||
|
lower = 1 << (n) # 小于等于x的最大2的幂次
|
||||||
|
upper = lower << 1 # 大于x的最小2的幂次
|
||||||
|
if x - lower < upper - x:
|
||||||
|
return lower
|
||||||
|
else:
|
||||||
|
return upper
|
||||||
|
|
||||||
|
|
||||||
|
EXPORT_HIDE_DIM = {
|
||||||
|
0: 1024,
|
||||||
|
1: 1024,
|
||||||
|
2: 1024,
|
||||||
|
3: 512,
|
||||||
|
4: 512,
|
||||||
|
5: 512,
|
||||||
|
6: 512,
|
||||||
|
7: 512,
|
||||||
|
8: 512,
|
||||||
|
9: 512,
|
||||||
|
10: 512,
|
||||||
|
11: 512,
|
||||||
|
12: 512,
|
||||||
|
13: 512,
|
||||||
|
14: 512,
|
||||||
|
15: 512,
|
||||||
|
16: 512,
|
||||||
|
17: 512,
|
||||||
|
18: 512,
|
||||||
|
19: 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
# ---------------------------- 残差块 ----------------------------
|
||||||
|
|
@ -92,8 +130,8 @@ class MoEModel(nn.Module):
|
||||||
output_multiplier=2,
|
output_multiplier=2,
|
||||||
d_model=768,
|
d_model=768,
|
||||||
num_resblocks=4,
|
num_resblocks=4,
|
||||||
num_domain_experts=8,
|
num_domain_experts=20,
|
||||||
num_shared_experts=1,
|
experts_dim=EXPORT_HIDE_DIM,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.output_multiplier = output_multiplier
|
self.output_multiplier = output_multiplier
|
||||||
|
|
@ -104,38 +142,35 @@ class MoEModel(nn.Module):
|
||||||
self.bert_config = bert.config
|
self.bert_config = bert.config
|
||||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||||
self.device = None # 将在 to() 调用时设置
|
self.device = None # 将在 to() 调用时设置
|
||||||
|
self.experts_dim = experts_dim
|
||||||
|
|
||||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
d_model=self.hidden_size,
|
d_model=self.hidden_size,
|
||||||
nhead=self.bert_config.num_attention_heads,
|
nhead=8,
|
||||||
dim_feedforward=self.bert_config.intermediate_size,
|
dim_feedforward=self.bert_config.intermediate_size,
|
||||||
dropout=self.bert_config.hidden_dropout_prob,
|
dropout=self.bert_config.hidden_dropout_prob,
|
||||||
activation="gelu",
|
activation="gelu",
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
norm_first=True, # Pre-LN,与预训练一致
|
|
||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||||
|
|
||||||
# 3. 专家层:8个领域专家 + 1个共享专家
|
self.total_experts = 20
|
||||||
total_experts = num_domain_experts + num_shared_experts
|
|
||||||
self.experts = nn.ModuleList()
|
self.experts = nn.ModuleList()
|
||||||
|
|
||||||
for i in range(total_experts):
|
for i in range(self.total_experts):
|
||||||
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
|
|
||||||
dropout_prob = 0.1 if i < num_domain_experts else 0.2
|
|
||||||
expert = Expert(
|
expert = Expert(
|
||||||
input_dim=self.hidden_size,
|
input_dim=self.hidden_size,
|
||||||
d_model=d_model,
|
d_model=self.experts_dim[i],
|
||||||
num_resblocks=num_resblocks,
|
num_resblocks=num_resblocks,
|
||||||
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
||||||
dropout_prob=dropout_prob,
|
dropout_prob=0.1,
|
||||||
)
|
)
|
||||||
self.experts.append(expert)
|
self.experts.append(expert)
|
||||||
|
|
||||||
self.expert_bias = nn.Embedding(
|
self.expert_bias = nn.Embedding(
|
||||||
total_experts, self.output_multiplier * self.hidden_size
|
self.total_experts, self.output_multiplier * self.hidden_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 分类头
|
# 4. 分类头
|
||||||
|
|
@ -146,11 +181,6 @@ class MoEModel(nn.Module):
|
||||||
self.output_multiplier * self.hidden_size,
|
self.output_multiplier * self.hidden_size,
|
||||||
),
|
),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(
|
|
||||||
self.output_multiplier * self.hidden_size,
|
|
||||||
self.output_multiplier * self.hidden_size,
|
|
||||||
),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
self.output_multiplier * self.hidden_size,
|
self.output_multiplier * self.hidden_size,
|
||||||
self.output_multiplier * self.hidden_size * 2,
|
self.output_multiplier * self.hidden_size * 2,
|
||||||
|
|
@ -223,11 +253,54 @@ class MoEModel(nn.Module):
|
||||||
expert_out = self.experts[7](pooled) + self.expert_bias(
|
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||||||
torch.tensor(7, device=pooled.device)
|
torch.tensor(7, device=pooled.device)
|
||||||
)
|
)
|
||||||
else: # group_id == 8
|
elif group_id == 8: # group_id == 8
|
||||||
expert_out = self.experts[8](pooled) + self.expert_bias(
|
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||||||
torch.tensor(8, device=pooled.device)
|
torch.tensor(8, device=pooled.device)
|
||||||
)
|
)
|
||||||
|
elif group_id == 9: # group_id == 9
|
||||||
|
expert_out = self.experts[9](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(9, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 10: # group_id == 10
|
||||||
|
expert_out = self.experts[10](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(10, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 11: # group_id == 11
|
||||||
|
expert_out = self.experts[11](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(11, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 12: # group_id == 12
|
||||||
|
expert_out = self.experts[12](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(12, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 13: # group_id == 13
|
||||||
|
expert_out = self.experts[13](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(13, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 14: # group_id == 14
|
||||||
|
expert_out = self.experts[14](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(14, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 15: # group_id == 15
|
||||||
|
expert_out = self.experts[15](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(15, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 16: # group_id == 16
|
||||||
|
expert_out = self.experts[16](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(16, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 17: # group_id == 17
|
||||||
|
expert_out = self.experts[17](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(17, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 18: # group_id == 18
|
||||||
|
expert_out = self.experts[18](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(18, device=pooled.device)
|
||||||
|
)
|
||||||
|
else: # group_id == 19
|
||||||
|
expert_out = self.experts[19](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(19, device=pooled.device)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = pooled.size(0)
|
batch_size = pooled.size(0)
|
||||||
# 并行计算所有专家输出
|
# 并行计算所有专家输出
|
||||||
|
|
@ -387,7 +460,7 @@ class MoEModel(nn.Module):
|
||||||
self,
|
self,
|
||||||
train_dataloader,
|
train_dataloader,
|
||||||
eval_dataloader=None,
|
eval_dataloader=None,
|
||||||
monitor: TrainingMonitor = None,
|
monitor: Optional[TrainingMonitor] = None,
|
||||||
criterion=nn.CrossEntropyLoss(),
|
criterion=nn.CrossEntropyLoss(),
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
|
|
@ -402,11 +475,31 @@ class MoEModel(nn.Module):
|
||||||
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
... 原有参数 ...
|
train_dataloader: DataLoader
|
||||||
lr_schedule : callable, optional
|
训练数据加载器。
|
||||||
|
eval_dataloader: DataLoader, optional
|
||||||
|
评估数据加载器。
|
||||||
|
monitor: TrainingMonitor, optional
|
||||||
|
训练监控器。
|
||||||
|
criterion: nn.Module, optional
|
||||||
|
损失函数。
|
||||||
|
optimizer: optim.Optimizer, optional
|
||||||
|
优化器。
|
||||||
|
num_epochs: int, optional
|
||||||
|
训练轮数。
|
||||||
|
eval_frequency: int, optional
|
||||||
|
评估频率。
|
||||||
|
grad_accum_steps: int, optional
|
||||||
|
梯度累积步数。
|
||||||
|
clip_grad_norm: float, optional
|
||||||
|
梯度裁剪范数。
|
||||||
|
mixed_precision: bool, optional
|
||||||
|
是否使用混合精度。
|
||||||
|
lr: float, optional
|
||||||
|
初始学习率。
|
||||||
|
lr_schedule : callable, optional
|
||||||
自定义学习率调度函数,接收参数 (processed_batches, optimizer),
|
自定义学习率调度函数,接收参数 (processed_batches, optimizer),
|
||||||
可在内部直接修改 optimizer.param_groups 中的学习率。
|
可在内部直接修改 optimizer.param_groups 中的学习率。
|
||||||
若为 None,则启用内置的固定阈值调度(前1000批 1e-4,之后 6e-6)。
|
|
||||||
"""
|
"""
|
||||||
# 确保模型在正确的设备上
|
# 确保模型在正确的设备上
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
|
|
@ -471,7 +564,7 @@ class MoEModel(nn.Module):
|
||||||
and global_step % eval_frequency == 0
|
and global_step % eval_frequency == 0
|
||||||
):
|
):
|
||||||
avg_loss = batch_loss_sum / eval_frequency
|
avg_loss = batch_loss_sum / eval_frequency
|
||||||
acc, _ = self.model_eval(eval_dataloader, criterion)
|
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||||||
super().train()
|
super().train()
|
||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
monitor.add_step(
|
monitor.add_step(
|
||||||
|
|
@ -479,7 +572,7 @@ class MoEModel(nn.Module):
|
||||||
{"loss": avg_loss, "acc": acc},
|
{"loss": avg_loss, "acc": acc},
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}"
|
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
|
||||||
)
|
)
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,595 @@
|
||||||
|
import pickle
|
||||||
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
from loguru import logger
|
||||||
|
from modelscope import AutoModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .monitor import TrainingMonitor
|
||||||
|
|
||||||
|
|
||||||
|
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||||||
|
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 残差块 ----------------------------
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, dim, dropout_prob=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(dim, dim)
|
||||||
|
self.ln1 = nn.LayerNorm(dim)
|
||||||
|
self.linear2 = nn.Linear(dim, dim)
|
||||||
|
self.ln2 = nn.LayerNorm(dim)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = self.relu(self.linear1(x))
|
||||||
|
x = self.ln1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = self.ln2(x)
|
||||||
|
x = self.dropout(x) # 残差前加 Dropout(符合原描述)
|
||||||
|
x = x + residual
|
||||||
|
return self.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 专家网络 ----------------------------
|
||||||
|
class Expert(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
d_model=1024,
|
||||||
|
num_resblocks=4,
|
||||||
|
output_multiplier=2,
|
||||||
|
dropout_prob=0.1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
input_dim : BERT 输出的 hidden_size(如 312/768)
|
||||||
|
d_model : 专家内部维度(固定 1024)
|
||||||
|
output_multiplier : 输出维度 = input_dim * output_multiplier
|
||||||
|
dropout_prob : 残差块内 Dropout
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.d_model = d_model
|
||||||
|
self.output_dim = input_dim * output_multiplier
|
||||||
|
|
||||||
|
# 输入映射:input_dim -> d_model
|
||||||
|
self.linear_in = nn.Linear(input_dim, d_model)
|
||||||
|
|
||||||
|
# 残差堆叠
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出映射:d_model -> output_dim
|
||||||
|
self.output = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(d_model, self.output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_in(x)
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
return self.output(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 主模型(MoE + 硬路由)------------------------
|
||||||
|
class MoEModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
|
||||||
|
num_classes=10018,
|
||||||
|
output_multiplier=2,
|
||||||
|
d_model=768,
|
||||||
|
num_resblocks=4,
|
||||||
|
num_domain_experts=23,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_multiplier = output_multiplier
|
||||||
|
|
||||||
|
# 1. 加载预训练 BERT,仅保留 embeddings
|
||||||
|
bert = AutoModel.from_pretrained(pretrained_model_name)
|
||||||
|
self.embedding = bert.embeddings
|
||||||
|
self.bert_config = bert.config
|
||||||
|
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||||
|
self.device = None # 将在 to() 调用时设置
|
||||||
|
|
||||||
|
self.linear = nn.Linear(256, d_model)
|
||||||
|
|
||||||
|
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
nhead=self.bert_config.num_attention_heads,
|
||||||
|
dim_feedforward=self.bert_config.intermediate_size,
|
||||||
|
dropout=self.bert_config.hidden_dropout_prob,
|
||||||
|
activation="gelu",
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True, # Pre-LN,与预训练一致
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
|
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||||
|
|
||||||
|
self.total_experts = 23
|
||||||
|
self.experts = nn.ModuleList()
|
||||||
|
|
||||||
|
for i in range(self.total_experts):
|
||||||
|
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
|
||||||
|
expert = Expert(
|
||||||
|
input_dim=self.hidden_size * 2,
|
||||||
|
d_model=d_model,
|
||||||
|
num_resblocks=num_resblocks,
|
||||||
|
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
|
||||||
|
dropout_prob=0.1,
|
||||||
|
)
|
||||||
|
self.experts.append(expert)
|
||||||
|
|
||||||
|
self.expert_bias = nn.Embedding(
|
||||||
|
self.total_experts, self.output_multiplier * self.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 分类头
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.LayerNorm(self.output_multiplier * self.hidden_size),
|
||||||
|
nn.Linear(
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(
|
||||||
|
self.output_multiplier * self.hidden_size,
|
||||||
|
self.output_multiplier * self.hidden_size * 2,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
|
||||||
|
)
|
||||||
|
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""重写 to 方法,记录设备"""
|
||||||
|
self.device = device
|
||||||
|
return super().to(device)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask, pg):
|
||||||
|
"""
|
||||||
|
input_ids : [batch, seq_len]
|
||||||
|
attention_mask: [batch, seq_len] (1 为有效,0 为 padding)
|
||||||
|
pg : group_id,训练时为 [batch] 的 LongTensor,推理导出时为标量 Tensor
|
||||||
|
"""
|
||||||
|
# ----- 1. Embeddings -----
|
||||||
|
embeddings = self.embedding(input_ids) # [B, S, H]
|
||||||
|
|
||||||
|
# ----- 2. Transformer Encoder -----
|
||||||
|
# padding mask: True 表示忽略该位置
|
||||||
|
padding_mask = attention_mask == 0
|
||||||
|
encoded = self.encoder(
|
||||||
|
embeddings, src_key_padding_mask=padding_mask
|
||||||
|
) # [B, S, H]
|
||||||
|
|
||||||
|
# ----- 3. 池化量 -----
|
||||||
|
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||||
|
|
||||||
|
# ----- 4. 专家路由(硬路由)-----
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
# ------------------ ONNX 导出模式:条件分支(batch=1)------------------
|
||||||
|
# 此时 pg 为标量 Tensor,转换为 Python int
|
||||||
|
group_id = pg.item() if torch.is_tensor(pg) else pg
|
||||||
|
|
||||||
|
if group_id == 0:
|
||||||
|
expert_out = self.experts[0](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(0, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 1:
|
||||||
|
expert_out = self.experts[1](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(1, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 2:
|
||||||
|
expert_out = self.experts[2](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(2, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 3:
|
||||||
|
expert_out = self.experts[3](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(3, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 4:
|
||||||
|
expert_out = self.experts[4](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(4, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 5:
|
||||||
|
expert_out = self.experts[5](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(5, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 6:
|
||||||
|
expert_out = self.experts[6](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(6, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 7:
|
||||||
|
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(7, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 8: # group_id == 8
|
||||||
|
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(8, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 9: # group_id == 9
|
||||||
|
expert_out = self.experts[9](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(9, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 10: # group_id == 10
|
||||||
|
expert_out = self.experts[10](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(10, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 11: # group_id == 11
|
||||||
|
expert_out = self.experts[11](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(11, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 12: # group_id == 12
|
||||||
|
expert_out = self.experts[12](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(12, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 13: # group_id == 13
|
||||||
|
expert_out = self.experts[13](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(13, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 14: # group_id == 14
|
||||||
|
expert_out = self.experts[14](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(14, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 15: # group_id == 15
|
||||||
|
expert_out = self.experts[15](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(15, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 16: # group_id == 16
|
||||||
|
expert_out = self.experts[16](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(16, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 17: # group_id == 17
|
||||||
|
expert_out = self.experts[17](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(17, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 18: # group_id == 18
|
||||||
|
expert_out = self.experts[18](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(18, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 19: # group_id == 19
|
||||||
|
expert_out = self.experts[19](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(19, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 20: # group_id == 20
|
||||||
|
expert_out = self.experts[20](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(20, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 21: # group_id == 21
|
||||||
|
expert_out = self.experts[21](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(21, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 22: # group_id == 22
|
||||||
|
expert_out = self.experts[22](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(22, device=pooled.device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_size = pooled.size(0)
|
||||||
|
# 并行计算所有专家输出
|
||||||
|
expert_outputs = torch.stack(
|
||||||
|
[e(pooled) for e in self.experts], dim=0
|
||||||
|
) # [E, B, D]
|
||||||
|
# 根据 pg 索引专家输出
|
||||||
|
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
|
||||||
|
# 添加专家偏置
|
||||||
|
bias = self.expert_bias(pg) # [B, D]
|
||||||
|
expert_out = expert_out + bias
|
||||||
|
|
||||||
|
# ----- 5. 分类头 -----
|
||||||
|
logits = self.classifier(expert_out) # [batch, num_classes]
|
||||||
|
if not self.training: # 推理时加 Softmax
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
return probs
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def model_eval(self, eval_dataloader, criterion=None):
|
||||||
|
"""
|
||||||
|
在验证集上评估模型,返回准确率和平均损失。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id'
|
||||||
|
criterion: 损失函数,默认为 CrossEntropyLoss()
|
||||||
|
返回:
|
||||||
|
accuracy: float, 准确率
|
||||||
|
avg_loss: float, 平均损失
|
||||||
|
"""
|
||||||
|
if criterion is None:
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in eval_dataloader:
|
||||||
|
# 移动数据到模型设备
|
||||||
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
|
pg = batch["pg"].to(self.device)
|
||||||
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
logits = self(input_ids, attention_mask, pg)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
preds = logits.argmax(dim=-1)
|
||||||
|
correct += (preds == labels).sum().item()
|
||||||
|
total += labels.size(0)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total if total > 0 else 0.0
|
||||||
|
accuracy = correct / total if total > 0 else 0.0
|
||||||
|
return accuracy, avg_loss
|
||||||
|
|
||||||
|
def predict(self, sample, debug=False):
|
||||||
|
"""
|
||||||
|
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
sample : dict
|
||||||
|
必须包含字段:
|
||||||
|
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
|
||||||
|
- 'attention_mask': 同上
|
||||||
|
- 'pg' : [batch] 或标量
|
||||||
|
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
|
||||||
|
调试时(debug=True)必须包含字段:
|
||||||
|
- 'txt' : 字符串列表(batch)或单个字符串
|
||||||
|
- 'char' : 字符串列表(batch)或单个字符串
|
||||||
|
- 'py' : 字符串列表(batch)或单个字符串
|
||||||
|
debug : bool
|
||||||
|
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
preds : torch.Tensor
|
||||||
|
[batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量)
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
# ------------------ 1. 提取并规范化输入 ------------------
|
||||||
|
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
attention_mask = sample["attention_mask"]
|
||||||
|
pg = sample["pg"]
|
||||||
|
has_batch_dim = input_ids.dim() > 1
|
||||||
|
|
||||||
|
if not has_batch_dim:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
attention_mask = attention_mask.unsqueeze(0)
|
||||||
|
if pg.dim() == 0:
|
||||||
|
pg = pg.unsqueeze(0).expand(input_ids.size(0))
|
||||||
|
|
||||||
|
# ------------------ 2. 移动设备 ------------------
|
||||||
|
input_ids = input_ids.to(self.device)
|
||||||
|
attention_mask = attention_mask.to(self.device)
|
||||||
|
pg = pg.to(self.device)
|
||||||
|
|
||||||
|
# ------------------ 3. 推理 ------------------
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self(input_ids, attention_mask, pg)
|
||||||
|
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
|
||||||
|
|
||||||
|
# ------------------ 4. 调试打印(错误样本) ------------------
|
||||||
|
if debug:
|
||||||
|
# 检查必需字段
|
||||||
|
required_keys = ["char_id", "txt", "char", "py"]
|
||||||
|
missing = [k for k in required_keys if k not in sample]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
|
||||||
|
|
||||||
|
# 提取真实标签
|
||||||
|
true_labels = sample["char_id"]
|
||||||
|
if true_labels.dim() == 0:
|
||||||
|
true_labels = true_labels.unsqueeze(0)
|
||||||
|
|
||||||
|
# 移动真实标签到相同设备
|
||||||
|
true_labels = true_labels.to(self.device)
|
||||||
|
|
||||||
|
# 找出预测错误的索引
|
||||||
|
incorrect_mask = preds != true_labels
|
||||||
|
incorrect_indices = torch.where(incorrect_mask)[0]
|
||||||
|
|
||||||
|
if len(incorrect_indices) > 0:
|
||||||
|
print("\n=== 预测错误样本 ===")
|
||||||
|
# 获取调试字段(可能是列表或单个字符串)
|
||||||
|
txts = sample["txt"]
|
||||||
|
chars = sample["char"]
|
||||||
|
pys = sample["py"]
|
||||||
|
|
||||||
|
# 统一转换为列表(如果输入是单个字符串)
|
||||||
|
if isinstance(txts, str):
|
||||||
|
txts = [txts]
|
||||||
|
chars = [chars]
|
||||||
|
pys = [pys]
|
||||||
|
|
||||||
|
for idx in incorrect_indices.cpu().numpy():
|
||||||
|
print(f"样本索引 {idx}:")
|
||||||
|
print(f" Text : {txts[idx]}")
|
||||||
|
print(f" Char : {chars[idx]}")
|
||||||
|
print(f" Pinyin: {pys[idx]}")
|
||||||
|
print(
|
||||||
|
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
|
||||||
|
)
|
||||||
|
print("===================\n")
|
||||||
|
|
||||||
|
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
|
||||||
|
if not has_batch_dim:
|
||||||
|
return preds.squeeze(0) # 返回标量
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
monitor: Optional[TrainingMonitor] = None,
|
||||||
|
criterion=nn.CrossEntropyLoss(),
|
||||||
|
optimizer=None,
|
||||||
|
num_epochs=1,
|
||||||
|
eval_frequency=500,
|
||||||
|
grad_accum_steps=1,
|
||||||
|
clip_grad_norm=1.0,
|
||||||
|
mixed_precision=False,
|
||||||
|
lr=1e-4,
|
||||||
|
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
train_dataloader: DataLoader
|
||||||
|
训练数据加载器。
|
||||||
|
eval_dataloader: DataLoader, optional
|
||||||
|
评估数据加载器。
|
||||||
|
monitor: TrainingMonitor, optional
|
||||||
|
训练监控器。
|
||||||
|
criterion: nn.Module, optional
|
||||||
|
损失函数。
|
||||||
|
optimizer: optim.Optimizer, optional
|
||||||
|
优化器。
|
||||||
|
num_epochs: int, optional
|
||||||
|
训练轮数。
|
||||||
|
eval_frequency: int, optional
|
||||||
|
评估频率。
|
||||||
|
grad_accum_steps: int, optional
|
||||||
|
梯度累积步数。
|
||||||
|
clip_grad_norm: float, optional
|
||||||
|
梯度裁剪范数。
|
||||||
|
mixed_precision: bool, optional
|
||||||
|
是否使用混合精度。
|
||||||
|
lr: float, optional
|
||||||
|
初始学习率。
|
||||||
|
lr_schedule : callable, optional
|
||||||
|
自定义学习率调度函数,接收参数 (processed_batches, optimizer),
|
||||||
|
可在内部直接修改 optimizer.param_groups 中的学习率。
|
||||||
|
"""
|
||||||
|
# 确保模型在正确的设备上
|
||||||
|
if self.device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
# 切换到训练模式
|
||||||
|
super().train()
|
||||||
|
|
||||||
|
# 默认优化器
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
|
||||||
|
|
||||||
|
# 混合精度缩放器
|
||||||
|
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
|
||||||
|
batch_loss_sum = 0.0
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
|
||||||
|
# ---------- 更新 batch 计数器 ----------
|
||||||
|
processed_batches += 1
|
||||||
|
|
||||||
|
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
||||||
|
if lr_schedule is not None:
|
||||||
|
# 调用用户自定义的调度函数
|
||||||
|
lr_schedule(processed_batches, optimizer)
|
||||||
|
|
||||||
|
# ---------- 移动数据 ----------
|
||||||
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
|
pg = batch["pg"].to(self.device)
|
||||||
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
# 混合精度前向
|
||||||
|
with amp.autocast(
|
||||||
|
device_type=self.device.type, enabled=mixed_precision
|
||||||
|
):
|
||||||
|
logits = self(input_ids, attention_mask, pg)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
loss = loss / grad_accum_steps
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# 梯度累积
|
||||||
|
if (batch_idx + 1) % grad_accum_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
global_step += 1
|
||||||
|
original_loss = loss.item() * grad_accum_steps
|
||||||
|
batch_loss_sum += original_loss
|
||||||
|
# 周期性评估(与原代码相同)
|
||||||
|
if (
|
||||||
|
eval_dataloader is not None
|
||||||
|
and global_step % eval_frequency == 0
|
||||||
|
):
|
||||||
|
avg_loss = batch_loss_sum / eval_frequency
|
||||||
|
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||||||
|
super().train()
|
||||||
|
if monitor is not None:
|
||||||
|
monitor.add_step(
|
||||||
|
global_step,
|
||||||
|
{"loss": avg_loss, "acc": acc},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
|
||||||
|
)
|
||||||
|
batch_loss_sum = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================ 使用示例 ============================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 1. 初始化模型
|
||||||
|
model = MoEModel()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 2. 构造 dummy 输入(batch=1,用于导出 ONNX)
|
||||||
|
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
|
||||||
|
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
|
||||||
|
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
|
||||||
|
|
||||||
|
# 3. 导出 ONNX(使用条件分支,仅计算一个专家)
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(dummy_input_ids, dummy_attention_mask, dummy_pg),
|
||||||
|
"moe_cpu.onnx",
|
||||||
|
input_names=["input_ids", "attention_mask", "pg"],
|
||||||
|
output_names=["logits"],
|
||||||
|
dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes
|
||||||
|
"input_ids": {0: "batch"},
|
||||||
|
"attention_mask": {0: "batch"},
|
||||||
|
},
|
||||||
|
opset_version=12,
|
||||||
|
do_constant_folding=True,
|
||||||
|
)
|
||||||
|
print("ONNX 导出成功!")
|
||||||
|
|
||||||
|
# 4. 测试训练模式(batch=4)
|
||||||
|
model.train()
|
||||||
|
batch_input_ids = torch.randint(0, 100, (4, 64))
|
||||||
|
batch_attention_mask = torch.ones_like(batch_input_ids)
|
||||||
|
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
|
||||||
|
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
|
||||||
|
print("训练模式输出形状:", logits.shape) # [4, 10018]
|
||||||
Loading…
Reference in New Issue