feat(model): 更新模型结构,使用 GELU 激活函数并优化专家网络参数
This commit is contained in:
parent
db90516fcf
commit
93dced50c7
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import pickle
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
|
@ -12,9 +13,10 @@ from loguru import logger
|
|||
from modelscope import AutoModel, AutoTokenizer
|
||||
from tqdm.autonotebook import tqdm
|
||||
|
||||
from .monitor import TrainingMonitor, send_serverchan_message
|
||||
from suinput.dataset import PG
|
||||
|
||||
from .monitor import TrainingMonitor, send_serverchan_message
|
||||
|
||||
|
||||
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||||
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||||
|
|
@ -22,61 +24,46 @@ def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset
|
|||
|
||||
# ---------------------------- 残差块 ----------------------------
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim, dropout_prob=0.1):
|
||||
def __init__(self, dim, dropout_prob=0.3):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(dim, dim)
|
||||
self.ln1 = nn.LayerNorm(dim)
|
||||
self.linear2 = nn.Linear(dim, dim)
|
||||
self.ln2 = nn.LayerNorm(dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.gelu = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.relu(self.linear1(x))
|
||||
x = self.gelu(self.linear1(x))
|
||||
x = self.ln1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.ln2(x)
|
||||
x = self.dropout(x)
|
||||
x = x + residual
|
||||
return self.relu(x)
|
||||
return self.gelu(x)
|
||||
|
||||
|
||||
# ---------------------------- 专家网络 ----------------------------
|
||||
class Expert(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim, # 输入特征的维度大小
|
||||
d_model=1024, # 模型内部的隐藏层维度,默认为1024
|
||||
num_resblocks=4, # 残差块的数量,默认为4
|
||||
output_multiplier=2, # 输出维度是输入维度的倍数,默认为2倍
|
||||
dropout_prob=0.1, # Dropout层的丢弃概率,默认为0.1
|
||||
input_dim,
|
||||
d_model=768,
|
||||
num_resblocks=3,
|
||||
output_multiplier=1,
|
||||
dropout_prob=0.3,
|
||||
):
|
||||
"""
|
||||
初始化函数,用于构建模型的各个层
|
||||
参数说明:
|
||||
input_dim : 输入维度
|
||||
d_model : 专家内部维度
|
||||
output_multiplier : 输出维度 = input_dim * output_multiplier
|
||||
dropout_prob : 残差块内 Dropout
|
||||
"""
|
||||
super().__init__() # 调用父类的初始化方法
|
||||
self.input_dim = input_dim # 保存输入维度
|
||||
self.d_model = d_model # 保存模型内部维度
|
||||
self.output_dim = input_dim * output_multiplier # 计算并保存输出维度
|
||||
|
||||
# 输入映射:input_dim -> d_model
|
||||
super().__init__()
|
||||
self.output_dim = input_dim * output_multiplier
|
||||
self.linear_in = nn.Linear(input_dim, d_model)
|
||||
|
||||
# 残差堆叠
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||||
)
|
||||
|
||||
# 输出映射:d_model -> output_dim
|
||||
self.output = nn.Sequential(
|
||||
nn.Linear(d_model, d_model),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.GELU(inplace=True),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.Linear(d_model, self.output_dim),
|
||||
)
|
||||
|
||||
|
|
@ -91,7 +78,7 @@ class Expert(nn.Module):
|
|||
class MoEModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
|
||||
pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
|
||||
num_classes=10018,
|
||||
output_multiplier=2,
|
||||
d_model=768,
|
||||
|
|
@ -102,14 +89,15 @@ class MoEModel(nn.Module):
|
|||
super().__init__()
|
||||
self.output_multiplier = output_multiplier
|
||||
|
||||
# 1. 加载预训练 BERT,仅保留 embeddings
|
||||
# 1. 加载预训练 Embedding (Lite: hidden_size=512)
|
||||
logger.info(f"Loading backbone: {pretrained_model_name}")
|
||||
bert = AutoModel.from_pretrained(pretrained_model_name)
|
||||
self.embedding = bert.embeddings
|
||||
self.bert_config = bert.config
|
||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||
self.device = None # 将在 to() 调用时设置
|
||||
self.hidden_size = self.bert_config.hidden_size # 512
|
||||
self.device = None
|
||||
|
||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||
# 2. Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=self.hidden_size,
|
||||
nhead=8,
|
||||
|
|
@ -120,63 +108,49 @@ class MoEModel(nn.Module):
|
|||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||
|
||||
# self.shared_resblocks = nn.ModuleList(
|
||||
# [ResidualBlock(self.hidden_size, 0.1) for _ in range(4)]
|
||||
# )
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
# self.linear = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
|
||||
# 3. 专家层:n个领域专家 + 1个共享专家
|
||||
# 3. 专家系统
|
||||
total_experts = num_domain_experts + num_shared_experts
|
||||
self.experts = nn.ModuleList()
|
||||
|
||||
for i in range(total_experts):
|
||||
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
|
||||
dropout_prob = 0.1 if i < num_domain_experts else 0.2
|
||||
expert = Expert(
|
||||
dropout = 0.3 if i < num_domain_experts else 0.4
|
||||
self.experts.append(
|
||||
Expert(
|
||||
input_dim=self.hidden_size,
|
||||
d_model=d_model,
|
||||
num_resblocks=num_resblocks,
|
||||
output_multiplier=self.output_multiplier,
|
||||
dropout_prob=dropout_prob,
|
||||
dropout_prob=dropout,
|
||||
)
|
||||
)
|
||||
self.experts.append(expert)
|
||||
|
||||
self.expert_bias = nn.Embedding(
|
||||
total_experts, self.output_multiplier * self.hidden_size
|
||||
total_experts, self.hidden_size * self.output_multiplier
|
||||
)
|
||||
|
||||
# 4. 分类头
|
||||
self.classifier = nn.Sequential(
|
||||
nn.LayerNorm(self.output_multiplier * self.hidden_size),
|
||||
nn.Linear(
|
||||
self.output_multiplier * self.hidden_size,
|
||||
self.output_multiplier * self.hidden_size,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(
|
||||
self.output_multiplier * self.hidden_size,
|
||||
self.output_multiplier * self.hidden_size * 2,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
|
||||
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
|
||||
)
|
||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||
|
||||
def to(self, device):
|
||||
"""重写 to 方法,记录设备"""
|
||||
self.device = device
|
||||
return super().to(device)
|
||||
|
||||
def forward(self, input_ids, attention_mask, pg):
|
||||
def forward(self, input_ids, attention_mask, pg, p_start):
|
||||
"""
|
||||
input_ids : [batch, seq_len]
|
||||
attention_mask: [batch, seq_len] (1 为有效,0 为 padding)
|
||||
pg : group_id,训练时为 [batch] 的 LongTensor,推理导出时为标量 Tensor
|
||||
ONNX 兼容的 Forward 函数
|
||||
|
||||
Args:
|
||||
input_ids: [B, L]
|
||||
attention_mask: [B, L]
|
||||
pg: [B] 拼音组 ID
|
||||
p_start: [B] 拼音起始索引位置 (整数 Tensor)
|
||||
"""
|
||||
# ----- 1. Embeddings -----
|
||||
embeddings = self.embedding(input_ids) # [B, S, H]
|
||||
embeddings = self.embedding(input_ids)
|
||||
|
||||
# ----- 2. Transformer Encoder -----
|
||||
# padding mask: True 表示忽略该位置
|
||||
|
|
@ -185,13 +159,53 @@ class MoEModel(nn.Module):
|
|||
embeddings, src_key_padding_mask=padding_mask
|
||||
) # [B, S, H]
|
||||
|
||||
# ----- 3. 池化量 -----
|
||||
# for block in self.shared_resblocks:
|
||||
# encoded = block(encoded)
|
||||
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
|
||||
# pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
|
||||
# pooled = pooled.flatten(1) # [B, H*2]
|
||||
# pooled = self.linear(pooled)
|
||||
# ----- 3. ONNX 兼容的 Span Pooling (向量化实现) -----
|
||||
"""
|
||||
思路:
|
||||
我们不能用循环去切片。我们要构造一个 Mask 矩阵。
|
||||
目标:对于每个样本 i,生成一个长度为 L 的向量,其中 p_start[i] < index < p_end[i] 的位置为 1,其余为 0。
|
||||
|
||||
步骤:
|
||||
1. 生成位置索引轴:indices = [0, 1, 2, ..., L-1] (Shape: [L])
|
||||
2. 扩展维度以匹配 Batch:
|
||||
indices: [1, L]
|
||||
p_start: [B, 1]
|
||||
p_end: [B, 1]
|
||||
3. 逻辑比较 (Broadcasting):
|
||||
mask = (indices > p_start) & (indices < p_end)
|
||||
结果 Shape: [B, L] (Boolean)
|
||||
4. 应用 Mask:
|
||||
masked_encoded = encoded * mask.unsqueeze(-1)
|
||||
5. 求和并归一化:
|
||||
sum_vec = masked_encoded.sum(dim=1)
|
||||
count = mask.sum(dim=1).clamp(min=1) # 防止除零
|
||||
pooled = sum_vec / count
|
||||
"""
|
||||
B, L, H = encoded.shape
|
||||
device = encoded.device
|
||||
|
||||
# 生成位置轴 [0, 1, ..., L-1]
|
||||
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
|
||||
|
||||
# 调整 p_start 形状为 [B, 1] 以便广播
|
||||
p_start_exp = p_start.unsqueeze(1) # [B, 1]
|
||||
span_mask = positions >= p_start_exp
|
||||
|
||||
# 转换为 Float 用于乘法
|
||||
span_mask_float = span_mask.float() # [B, L]
|
||||
|
||||
# 应用 Mask
|
||||
# encoded: [B, L, H] -> mask: [B, L, 1]
|
||||
masked_encoded = encoded * span_mask_float.unsqueeze(-1)
|
||||
|
||||
# 求和
|
||||
span_sum = masked_encoded.sum(dim=1) # [B, H]
|
||||
|
||||
# 计算有效长度 (防止除以 0)
|
||||
span_count = span_mask_float.sum(dim=1, keepdim=True).clamp(min=1.0) # [B, 1]
|
||||
|
||||
# 平均池化
|
||||
pooled = span_sum / span_count # [B, H]
|
||||
|
||||
# ----- 4. 专家路由(硬路由)-----
|
||||
if torch.jit.is_tracing():
|
||||
|
|
@ -245,7 +259,7 @@ class MoEModel(nn.Module):
|
|||
)
|
||||
elif group_id == 11:
|
||||
expert_out = self.experts[11](pooled) + self.expert_bias(
|
||||
torch.tensor(12, device=pooled.device)
|
||||
torch.tensor(11, device=pooled.device)
|
||||
)
|
||||
else: # group_id == 12
|
||||
expert_out = self.experts[12](pooled) + self.expert_bias(
|
||||
|
|
@ -270,7 +284,7 @@ class MoEModel(nn.Module):
|
|||
return probs
|
||||
return logits
|
||||
|
||||
def model_eval(self, eval_dataloader, criterion=None):
|
||||
def model_eval(self, eval_dataloader, criterion):
|
||||
"""
|
||||
在验证集上评估模型,返回准确率和平均损失。
|
||||
|
||||
|
|
@ -281,9 +295,6 @@ class MoEModel(nn.Module):
|
|||
accuracy: float, 准确率
|
||||
avg_loss: float, 平均损失
|
||||
"""
|
||||
if criterion is None:
|
||||
criterion = nn.NLLLoss()
|
||||
|
||||
self.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
|
|
@ -295,11 +306,11 @@ class MoEModel(nn.Module):
|
|||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
p_start = batch["p_start"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
# 前向传播
|
||||
probs = self(input_ids, attention_mask, pg)
|
||||
log_probs = torch.log(probs + 1e-12)
|
||||
loss = criterion(log_probs, labels)
|
||||
total_loss += loss.item() * labels.size(0)
|
||||
|
||||
|
|
@ -359,6 +370,7 @@ class MoEModel(nn.Module):
|
|||
}
|
||||
# 将拼音的第一个字符映射为 PG 中的索引并转换为张量
|
||||
sample["pg"] = torch.tensor([PG[py[0]]])
|
||||
sample["p_start"] = torch.tensor([len(text)])
|
||||
return sample
|
||||
|
||||
def predict(self, text, py, tokenizer=None):
|
||||
|
|
@ -366,7 +378,12 @@ class MoEModel(nn.Module):
|
|||
基于输入的文本和拼音,生成 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||
|
||||
参数:
|
||||
|
||||
text : str
|
||||
输入的文本。
|
||||
py : str
|
||||
输入的拼音。
|
||||
tokenizer : Tokenizer, optional
|
||||
用于分词的分词器,默认为 None。
|
||||
debug : bool
|
||||
是否打印预测错误的样本信息。
|
||||
|
||||
|
|
@ -421,40 +438,18 @@ class MoEModel(nn.Module):
|
|||
eval_frequency=500,
|
||||
grad_accum_steps=1, # 梯度累积步数
|
||||
clip_grad_norm=1.0, # 梯度裁剪的范数
|
||||
mixed_precision=False, # 是否使用混合精度训练
|
||||
loss_weight=None, # 损失权重,用于处理类别不平衡
|
||||
lr=6e-5, # 初始学习率
|
||||
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
||||
loss_weight=None,
|
||||
mixed_precision=True,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
label_smoothing=0.15,
|
||||
lr=1e-4,
|
||||
):
|
||||
"""
|
||||
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||
|
||||
参数:
|
||||
train_dataloader: DataLoader
|
||||
训练数据加载器。
|
||||
eval_dataloader: DataLoader, optional
|
||||
评估数据加载器。
|
||||
monitor: TrainingMonitor, optional
|
||||
训练监控器。
|
||||
criterion: nn.Module, optional
|
||||
损失函数。
|
||||
optimizer: optim.Optimizer, optional
|
||||
优化器。
|
||||
num_epochs: int, optional
|
||||
训练轮数。
|
||||
eval_frequency: int, optional
|
||||
评估频率。
|
||||
grad_accum_steps: int, optional
|
||||
梯度累积步数。
|
||||
clip_grad_norm: float, optional
|
||||
梯度裁剪范数。
|
||||
mixed_precision: bool, optional
|
||||
是否使用混合精度。
|
||||
lr: float, optional
|
||||
初始学习率。
|
||||
lr_schedule : callable, optional
|
||||
自定义学习率调度函数,接收参数 (processed_batches, optimizer),
|
||||
可在内部直接修改 optimizer.param_groups 中的学习率。
|
||||
# TODO: 添加参数注释
|
||||
"""
|
||||
# 确保模型在正确的设备上(GPU或CPU)
|
||||
if self.device is None:
|
||||
|
|
@ -462,48 +457,61 @@ class MoEModel(nn.Module):
|
|||
self.to(self.device)
|
||||
|
||||
# 切换到训练模式
|
||||
super().train()
|
||||
self.train()
|
||||
|
||||
# 默认优化器设置
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=lr)
|
||||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
# 损失函数设置
|
||||
if criterion is None:
|
||||
if loss_weight is not None:
|
||||
criterion = nn.CrossEntropyLoss(weight=loss_weight)
|
||||
criterion = nn.CrossEntropyLoss(
|
||||
weight=loss_weight, label_smoothing=label_smoothing
|
||||
)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||
|
||||
# 混合精度缩放器
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
||||
global_step = 0
|
||||
total_steps = stop_batch
|
||||
warmup_steps = int(total_steps * warmup_ratio)
|
||||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||||
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
|
||||
batch_loss_sum = 0.0
|
||||
optimizer.zero_grad()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)):
|
||||
# ---------- 更新 batch 计数器 ----------
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(train_dataloader, total=int(stop_batch))
|
||||
):
|
||||
processed_batches += 1
|
||||
|
||||
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
||||
if lr_schedule is not None:
|
||||
# 调用用户自定义的调度函数
|
||||
lr_schedule(processed_batches, optimizer)
|
||||
# LR Schedule
|
||||
if processed_batches < warmup_steps:
|
||||
current_lr = lr * (processed_batches / warmup_steps)
|
||||
else:
|
||||
progress = (processed_batches - warmup_steps) / (
|
||||
total_steps - warmup_steps
|
||||
)
|
||||
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = current_lr
|
||||
|
||||
# ---------- 移动数据 ----------
|
||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
p_start = batch["p_start"].to(self.device) # [B]
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
# 混合精度前向
|
||||
with amp.autocast(
|
||||
# Forward
|
||||
with torch.amp.autocast(
|
||||
device_type=self.device.type, enabled=mixed_precision
|
||||
):
|
||||
logits = self(input_ids, attention_mask, pg)
|
||||
logits = self(input_ids, attention_mask, pg, p_start)
|
||||
loss = criterion(logits, labels)
|
||||
loss = loss / grad_accum_steps
|
||||
|
||||
|
|
@ -511,28 +519,34 @@ class MoEModel(nn.Module):
|
|||
scaler.scale(loss).backward()
|
||||
|
||||
# 梯度累积
|
||||
if (batch_idx) % grad_accum_steps == 0:
|
||||
if (processed_batches) % grad_accum_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||
|
||||
has_nan = False
|
||||
for p in self.parameters():
|
||||
if p.grad is not None and torch.isnan(p.grad).any():
|
||||
has_nan = True
|
||||
break
|
||||
|
||||
if not has_nan:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
logger.warning("NaN detected, skipping step.")
|
||||
optimizer.zero_grad()
|
||||
original_loss = loss.item() * grad_accum_steps
|
||||
batch_loss_sum += original_loss
|
||||
# 周期性评估(与原代码相同)
|
||||
if (
|
||||
eval_dataloader is not None
|
||||
and global_step % eval_frequency == 0
|
||||
):
|
||||
avg_loss = batch_loss_sum / eval_frequency
|
||||
batch_loss_sum += loss.item() * grad_accum_steps
|
||||
|
||||
# 周期性评估
|
||||
if eval_dataloader and global_step % eval_frequency == 0:
|
||||
self.eval()
|
||||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||||
if global_step == 0:
|
||||
avg_loss = eval_loss
|
||||
super().train()
|
||||
if monitor is not None:
|
||||
self.train()
|
||||
if monitor:
|
||||
monitor.add_step(
|
||||
global_step,
|
||||
{"loss": avg_loss, "acc": acc},
|
||||
global_step, {"loss": avg_loss, "acc": acc}
|
||||
)
|
||||
logger.info(
|
||||
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
|
||||
|
|
@ -541,8 +555,9 @@ class MoEModel(nn.Module):
|
|||
if processed_batches - 1 >= stop_batch:
|
||||
break
|
||||
global_step += 1
|
||||
res_acc, res_loss = self.model_eval(eval_dataloader)
|
||||
|
||||
try:
|
||||
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
|
||||
to_wechat_response = send_serverchan_message(
|
||||
title="训练完成",
|
||||
content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}",
|
||||
|
|
@ -582,3 +597,31 @@ class MoEModel(nn.Module):
|
|||
for name, param in self.named_parameters():
|
||||
if name in freeze_layers:
|
||||
param.requires_grad = False
|
||||
|
||||
# --- ONNX 导出辅助函数 ---
|
||||
def export_onnx(self, output_path, dummy_input):
|
||||
"""
|
||||
dummy_input 应该是一个字典或元组,包含:
|
||||
(input_ids, attention_mask, pg, p_start)
|
||||
"""
|
||||
self.eval()
|
||||
input_names = ["input_ids", "attention_mask", "pg", "p_start"]
|
||||
output_names = ["logits"]
|
||||
|
||||
torch.onnx.export(
|
||||
self,
|
||||
dummy_input,
|
||||
output_path,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"pg": {0: "batch_size"},
|
||||
"p_start": {0: "batch_size"},
|
||||
"logits": {0: "batch_size"},
|
||||
},
|
||||
opset_version=14, # 推荐使用 14+ 以支持更好的算子
|
||||
do_constant_folding=True,
|
||||
)
|
||||
logger.info(f"ONNX model exported to {output_path}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue