SUInput/src/trainer/model.py

628 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import pickle
from importlib.resources import files
from pathlib import Path
from typing import Optional, Union
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel, AutoTokenizer
from tqdm.autonotebook import tqdm
from suinput.dataset import PG
from .monitor import TrainingMonitor, send_serverchan_message
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
x = self.gelu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x)
x = x + residual
return self.gelu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim,
d_model=768,
num_resblocks=3,
output_multiplier=1,
dropout_prob=0.3,
):
super().__init__()
self.output_dim = input_dim * output_multiplier
self.linear_in = nn.Linear(input_dim, d_model)
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(inplace=True),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim),
)
def forward(self, x):
x = self.linear_in(x)
for block in self.res_blocks:
x = block(x)
return self.output(x)
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module):
def __init__(
self,
pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
num_classes=10018,
output_multiplier=2,
d_model=768,
num_resblocks=4,
num_domain_experts=12,
num_shared_experts=1,
):
super().__init__()
self.output_multiplier = output_multiplier
# 1. 加载预训练 Embedding (Lite: hidden_size=512)
logger.info(f"Loading backbone: {pretrained_model_name}")
bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # 512
self.device = None
# 2. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
# 3. 专家系统
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList()
for i in range(total_experts):
dropout = 0.3 if i < num_domain_experts else 0.4
self.experts.append(
Expert(
input_dim=self.hidden_size,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier,
dropout_prob=dropout,
)
)
self.expert_bias = nn.Embedding(
total_experts, self.hidden_size * self.output_multiplier
)
# 4. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Dropout(0.4),
nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg, p_start):
"""
ONNX 兼容的 Forward 函数
Args:
input_ids: [B, L]
attention_mask: [B, L]
pg: [B] 拼音组 ID
p_start: [B] 拼音起始索引位置 (整数 Tensor)
"""
# ----- 1. Embeddings -----
embeddings = self.embedding(input_ids)
# ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置
padding_mask = attention_mask == 0
encoded = self.encoder(
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
# ----- 3. ONNX 兼容的 Span Pooling (向量化实现) -----
"""
思路:
我们不能用循环去切片。我们要构造一个 Mask 矩阵。
目标:对于每个样本 i生成一个长度为 L 的向量,其中 p_start[i] < index < p_end[i] 的位置为 1其余为 0。
步骤:
1. 生成位置索引轴indices = [0, 1, 2, ..., L-1] (Shape: [L])
2. 扩展维度以匹配 Batch
indices: [1, L]
p_start: [B, 1]
p_end: [B, 1]
3. 逻辑比较 (Broadcasting)
mask = (indices > p_start) & (indices < p_end)
结果 Shape: [B, L] (Boolean)
4. 应用 Mask
masked_encoded = encoded * mask.unsqueeze(-1)
5. 求和并归一化:
sum_vec = masked_encoded.sum(dim=1)
count = mask.sum(dim=1).clamp(min=1) # 防止除零
pooled = sum_vec / count
"""
B, L, H = encoded.shape
device = encoded.device
# 生成位置轴 [0, 1, ..., L-1]
positions = torch.arange(L, device=device).unsqueeze(0) # [1, L]
# 调整 p_start 形状为 [B, 1] 以便广播
p_start_exp = p_start.unsqueeze(1) # [B, 1]
span_mask = positions >= p_start_exp
# 转换为 Float 用于乘法
span_mask_float = span_mask.float() # [B, L]
# 应用 Mask
# encoded: [B, L, H] -> mask: [B, L, 1]
masked_encoded = encoded * span_mask_float.unsqueeze(-1)
# 求和
span_sum = masked_encoded.sum(dim=1) # [B, H]
# 计算有效长度 (防止除以 0)
span_count = span_mask_float.sum(dim=1, keepdim=True).clamp(min=1.0) # [B, 1]
# 平均池化
pooled = span_sum / span_count # [B, H]
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
# ------------------ ONNX 导出模式条件分支batch=1------------------
# 此时 pg 为标量 Tensor转换为 Python int
group_id = pg.item() if torch.is_tensor(pg) else pg
if group_id == 0:
expert_out = self.experts[0](pooled) + self.expert_bias(
torch.tensor(0, device=pooled.device)
)
elif group_id == 1:
expert_out = self.experts[1](pooled) + self.expert_bias(
torch.tensor(1, device=pooled.device)
)
elif group_id == 2:
expert_out = self.experts[2](pooled) + self.expert_bias(
torch.tensor(2, device=pooled.device)
)
elif group_id == 3:
expert_out = self.experts[3](pooled) + self.expert_bias(
torch.tensor(3, device=pooled.device)
)
elif group_id == 4:
expert_out = self.experts[4](pooled) + self.expert_bias(
torch.tensor(4, device=pooled.device)
)
elif group_id == 5:
expert_out = self.experts[5](pooled) + self.expert_bias(
torch.tensor(5, device=pooled.device)
)
elif group_id == 6:
expert_out = self.experts[6](pooled) + self.expert_bias(
torch.tensor(6, device=pooled.device)
)
elif group_id == 7:
expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device)
)
elif group_id == 8:
expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device)
)
elif group_id == 9:
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10:
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11:
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
else: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
else:
batch_size = pooled.size(0)
# 并行计算所有专家输出
expert_outputs = torch.stack(
[e(pooled) for e in self.experts], dim=0
) # [E, B, D]
# 根据 pg 索引专家输出
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
# 添加专家偏置
bias = self.expert_bias(pg) # [B, D]
expert_out = expert_out + bias
# ----- 5. 分类头 -----
logits = self.classifier(expert_out) # [batch, num_classes]
if not self.training: # 推理时加 Softmax
probs = torch.softmax(logits, dim=-1)
return probs
return logits
def model_eval(self, eval_dataloader, criterion):
"""
在验证集上评估模型,返回准确率和平均损失。
参数:
eval_dataloader: DataLoader提供 'input_ids', 'attention_mask', 'pg', 'char_id'
criterion: 损失函数,默认为 CrossEntropyLoss()
返回:
accuracy: float, 准确率
avg_loss: float, 平均损失
"""
self.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
probs = self(input_ids, attention_mask, pg)
loss = criterion(log_probs, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = probs.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def gen_predict_sample(self, text, py, tokenizer=None):
"""
生成用于预测的样本数据。
参数:
text (str): 输入的文本内容。
py (list): 与文本对应的拼音列表。
tokenizer (PreTrainedTokenizer, optional): 用于文本编码的分词器。如果未提供且实例中没有默认分词器,
则会自动加载预训练的分词器。
返回:
dict: 包含以下键值的字典:
- "hint": 包含编码后的输入特征,包括 "input_ids""attention_mask"
- "pg": 一个张量,表示拼音的第一个字符在 PG 映射中的索引。
功能说明:
1. 如果未提供分词器且实例中不存在默认分词器,则从预训练模型加载分词器。
2. 使用分词器对输入文本和拼音进行编码,设置最大长度为 88并进行填充和截断。
3. 构造样本字典,包含编码后的输入特征和拼音映射张量。
"""
# 如果未提供分词器且实例中没有默认分词器,则加载预训练分词器
if tokenizer is None and not hasattr(self, "tokenizer"):
self.tokenizer = AutoTokenizer.from_pretrained(
"iic/nlp_structbert_backbone_tiny_std"
)
else:
# 使用传入的分词器或实例中的默认分词器
self.tokenizer = tokenizer or self.tokenizer
# 对输入文本和拼音进行编码,生成模型所需的输入格式
hint = self.tokenizer(
text,
py,
max_length=88,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# 构造样本字典
sample = {}
sample["hint"] = {
"input_ids": hint["input_ids"],
"attention_mask": hint["attention_mask"],
}
# 将拼音的第一个字符映射为 PG 中的索引并转换为张量
sample["pg"] = torch.tensor([PG[py[0]]])
sample["p_start"] = torch.tensor([len(text)])
return sample
def predict(self, text, py, tokenizer=None):
"""
基于输入的文本和拼音,生成 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
参数:
text : str
输入的文本。
py : str
输入的拼音。
tokenizer : Tokenizer, optional
用于分词的分词器,默认为 None。
debug : bool
是否打印预测错误的样本信息。
返回:
preds : torch.Tensor
[batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量)
"""
self.eval() # 将模型设置为评估模式关闭dropout等训练时需要的层
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
sample = self.gen_predict_sample(text, py, tokenizer) # 生成预测所需的样本数据
input_ids = sample["hint"]["input_ids"] # 获取输入ID
attention_mask = sample["hint"]["attention_mask"] # 获取注意力掩码
pg = sample["pg"] # 获取拼音引导
has_batch_dim = input_ids.dim() > 1 # 判断输入是否有batch维度
# 如果没有batch维度则添加batch维度
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0) # 在第0维添加batch维度
attention_mask = attention_mask.unsqueeze(0) # 在第0维添加batch维度
# 如果拼音引导是标量则扩展为与输入ID相同的batch大小
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0)) # 扩展拼音引导的batch维度
# ------------------ 2. 移动设备 ------------------
# 将输入数据移动到模型所在设备GPU/CPU
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
pg = pg.to(self.device)
# ------------------ 3. 推理 ------------------
# 使用torch.no_grad()上下文管理器,不计算梯度,节省内存
with torch.no_grad():
logits = self(input_ids, attention_mask, pg) # 前向传播获取logits
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds
def fit(
self,
train_dataloader, # 训练数据加载器
eval_dataloader=None, # 评估数据加载器,可选
monitor: Optional[TrainingMonitor] = None, # 训练监控器,用于记录训练过程
criterion=None, # 损失函数
optimizer=None, # 优化器
num_epochs=1, # 训练轮数
stop_batch=1e6, # 最大训练批次数
eval_frequency=500,
grad_accum_steps=1, # 梯度累积步数
clip_grad_norm=1.0, # 梯度裁剪的范数
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
):
"""
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
参数:
# TODO: 添加参数注释
"""
# 确保模型在正确的设备上GPU或CPU
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式
self.train()
# 默认优化器设置
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数设置
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing
)
else:
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
processed_batches += 1
# LR Schedule
if processed_batches < warmup_steps:
current_lr = lr * (processed_batches / warmup_steps)
else:
progress = (processed_batches - warmup_steps) / (
total_steps - warmup_steps
)
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
p_start = batch["p_start"].to(self.device) # [B]
labels = batch["char_id"].to(self.device)
# 混合精度前向
# Forward
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg, p_start)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
# 反向传播
scaler.scale(loss).backward()
# 梯度累积
if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
has_nan = False
for p in self.parameters():
if p.grad is not None and torch.isnan(p.grad).any():
has_nan = True
break
if not has_nan:
scaler.step(optimizer)
scaler.update()
else:
logger.warning("NaN detected, skipping step.")
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
# 周期性评估
if eval_dataloader and global_step % eval_frequency == 0:
self.eval()
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
if global_step == 0:
avg_loss = eval_loss
self.train()
if monitor:
monitor.add_step(
global_step, {"loss": avg_loss, "acc": acc}
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
)
batch_loss_sum = 0.0
if processed_batches - 1 >= stop_batch:
break
global_step += 1
try:
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
to_wechat_response = send_serverchan_message(
title="训练完成",
content=f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}",
)
logger.info(f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}")
logger.info(f"发送消息: {to_wechat_response}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False
# --- ONNX 导出辅助函数 ---
def export_onnx(self, output_path, dummy_input):
"""
dummy_input 应该是一个字典或元组,包含:
(input_ids, attention_mask, pg, p_start)
"""
self.eval()
input_names = ["input_ids", "attention_mask", "pg", "p_start"]
output_names = ["logits"]
torch.onnx.export(
self,
dummy_input,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"pg": {0: "batch_size"},
"p_start": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=14, # 推荐使用 14+ 以支持更好的算子
do_constant_folding=True,
)
logger.info(f"ONNX model exported to {output_path}")