595 lines
24 KiB
Python
595 lines
24 KiB
Python
import math
|
||
import pickle
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Optional, Union
|
||
|
||
import torch
|
||
import torch.amp as amp
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.optim as optim
|
||
from loguru import logger
|
||
from modelscope import AutoModel, AutoTokenizer
|
||
from tqdm.autonotebook import tqdm
|
||
|
||
from suinput.dataset import PG
|
||
|
||
from .monitor import TrainingMonitor, send_serverchan_message
|
||
|
||
|
||
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||
|
||
|
||
# ---------------------------- 注意力池化模块(新增)----------------------------
|
||
class AttentionPooling(nn.Module):
|
||
def __init__(self, hidden_size):
|
||
super().__init__()
|
||
self.attn = nn.Linear(hidden_size, 1)
|
||
# 三个可学习偏置:文本、拼音、个性化
|
||
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
|
||
|
||
def forward(self, x, mask=None, token_type_ids=None):
|
||
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
|
||
if token_type_ids is not None:
|
||
# 根据 token_type_ids 添加对应偏置
|
||
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
|
||
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
|
||
scores = scores + bias_per_token
|
||
if mask is not None:
|
||
scores = scores.masked_fill(mask == 0, -1e9)
|
||
weights = torch.softmax(scores, dim=-1)
|
||
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
||
return pooled
|
||
|
||
|
||
# ---------------------------- 残差块 ----------------------------
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, dim, dropout_prob=0.3):
|
||
super().__init__()
|
||
self.linear1 = nn.Linear(dim, dim)
|
||
self.ln1 = nn.LayerNorm(dim)
|
||
self.linear2 = nn.Linear(dim, dim)
|
||
self.ln2 = nn.LayerNorm(dim)
|
||
self.gelu = nn.GELU()
|
||
self.dropout = nn.Dropout(dropout_prob)
|
||
|
||
def forward(self, x):
|
||
residual = x
|
||
# 修复:使用 self.gelu 而不是未定义的 self.relu
|
||
x = self.gelu(self.linear1(x))
|
||
x = self.ln1(x)
|
||
x = self.linear2(x)
|
||
x = self.ln2(x)
|
||
x = self.dropout(x)
|
||
x = x + residual
|
||
return self.gelu(x)
|
||
|
||
|
||
# ---------------------------- 专家网络 ----------------------------
|
||
class Expert(nn.Module):
|
||
def __init__(
|
||
self,
|
||
input_dim,
|
||
d_model=768,
|
||
num_resblocks=4,
|
||
output_multiplier=2,
|
||
dropout_prob=0.3,
|
||
):
|
||
super().__init__()
|
||
self.output_dim = input_dim * output_multiplier
|
||
self.linear_in = nn.Linear(input_dim, d_model)
|
||
self.res_blocks = nn.ModuleList(
|
||
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||
)
|
||
self.output = nn.Sequential(
|
||
nn.Linear(d_model, d_model),
|
||
nn.GELU(),
|
||
nn.Dropout(dropout_prob),
|
||
nn.Linear(d_model, self.output_dim),
|
||
)
|
||
|
||
def forward(self, x):
|
||
x = self.linear_in(x)
|
||
for block in self.res_blocks:
|
||
x = block(x)
|
||
return self.output(x)
|
||
|
||
|
||
# ---------------------------- 主模型(MoE + 硬路由)------------------------
|
||
class MoEModel(nn.Module):
|
||
def __init__(
|
||
self,
|
||
pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
|
||
num_classes=10018,
|
||
output_multiplier=2,
|
||
d_model=768,
|
||
num_resblocks=4,
|
||
num_domain_experts=12,
|
||
num_shared_experts=1,
|
||
):
|
||
super().__init__()
|
||
self.output_multiplier = output_multiplier
|
||
|
||
# 1. 加载预训练 Embedding (Lite: hidden_size=512)
|
||
logger.info(f"Loading backbone: {pretrained_model_name}")
|
||
bert = AutoModel.from_pretrained(pretrained_model_name)
|
||
self.embedding = bert.embeddings
|
||
self.bert_config = bert.config
|
||
self.hidden_size = self.bert_config.hidden_size # 512
|
||
self.device = None
|
||
|
||
# 2. Transformer Encoder
|
||
encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=self.hidden_size,
|
||
nhead=8,
|
||
dim_feedforward=self.bert_config.intermediate_size,
|
||
dropout=self.bert_config.hidden_dropout_prob,
|
||
activation="gelu",
|
||
batch_first=True,
|
||
)
|
||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||
|
||
# 3. 注意力池化(新增)
|
||
self.attn_pool = AttentionPooling(self.hidden_size)
|
||
|
||
# 4. 专家系统
|
||
total_experts = num_domain_experts + num_shared_experts
|
||
self.experts = nn.ModuleList()
|
||
for i in range(total_experts):
|
||
dropout = 0.3 if i < num_domain_experts else 0.4
|
||
self.experts.append(
|
||
Expert(
|
||
input_dim=self.hidden_size,
|
||
d_model=d_model,
|
||
num_resblocks=num_resblocks,
|
||
output_multiplier=self.output_multiplier,
|
||
dropout_prob=dropout,
|
||
)
|
||
)
|
||
|
||
self.expert_bias = nn.Embedding(
|
||
total_experts, self.hidden_size * self.output_multiplier
|
||
)
|
||
|
||
# 5. 分类头
|
||
self.classifier = nn.Sequential(
|
||
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
||
nn.Dropout(0.4),
|
||
nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
|
||
nn.GELU(),
|
||
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
|
||
)
|
||
|
||
def to(self, device):
|
||
"""重写 to 方法,记录设备"""
|
||
self.device = device
|
||
return super().to(device)
|
||
|
||
def forward(self, input_ids, attention_mask, token_type_ids, pg):
|
||
"""
|
||
新版 Forward 函数,不再需要 p_start,改用 token_type_ids。
|
||
Args:
|
||
input_ids: [B, L]
|
||
attention_mask: [B, L]
|
||
token_type_ids: [B, L] (0=文本, 1=拼音)
|
||
pg: [B] 拼音组 ID
|
||
"""
|
||
# ----- 1. Embeddings -----
|
||
# 注意:预训练的 embedding 层本身可能已经包含了 token_type_ids 的处理,
|
||
# 但这里我们直接使用它的 embedding,并手动将 token_type_ids 的嵌入加到上面。
|
||
# 由于 bert.embeddings 通常包含 token_type_embeddings,我们可以利用它。
|
||
# 但为简化,我们直接使用 bert.embeddings(input_ids, token_type_ids=token_type_ids)
|
||
# 如果当前 embedding 不支持传入 token_type_ids,可以手动相加:
|
||
# embeddings = self.embedding(input_ids) + self.embedding.token_type_embeddings(token_type_ids)
|
||
# 这里采用更通用的方式:假设 self.embedding 有 token_type_ids 参数
|
||
embeddings = self.embedding(input_ids, token_type_ids=token_type_ids)
|
||
|
||
# ----- 2. Transformer Encoder -----
|
||
padding_mask = attention_mask == 0
|
||
encoded = self.encoder(
|
||
embeddings, src_key_padding_mask=padding_mask
|
||
) # [B, S, H]
|
||
|
||
# ----- 3. 注意力池化(代替原来的 Span Pooling)-----
|
||
# 使用 attention_mask 忽略 padding 位置
|
||
pooled = self.attn_pool(encoded, attention_mask, token_type_ids) # [B, H]
|
||
|
||
# ----- 4. 专家路由(硬路由)-----
|
||
if torch.jit.is_tracing():
|
||
# ONNX 导出模式:batch=1,根据 pg 选择专家
|
||
group_id = pg.item() if torch.is_tensor(pg) else pg
|
||
# 注意:专家索引从 0 开始,确保所有 case 都覆盖且偏置正确
|
||
# 使用字典映射或 if-elif(ONNX 需要静态图,此处保持原样但修正索引错误)
|
||
if group_id == 0:
|
||
expert_out = self.experts[0](pooled) + self.expert_bias(
|
||
torch.tensor(0, device=pooled.device)
|
||
)
|
||
elif group_id == 1:
|
||
expert_out = self.experts[1](pooled) + self.expert_bias(
|
||
torch.tensor(1, device=pooled.device)
|
||
)
|
||
elif group_id == 2:
|
||
expert_out = self.experts[2](pooled) + self.expert_bias(
|
||
torch.tensor(2, device=pooled.device)
|
||
)
|
||
elif group_id == 3:
|
||
expert_out = self.experts[3](pooled) + self.expert_bias(
|
||
torch.tensor(3, device=pooled.device)
|
||
)
|
||
elif group_id == 4:
|
||
expert_out = self.experts[4](pooled) + self.expert_bias(
|
||
torch.tensor(4, device=pooled.device)
|
||
)
|
||
elif group_id == 5:
|
||
expert_out = self.experts[5](pooled) + self.expert_bias(
|
||
torch.tensor(5, device=pooled.device)
|
||
)
|
||
elif group_id == 6:
|
||
expert_out = self.experts[6](pooled) + self.expert_bias(
|
||
torch.tensor(6, device=pooled.device)
|
||
)
|
||
elif group_id == 7:
|
||
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||
torch.tensor(7, device=pooled.device)
|
||
)
|
||
elif group_id == 8:
|
||
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||
torch.tensor(8, device=pooled.device)
|
||
)
|
||
elif group_id == 9:
|
||
expert_out = self.experts[9](pooled) + self.expert_bias(
|
||
torch.tensor(9, device=pooled.device)
|
||
)
|
||
elif group_id == 10:
|
||
expert_out = self.experts[10](pooled) + self.expert_bias(
|
||
torch.tensor(10, device=pooled.device)
|
||
)
|
||
elif group_id == 11:
|
||
expert_out = self.experts[11](pooled) + self.expert_bias(
|
||
torch.tensor(11, device=pooled.device)
|
||
)
|
||
else: # group_id == 12
|
||
expert_out = self.experts[12](pooled) + self.expert_bias(
|
||
torch.tensor(12, device=pooled.device)
|
||
)
|
||
else:
|
||
batch_size = pooled.size(0)
|
||
expert_outputs = torch.stack(
|
||
[e(pooled) for e in self.experts], dim=0
|
||
) # [E, B, D]
|
||
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
|
||
bias = self.expert_bias(pg) # [B, D]
|
||
expert_out = expert_out + bias
|
||
|
||
# ----- 5. 分类头 -----
|
||
return self.classifier(expert_out) # [batch, num_classes]
|
||
|
||
def model_eval(self, eval_dataloader, criterion):
|
||
"""
|
||
评估模型在验证集上的性能。
|
||
|
||
Args:
|
||
eval_dataloader (DataLoader): 验证集的数据加载器,每个batch包含以下字段:
|
||
- hint: 包含input_ids、attention_mask和token_type_ids的字典
|
||
- pg: 程序图数据
|
||
- char_id: 字符ID标签
|
||
criterion (callable): 损失函数,用于计算模型输出与标签之间的损失
|
||
|
||
Returns:
|
||
tuple: 包含两个浮点数的元组 (accuracy, avg_loss)
|
||
- accuracy (float): 模型在验证集上的准确率
|
||
- avg_loss (float): 模型在验证集上的平均损失
|
||
|
||
Note:
|
||
该方法会自动将模型切换到评估模式(self.eval()),
|
||
并使用torch.no_grad()上下文管理器来禁用梯度计算,
|
||
以节省内存和计算资源。
|
||
"""
|
||
self.eval()
|
||
total_loss = 0.0
|
||
correct = 0
|
||
total = 0
|
||
|
||
with torch.no_grad():
|
||
for batch in eval_dataloader:
|
||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
|
||
pg = batch["pg"].to(self.device)
|
||
labels = batch["char_id"].to(self.device)
|
||
|
||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||
loss = criterion(logits, labels)
|
||
|
||
total_loss += loss.item() * labels.size(0)
|
||
|
||
preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
||
correct += (preds == labels).sum().item()
|
||
total += labels.size(0)
|
||
|
||
avg_loss = total_loss / total if total > 0 else 0.0
|
||
accuracy = correct / total if total > 0 else 0.0
|
||
return accuracy, avg_loss
|
||
|
||
def gen_predict_sample(self, text, py, tokenizer=None):
|
||
"""
|
||
生成用于预测的样本数据。
|
||
|
||
该方法将文本和拼音转换为模型所需的输入格式,包括input_ids、attention_mask和token_type_ids。
|
||
如果没有提供tokenizer,会使用默认的AutoTokenizer。
|
||
|
||
Args:
|
||
text (str): 输入文本,作为第一句输入。
|
||
py (str): 拼音字符串,作为第二句输入。
|
||
tokenizer (AutoTokenizer, optional): 分词器实例。如果为None且self.tokenizer不存在,
|
||
则会创建默认的分词器。默认为None。
|
||
|
||
Returns:
|
||
dict: 包含模型输入的字典,格式为:
|
||
{
|
||
"hint": {
|
||
"input_ids": tensor, # 文本和拼音的token IDs
|
||
"attention_mask": tensor, # 注意力掩码
|
||
"token_type_ids": tensor # 句子类型ID
|
||
},
|
||
"pg": tensor # 拼音组ID,根据拼音首字母生成
|
||
}
|
||
|
||
Notes:
|
||
- 使用text_pair参数让分词器自动生成token_type_ids
|
||
- 确保分词器支持return_token_type_ids=True
|
||
- 最大长度(max_length)设置为88
|
||
- 会自动进行padding和truncation处理
|
||
- 拼音组ID当前根据拼音首字母生成,可根据实际需要改进
|
||
"""
|
||
if tokenizer is None and not hasattr(self, "tokenizer"):
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
"iic/nlp_structbert_backbone_lite_std"
|
||
)
|
||
else:
|
||
self.tokenizer = tokenizer or self.tokenizer
|
||
|
||
# 使用 text_pair 参数让分词器自动生成 token_type_ids
|
||
# 注意:确保分词器支持 return_token_type_ids=True
|
||
encoded = self.tokenizer(
|
||
text, # 文本作为第一句
|
||
py, # 拼音作为第二句
|
||
max_length=88,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True, # 显式要求返回 token_type_ids
|
||
)
|
||
|
||
sample = {
|
||
"hint": {
|
||
"input_ids": encoded["input_ids"],
|
||
"attention_mask": encoded["attention_mask"],
|
||
"token_type_ids": encoded["token_type_ids"], # 新增
|
||
},
|
||
"pg": torch.tensor(
|
||
[PG[py[0]] if py != "" else 12]
|
||
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
|
||
}
|
||
return sample
|
||
|
||
def predict(self, text, py, tokenizer=None):
|
||
"""
|
||
预测函数,自动处理 batch 维度
|
||
|
||
Args:
|
||
text (str or List[str]): 输入文本或文本列表
|
||
py (int or List[int]): 拼音特征,可以是单个值或列表
|
||
tokenizer (object, optional): 分词器对象,用于文本预处理。默认为None
|
||
|
||
Returns:
|
||
torch.Tensor: 预测结果,如果是单个输入则返回一维张量,
|
||
如果是批量输入则返回二维张量
|
||
"""
|
||
self.eval()
|
||
sample = self.gen_predict_sample(text, py, tokenizer)
|
||
input_ids = sample["hint"]["input_ids"]
|
||
attention_mask = sample["hint"]["attention_mask"]
|
||
token_type_ids = sample["hint"]["token_type_ids"]
|
||
pg = sample["pg"]
|
||
|
||
has_batch_dim = input_ids.dim() > 1
|
||
if not has_batch_dim:
|
||
input_ids = input_ids.unsqueeze(0)
|
||
attention_mask = attention_mask.unsqueeze(0)
|
||
token_type_ids = token_type_ids.unsqueeze(0)
|
||
if pg.dim() == 0:
|
||
pg = pg.unsqueeze(0).expand(input_ids.size(0))
|
||
|
||
input_ids = input_ids.to(self.device)
|
||
attention_mask = attention_mask.to(self.device)
|
||
token_type_ids = token_type_ids.to(self.device)
|
||
pg = pg.to(self.device)
|
||
|
||
with torch.no_grad():
|
||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||
preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
||
|
||
if not has_batch_dim:
|
||
return preds.squeeze(0)
|
||
return preds
|
||
|
||
def fit(
|
||
self,
|
||
train_dataloader,
|
||
eval_dataloader=None,
|
||
monitor: Optional[TrainingMonitor] = None,
|
||
criterion=None,
|
||
optimizer=None,
|
||
num_epochs=1,
|
||
stop_batch=2e5,
|
||
eval_frequency=500,
|
||
grad_accum_steps=1,
|
||
clip_grad_norm=1.0,
|
||
loss_weight=None,
|
||
mixed_precision=True,
|
||
weight_decay=0.1,
|
||
warmup_ratio=0.1,
|
||
label_smoothing=0.15,
|
||
lr=1e-4,
|
||
):
|
||
"""训练函数,调整了输入参数"""
|
||
if self.device is None:
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.to(self.device)
|
||
|
||
self.train()
|
||
|
||
if optimizer is None:
|
||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||
|
||
if criterion is None:
|
||
if loss_weight is not None:
|
||
criterion = nn.CrossEntropyLoss(
|
||
weight=loss_weight, label_smoothing=label_smoothing
|
||
)
|
||
else:
|
||
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||
|
||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||
|
||
total_steps = stop_batch
|
||
warmup_steps = int(total_steps * warmup_ratio)
|
||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||
processed_batches = 0
|
||
global_step = 0 # 初始化
|
||
batch_loss_sum = 0.0
|
||
optimizer.zero_grad()
|
||
|
||
for epoch in range(num_epochs):
|
||
for batch_idx, batch in enumerate(
|
||
tqdm(train_dataloader, total=int(stop_batch))
|
||
):
|
||
# LR Schedule
|
||
if processed_batches < warmup_steps:
|
||
current_lr = lr * (processed_batches/ warmup_steps)
|
||
else:
|
||
progress = (processed_batches - warmup_steps) / (
|
||
total_steps - warmup_steps
|
||
)
|
||
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||
for param_group in optimizer.param_groups:
|
||
param_group["lr"] = current_lr
|
||
|
||
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
|
||
pg = batch["pg"].to(self.device)
|
||
labels = batch["char_id"].to(self.device)
|
||
|
||
with torch.amp.autocast(
|
||
device_type=self.device.type, enabled=mixed_precision
|
||
):
|
||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||
loss = criterion(logits, labels)
|
||
loss = loss / grad_accum_steps
|
||
|
||
scaler.scale(loss).backward()
|
||
|
||
if (processed_batches) % grad_accum_steps == 0:
|
||
scaler.unscale_(optimizer)
|
||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
optimizer.zero_grad()
|
||
batch_loss_sum += loss.item() * grad_accum_steps
|
||
if global_step % eval_frequency == 0:
|
||
if eval_dataloader:
|
||
self.eval()
|
||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||
self.train()
|
||
if monitor:
|
||
# 使用 eval_loss 作为监控指标
|
||
monitor.add_step(
|
||
global_step, {"loss": batch_loss_sum / (eval_frequency if global_step > 0 else 1), "acc": acc}
|
||
)
|
||
logger.info(
|
||
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
||
)
|
||
else:
|
||
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
|
||
batch_loss_sum = 0.0
|
||
if processed_batches >= stop_batch:
|
||
break
|
||
processed_batches += 1
|
||
global_step += 1
|
||
|
||
# 训练结束发送通知
|
||
try:
|
||
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
|
||
send_serverchan_message(
|
||
title="训练完成",
|
||
content=f"acc: {res_acc:.4f}, loss: {res_loss:.4f}",
|
||
)
|
||
logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}")
|
||
except Exception as e:
|
||
logger.error(f"发送消息失败: {e}")
|
||
|
||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||
state_dict = torch.load(
|
||
state_dict_path, weights_only=True, map_location=self.device
|
||
)
|
||
self.load_state_dict(state_dict)
|
||
|
||
def load_from_pretrained_base_model(
|
||
self,
|
||
BaseModel,
|
||
snapshot_path: Union[str, Path],
|
||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||
*args,
|
||
**kwargs,
|
||
):
|
||
base_model = BaseModel(*args, **kwargs)
|
||
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||
self_static_dict = self.state_dict()
|
||
pretrained_dict = base_model.state_dict()
|
||
|
||
freeze_layers = []
|
||
|
||
for key in self_static_dict.keys():
|
||
if key in pretrained_dict.keys():
|
||
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||
freeze_layers.append(key)
|
||
self.load_state_dict(self_static_dict)
|
||
for name, param in self.named_parameters():
|
||
if name in freeze_layers:
|
||
param.requires_grad = False
|
||
|
||
# --- ONNX 导出辅助函数 ---
|
||
def export_onnx(self, output_path, dummy_input):
|
||
"""
|
||
dummy_input 应该是一个元组,包含:
|
||
(input_ids, attention_mask, token_type_ids, pg)
|
||
"""
|
||
self.eval()
|
||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pg"]
|
||
output_names = ["logits"]
|
||
|
||
torch.onnx.export(
|
||
self,
|
||
dummy_input,
|
||
output_path,
|
||
input_names=input_names,
|
||
output_names=output_names,
|
||
dynamic_axes={
|
||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||
"token_type_ids": {0: "batch_size", 1: "seq_len"},
|
||
"pg": {0: "batch_size"},
|
||
"logits": {0: "batch_size"},
|
||
},
|
||
opset_version=14,
|
||
do_constant_folding=True,
|
||
)
|
||
logger.info(f"ONNX model exported to {output_path}")
|