SUInput/src/trainer/model.py

521 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
from importlib.resources import files
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel
from tqdm import tqdm
from .monitor import TrainingMonitor
EVAL_DATALOADER = [
pickle.load(file.open("rb"))
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
]
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.1):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
x = self.relu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x) # 残差前加 Dropout符合原描述
x = x + residual
return self.relu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim,
d_model=1024,
num_resblocks=4,
output_multiplier=2,
dropout_prob=0.1,
):
"""
input_dim : BERT 输出的 hidden_size如 312/768
d_model : 专家内部维度(固定 1024
output_multiplier : 输出维度 = input_dim * output_multiplier
dropout_prob : 残差块内 Dropout
"""
super().__init__()
self.input_dim = input_dim
self.d_model = d_model
self.output_dim = input_dim * output_multiplier
# 输入映射input_dim -> d_model
self.linear_in = nn.Linear(input_dim, d_model)
# 残差堆叠
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
# 输出映射d_model -> output_dim
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(inplace=True),
nn.Linear(d_model, self.output_dim),
)
def forward(self, x):
x = self.linear_in(x)
for block in self.res_blocks:
x = block(x)
return self.output(x)
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module):
def __init__(
self,
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
num_classes=10018,
output_multiplier=2,
d_model=768,
num_resblocks=4,
num_domain_experts=8,
num_shared_experts=1,
):
super().__init__()
self.output_multiplier = output_multiplier
# 1. 加载预训练 BERT仅保留 embeddings
bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置
# 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=self.bert_config.num_attention_heads,
dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
norm_first=True, # Pre-LN与预训练一致
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1)
# 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList()
for i in range(total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
dropout_prob = 0.1 if i < num_domain_experts else 0.2
expert = Expert(
input_dim=self.hidden_size,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
dropout_prob=dropout_prob,
)
self.experts.append(expert)
self.expert_bias = nn.Embedding(
total_experts, self.output_multiplier * self.hidden_size
)
# 4. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.output_multiplier * self.hidden_size),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
)
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg):
"""
input_ids : [batch, seq_len]
attention_mask: [batch, seq_len] (1 为有效0 为 padding)
pg : group_id训练时为 [batch] 的 LongTensor推理导出时为标量 Tensor
"""
# ----- 1. Embeddings -----
embeddings = self.embedding(input_ids) # [B, S, H]
# ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置
padding_mask = attention_mask == 0
encoded = self.encoder(
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
# ----- 3. 池化量 -----
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
# ------------------ ONNX 导出模式条件分支batch=1------------------
# 此时 pg 为标量 Tensor转换为 Python int
group_id = pg.item() if torch.is_tensor(pg) else pg
if group_id == 0:
expert_out = self.experts[0](pooled) + self.expert_bias(
torch.tensor(0, device=pooled.device)
)
elif group_id == 1:
expert_out = self.experts[1](pooled) + self.expert_bias(
torch.tensor(1, device=pooled.device)
)
elif group_id == 2:
expert_out = self.experts[2](pooled) + self.expert_bias(
torch.tensor(2, device=pooled.device)
)
elif group_id == 3:
expert_out = self.experts[3](pooled) + self.expert_bias(
torch.tensor(3, device=pooled.device)
)
elif group_id == 4:
expert_out = self.experts[4](pooled) + self.expert_bias(
torch.tensor(4, device=pooled.device)
)
elif group_id == 5:
expert_out = self.experts[5](pooled) + self.expert_bias(
torch.tensor(5, device=pooled.device)
)
elif group_id == 6:
expert_out = self.experts[6](pooled) + self.expert_bias(
torch.tensor(6, device=pooled.device)
)
elif group_id == 7:
expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device)
)
else: # group_id == 8
expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device)
)
else:
batch_size = pooled.size(0)
# 并行计算所有专家输出
expert_outputs = torch.stack(
[e(pooled) for e in self.experts], dim=0
) # [E, B, D]
# 根据 pg 索引专家输出
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
# 添加专家偏置
bias = self.expert_bias(pg) # [B, D]
expert_out = expert_out + bias
# ----- 5. 分类头 -----
logits = self.classifier(expert_out) # [batch, num_classes]
if not self.training: # 推理时加 Softmax
probs = torch.softmax(logits, dim=-1)
return probs
return logits
def model_eval(self, eval_dataloader, criterion=None):
"""
在验证集上评估模型,返回准确率和平均损失。
参数:
eval_dataloader: DataLoader提供 'input_ids', 'attention_mask', 'pg', 'char_id'
criterion: 损失函数,默认为 CrossEntropyLoss()
返回:
accuracy: float, 准确率
avg_loss: float, 平均损失
"""
if criterion is None:
criterion = nn.CrossEntropyLoss()
self.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def predict(self, sample, debug=False):
"""
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
参数:
sample : dict
必须包含字段:
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
- 'attention_mask': 同上
- 'pg' : [batch] 或标量
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
调试时debug=True必须包含字段
- 'txt' : 字符串列表batch或单个字符串
- 'char' : 字符串列表batch或单个字符串
- 'py' : 字符串列表batch或单个字符串
debug : bool
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py抛出 ValueError。
返回:
preds : torch.Tensor
[batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量)
"""
self.eval()
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0))
# ------------------ 2. 移动设备 ------------------
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
pg = pg.to(self.device)
# ------------------ 3. 推理 ------------------
with torch.no_grad():
logits = self(input_ids, attention_mask, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 调试打印(错误样本) ------------------
if debug:
# 检查必需字段
required_keys = ["char_id", "txt", "char", "py"]
missing = [k for k in required_keys if k not in sample]
if missing:
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
# 提取真实标签
true_labels = sample["char_id"]
if true_labels.dim() == 0:
true_labels = true_labels.unsqueeze(0)
# 移动真实标签到相同设备
true_labels = true_labels.to(self.device)
# 找出预测错误的索引
incorrect_mask = preds != true_labels
incorrect_indices = torch.where(incorrect_mask)[0]
if len(incorrect_indices) > 0:
print("\n=== 预测错误样本 ===")
# 获取调试字段(可能是列表或单个字符串)
txts = sample["txt"]
chars = sample["char"]
pys = sample["py"]
# 统一转换为列表(如果输入是单个字符串)
if isinstance(txts, str):
txts = [txts]
chars = [chars]
pys = [pys]
for idx in incorrect_indices.cpu().numpy():
print(f"样本索引 {idx}:")
print(f" Text : {txts[idx]}")
print(f" Char : {chars[idx]}")
print(f" Pinyin: {pys[idx]}")
print(
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
)
print("===================\n")
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor: TrainingMonitor = None,
criterion=nn.CrossEntropyLoss(),
optimizer=None,
num_epochs=1,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
mixed_precision=False,
lr=1e-4,
lr_schedule=None, # 新增:可选的自定义学习率调度函数
):
"""
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
参数:
... 原有参数 ...
lr_schedule : callable, optional
自定义学习率调度函数,接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率。
若为 None则启用内置的固定阈值调度前1000批 1e-4之后 6e-6
"""
# 确保模型在正确的设备上
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式
super().train()
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
# ---------- 更新 batch 计数器 ----------
processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
if lr_schedule is not None:
# 调用用户自定义的调度函数
lr_schedule(processed_batches, optimizer)
# ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 混合精度前向
with amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
# 反向传播
scaler.scale(loss).backward()
# 梯度累积
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
global_step += 1
original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss
# 周期性评估(与原代码相同)
if (
eval_dataloader is not None
and global_step % eval_frequency == 0
):
avg_loss = batch_loss_sum / eval_frequency
acc, _ = self.model_eval(eval_dataloader, criterion)
super().train()
if monitor is not None:
monitor.add_step(
global_step,
{"loss": avg_loss, "acc": acc},
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}"
)
batch_loss_sum = 0.0
# ============================ 使用示例 ============================
if __name__ == "__main__":
# 1. 初始化模型
model = MoEModel()
model.eval()
# 2. 构造 dummy 输入batch=1用于导出 ONNX
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
# 3. 导出 ONNX使用条件分支仅计算一个专家
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask, dummy_pg),
"moe_cpu.onnx",
input_names=["input_ids", "attention_mask", "pg"],
output_names=["logits"],
dynamic_axes={ # 固定 batch=1可不设 dynamic_axes
"input_ids": {0: "batch"},
"attention_mask": {0: "batch"},
},
opset_version=12,
do_constant_folding=True,
)
print("ONNX 导出成功!")
# 4. 测试训练模式batch=4
model.train()
batch_input_ids = torch.randint(0, 100, (4, 64))
batch_attention_mask = torch.ones_like(batch_input_ids)
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
print("训练模式输出形状:", logits.shape) # [4, 10018]