refactor(model): 使用注意力池化替换 span pooling 并支持 token_type_ids

This commit is contained in:
songsenand 2026-02-26 00:48:09 +08:00
parent 93dced50c7
commit 7c90633ebc
8 changed files with 189 additions and 212 deletions

View File

@ -56,7 +56,7 @@ class PinyinInputDataset(IterableDataset):
self,
data_dir: str,
query_engine,
tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std",
tokenizer_name: str = "iic/nlp_structbert_backbone_lite_std",
max_len: int = 88,
text_field: str = "text",
batch_query_size: int = 1000,
@ -71,7 +71,7 @@ class PinyinInputDataset(IterableDataset):
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section=[0.90, 0.95, 1],
drop_py_rate: float = 0.30,
drop_py_rate: float = 0,
):
"""
初始化数据集
@ -434,15 +434,6 @@ class PinyinInputDataset(IterableDataset):
# 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
# Tokenize
hint = self.tokenizer(
sampled_context + processed_pinyin,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12
prob = random.random()
if prob < self.drop_py_rate:
@ -450,6 +441,18 @@ class PinyinInputDataset(IterableDataset):
else:
py = processed_pinyin
# Tokenize
hint = self.tokenizer(
sampled_context,
py,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
# 生成样本
sample = {
"hint": hint,
@ -592,6 +595,7 @@ def custom_collate_with_txt(batch):
"hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
"token_type_ids": torch.cat([h["token_type_ids"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"char": [item["char"] for item in batch],

View File

@ -18,12 +18,12 @@ if __name__ == "__main__":
dataset = PinyinInputDataset(
data_dir="/home/songsenand/DataSet/data",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
drop_py_rate=0.7
drop_py_rate=0
)
logger.info("数据集初始化")
dataloader = DataLoader(

View File

@ -22,6 +22,28 @@ def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
# ---------------------------- 注意力池化模块(新增)----------------------------
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
# 三个可学习偏置:文本、拼音、个性化
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
def forward(self, x, mask=None, token_type_ids=None):
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
if token_type_ids is not None:
# 根据 token_type_ids 添加对应偏置
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
scores = scores + bias_per_token
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
return pooled
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
@ -35,6 +57,7 @@ class ResidualBlock(nn.Module):
def forward(self, x):
residual = x
# 修复:使用 self.gelu 而不是未定义的 self.relu
x = self.gelu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
@ -62,7 +85,7 @@ class Expert(nn.Module):
)
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(inplace=True),
nn.GELU(),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim),
)
@ -108,7 +131,10 @@ class MoEModel(nn.Module):
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
# 3. 专家系统
# 3. 注意力池化(新增)
self.attn_pool = AttentionPooling(self.hidden_size)
# 4. 专家系统
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList()
for i in range(total_experts):
@ -127,7 +153,7 @@ class MoEModel(nn.Module):
total_experts, self.hidden_size * self.output_multiplier
)
# 4. 分类头
# 5. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Dropout(0.4),
@ -139,80 +165,41 @@ class MoEModel(nn.Module):
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg, p_start):
def forward(self, input_ids, attention_mask, token_type_ids, pg):
"""
ONNX 兼容的 Forward 函数
新版 Forward 函数不再需要 p_start改用 token_type_ids
Args:
input_ids: [B, L]
attention_mask: [B, L]
token_type_ids: [B, L] (0=文本, 1=拼音)
pg: [B] 拼音组 ID
p_start: [B] 拼音起始索引位置 (整数 Tensor)
"""
# ----- 1. Embeddings -----
embeddings = self.embedding(input_ids)
# 注意:预训练的 embedding 层本身可能已经包含了 token_type_ids 的处理,
# 但这里我们直接使用它的 embedding并手动将 token_type_ids 的嵌入加到上面。
# 由于 bert.embeddings 通常包含 token_type_embeddings我们可以利用它。
# 但为简化,我们直接使用 bert.embeddings(input_ids, token_type_ids=token_type_ids)
# 如果当前 embedding 不支持传入 token_type_ids可以手动相加
# embeddings = self.embedding(input_ids) + self.embedding.token_type_embeddings(token_type_ids)
# 这里采用更通用的方式:假设 self.embedding 有 token_type_ids 参数
embeddings = self.embedding(input_ids, token_type_ids=token_type_ids)
# ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置
padding_mask = attention_mask == 0
encoded = self.encoder(
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
# ----- 3. ONNX 兼容的 Span Pooling (向量化实现) -----
"""
思路
我们不能用循环去切片我们要构造一个 Mask 矩阵
目标对于每个样本 i生成一个长度为 L 的向量其中 p_start[i] < index < p_end[i] 的位置为 1其余为 0
步骤
1. 生成位置索引轴indices = [0, 1, 2, ..., L-1] (Shape: [L])
2. 扩展维度以匹配 Batch
indices: [1, L]
p_start: [B, 1]
p_end: [B, 1]
3. 逻辑比较 (Broadcasting)
mask = (indices > p_start) & (indices < p_end)
结果 Shape: [B, L] (Boolean)
4. 应用 Mask
masked_encoded = encoded * mask.unsqueeze(-1)
5. 求和并归一化
sum_vec = masked_encoded.sum(dim=1)
count = mask.sum(dim=1).clamp(min=1) # 防止除零
pooled = sum_vec / count
"""
B, L, H = encoded.shape
device = encoded.device
# 生成位置轴 [0, 1, ..., L-1]
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
# 调整 p_start 形状为 [B, 1] 以便广播
p_start_exp = p_start.unsqueeze(1) # [B, 1]
span_mask = positions >= p_start_exp
# 转换为 Float 用于乘法
span_mask_float = span_mask.float() # [B, L]
# 应用 Mask
# encoded: [B, L, H] -> mask: [B, L, 1]
masked_encoded = encoded * span_mask_float.unsqueeze(-1)
# 求和
span_sum = masked_encoded.sum(dim=1) # [B, H]
# 计算有效长度 (防止除以 0)
span_count = span_mask_float.sum(dim=1, keepdim=True).clamp(min=1.0) # [B, 1]
# 平均池化
pooled = span_sum / span_count # [B, H]
# ----- 3. 注意力池化(代替原来的 Span Pooling-----
# 使用 attention_mask 忽略 padding 位置
pooled = self.attn_pool(encoded, attention_mask, token_type_ids) # [B, H]
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
# ------------------ ONNX 导出模式条件分支batch=1------------------
# 此时 pg 为标量 Tensor转换为 Python int
# ONNX 导出模式batch=1根据 pg 选择专家
group_id = pg.item() if torch.is_tensor(pg) else pg
# 注意:专家索引从 0 开始,确保所有 case 都覆盖且偏置正确
# 使用字典映射或 if-elifONNX 需要静态图,此处保持原样但修正索引错误)
if group_id == 0:
expert_out = self.experts[0](pooled) + self.expert_bias(
torch.tensor(0, device=pooled.device)
@ -267,33 +254,36 @@ class MoEModel(nn.Module):
)
else:
batch_size = pooled.size(0)
# 并行计算所有专家输出
expert_outputs = torch.stack(
[e(pooled) for e in self.experts], dim=0
) # [E, B, D]
# 根据 pg 索引专家输出
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
# 添加专家偏置
bias = self.expert_bias(pg) # [B, D]
expert_out = expert_out + bias
# ----- 5. 分类头 -----
logits = self.classifier(expert_out) # [batch, num_classes]
if not self.training: # 推理时加 Softmax
probs = torch.softmax(logits, dim=-1)
return probs
return logits
return self.classifier(expert_out) # [batch, num_classes]
def model_eval(self, eval_dataloader, criterion):
"""
在验证集上评估模型返回准确率和平均损失
评估模型在验证集上的性能
参数
eval_dataloader: DataLoader提供 'input_ids', 'attention_mask', 'pg', 'char_id'
criterion: 损失函数默认为 CrossEntropyLoss()
返回
accuracy: float, 准确率
avg_loss: float, 平均损失
Args:
eval_dataloader (DataLoader): 验证集的数据加载器每个batch包含以下字段
- hint: 包含input_idsattention_mask和token_type_ids的字典
- pg: 程序图数据
- char_id: 字符ID标签
criterion (callable): 损失函数用于计算模型输出与标签之间的损失
Returns:
tuple: 包含两个浮点数的元组 (accuracy, avg_loss)
- accuracy (float): 模型在验证集上的准确率
- avg_loss (float): 模型在验证集上的平均损失
Note:
该方法会自动将模型切换到评估模式(self.eval())
并使用torch.no_grad()上下文管理器来禁用梯度计算
以节省内存和计算资源
"""
self.eval()
total_loss = 0.0
@ -302,20 +292,18 @@ class MoEModel(nn.Module):
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
probs = self(input_ids, attention_mask, pg)
loss = criterion(log_probs, labels)
logits = self(input_ids, attention_mask, token_type_ids, pg)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = probs.argmax(dim=-1)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
@ -327,117 +315,117 @@ class MoEModel(nn.Module):
"""
生成用于预测的样本数据
参数:
text (str): 输入的文本内容
py (list): 与文本对应的拼音列表
tokenizer (PreTrainedTokenizer, optional): 用于文本编码的分词器如果未提供且实例中没有默认分词器
则会自动加载预训练的分词器
该方法将文本和拼音转换为模型所需的输入格式包括input_idsattention_mask和token_type_ids
如果没有提供tokenizer会使用默认的AutoTokenizer
返回:
dict: 包含以下键值的字典
- "hint": 包含编码后的输入特征包括 "input_ids" "attention_mask"
- "pg": 一个张量表示拼音的第一个字符在 PG 映射中的索引
Args:
text (str): 输入文本作为第一句输入
py (str): 拼音字符串作为第二句输入
tokenizer (AutoTokenizer, optional): 分词器实例如果为None且self.tokenizer不存在
则会创建默认的分词器默认为None
功能说明:
1. 如果未提供分词器且实例中不存在默认分词器则从预训练模型加载分词器
2. 使用分词器对输入文本和拼音进行编码设置最大长度为 88并进行填充和截断
3. 构造样本字典包含编码后的输入特征和拼音映射张量
Returns:
dict: 包含模型输入的字典格式为
{
"hint": {
"input_ids": tensor, # 文本和拼音的token IDs
"attention_mask": tensor, # 注意力掩码
"token_type_ids": tensor # 句子类型ID
},
"pg": tensor # 拼音组ID根据拼音首字母生成
}
Notes:
- 使用text_pair参数让分词器自动生成token_type_ids
- 确保分词器支持return_token_type_ids=True
- 最大长度(max_length)设置为88
- 会自动进行padding和truncation处理
- 拼音组ID当前根据拼音首字母生成可根据实际需要改进
"""
# 如果未提供分词器且实例中没有默认分词器,则加载预训练分词器
if tokenizer is None and not hasattr(self, "tokenizer"):
self.tokenizer = AutoTokenizer.from_pretrained(
"iic/nlp_structbert_backbone_tiny_std"
"iic/nlp_structbert_backbone_lite_std"
)
else:
# 使用传入的分词器或实例中的默认分词器
self.tokenizer = tokenizer or self.tokenizer
# 对输入文本和拼音进行编码,生成模型所需的输入格式
hint = self.tokenizer(
text,
py,
# 使用 text_pair 参数让分词器自动生成 token_type_ids
# 注意:确保分词器支持 return_token_type_ids=True
encoded = self.tokenizer(
text, # 文本作为第一句
py, # 拼音作为第二句
max_length=88,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True, # 显式要求返回 token_type_ids
)
# 构造样本字典
sample = {}
sample["hint"] = {
"input_ids": hint["input_ids"],
"attention_mask": hint["attention_mask"],
sample = {
"hint": {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"token_type_ids": encoded["token_type_ids"], # 新增
},
"pg": torch.tensor(
[PG[py[0]]]
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
}
# 将拼音的第一个字符映射为 PG 中的索引并转换为张量
sample["pg"] = torch.tensor([PG[py[0]]])
sample["p_start"] = torch.tensor([len(text)])
return sample
def predict(self, text, py, tokenizer=None):
"""
基于输入的文本和拼音生成 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
预测函数自动处理 batch 维度
参数
text : str
输入的文本
py : str
输入的拼音
tokenizer : Tokenizer, optional
用于分词的分词器默认为 None
debug : bool
是否打印预测错误的样本信息
Args:
text (str or List[str]): 输入文本或文本列表
py (int or List[int]): 拼音特征可以是单个值或列表
tokenizer (object, optional): 分词器对象用于文本预处理默认为None
返回
preds : torch.Tensor
[batch] 预测类别标签若输入为单样本且无 batch 维度则返回标量
Returns:
torch.Tensor: 预测结果如果是单个输入则返回一维张量
如果是批量输入则返回二维张量
"""
self.eval() # 将模型设置为评估模式关闭dropout等训练时需要的层
self.eval()
sample = self.gen_predict_sample(text, py, tokenizer)
input_ids = sample["hint"]["input_ids"]
attention_mask = sample["hint"]["attention_mask"]
token_type_ids = sample["hint"]["token_type_ids"]
pg = sample["pg"]
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
sample = self.gen_predict_sample(text, py, tokenizer) # 生成预测所需的样本数据
input_ids = sample["hint"]["input_ids"] # 获取输入ID
attention_mask = sample["hint"]["attention_mask"] # 获取注意力掩码
pg = sample["pg"] # 获取拼音引导
has_batch_dim = input_ids.dim() > 1 # 判断输入是否有batch维度
# 如果没有batch维度则添加batch维度
has_batch_dim = input_ids.dim() > 1
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0) # 在第0维添加batch维度
attention_mask = attention_mask.unsqueeze(0) # 在第0维添加batch维度
# 如果拼音引导是标量则扩展为与输入ID相同的batch大小
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
token_type_ids = token_type_ids.unsqueeze(0)
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0)) # 扩展拼音引导的batch维度
pg = pg.unsqueeze(0).expand(input_ids.size(0))
# ------------------ 2. 移动设备 ------------------
# 将输入数据移动到模型所在设备GPU/CPU
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
token_type_ids = token_type_ids.to(self.device)
pg = pg.to(self.device)
# ------------------ 3. 推理 ------------------
# 使用torch.no_grad()上下文管理器,不计算梯度,节省内存
with torch.no_grad():
logits = self(input_ids, attention_mask, pg) # 前向传播获取logits
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
logits = self(input_ids, attention_mask, token_type_ids, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
# ------------------ 4. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds.squeeze(0)
return preds
def fit(
self,
train_dataloader, # 训练数据加载器
eval_dataloader=None, # 评估数据加载器,可选
monitor: Optional[TrainingMonitor] = None, # 训练监控器,用于记录训练过程
criterion=None, # 损失函数
optimizer=None, # 优化器
num_epochs=1, # 训练轮数
stop_batch=1e6, # 最大训练批次数
train_dataloader,
eval_dataloader=None,
monitor: Optional[TrainingMonitor] = None,
criterion=None,
optimizer=None,
num_epochs=1,
stop_batch=2e5,
eval_frequency=500,
grad_accum_steps=1, # 梯度累积步数
clip_grad_norm=1.0, # 梯度裁剪的范数
grad_accum_steps=1,
clip_grad_norm=1.0,
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
@ -445,25 +433,16 @@ class MoEModel(nn.Module):
label_smoothing=0.15,
lr=1e-4,
):
"""
训练模型支持混合精度梯度累积学习率调度实时监控
参数
# TODO: 添加参数注释
"""
# 确保模型在正确的设备上GPU或CPU
"""训练函数,调整了输入参数"""
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式
self.train()
# 默认优化器设置
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数设置
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
@ -472,13 +451,13 @@ class MoEModel(nn.Module):
else:
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
total_steps = max(stop_batch, 2e5)
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
processed_batches = 0
global_step = 0 # 初始化
batch_loss_sum = 0.0
optimizer.zero_grad()
@ -486,8 +465,6 @@ class MoEModel(nn.Module):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
processed_batches += 1
# LR Schedule
if processed_batches < warmup_steps:
current_lr = lr * (processed_batches/ warmup_steps)
@ -499,26 +476,22 @@ class MoEModel(nn.Module):
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# ---------- 移动数据 ----------
# 移动数据注意batch 中现在包含 token_type_ids
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device) # [B]
labels = batch["char_id"].to(self.device)
# 混合精度前向
# Forward
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg, p_start)
logits = self(input_ids, attention_mask, token_type_ids, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
# 反向传播
scaler.scale(loss).backward()
# 梯度累积
if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
@ -536,34 +509,34 @@ class MoEModel(nn.Module):
logger.warning("NaN detected, skipping step.")
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
# 周期性评估
if eval_dataloader and global_step % eval_frequency == 0:
if global_step % eval_frequency == 0:
if eval_dataloader:
self.eval()
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
if global_step == 0:
avg_loss = eval_loss
self.train()
if monitor:
# 使用 eval_loss 作为监控指标
monitor.add_step(
global_step, {"loss": avg_loss, "acc": acc}
global_step, {"loss": batch_loss_sum, "acc": acc}
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
)
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
batch_loss_sum = 0.0
if processed_batches - 1 >= stop_batch:
if processed_batches >= stop_batch:
break
processed_batches += 1
global_step += 1
# 训练结束发送通知
try:
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
to_wechat_response = send_serverchan_message(
send_serverchan_message(
title="训练完成",
content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}",
content=f"acc: {res_acc:.4f}, loss: {res_loss:.4f}",
)
logger.info(f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}")
logger.info(f"发送消息: {to_wechat_response}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
@ -601,11 +574,11 @@ class MoEModel(nn.Module):
# --- ONNX 导出辅助函数 ---
def export_onnx(self, output_path, dummy_input):
"""
dummy_input 应该是一个字典或元组包含:
(input_ids, attention_mask, pg, p_start)
dummy_input 应该是一个元组包含:
(input_ids, attention_mask, token_type_ids, pg)
"""
self.eval()
input_names = ["input_ids", "attention_mask", "pg", "p_start"]
input_names = ["input_ids", "attention_mask", "token_type_ids", "pg"]
output_names = ["logits"]
torch.onnx.export(
@ -617,11 +590,11 @@ class MoEModel(nn.Module):
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"token_type_ids": {0: "batch_size", 1: "seq_len"},
"pg": {0: "batch_size"},
"p_start": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=14, # 推荐使用 14+ 以支持更好的算子
opset_version=14,
do_constant_folding=True,
)
logger.info(f"ONNX model exported to {output_path}")