SUInput/src/trainer/model_with_neck.py

515 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
from importlib.resources import files
from pathlib import Path
from typing import Optional, Union
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel
from tqdm import tqdm
from .monitor import TrainingMonitor
from .model import (
EXPORT_HIDE_DIM,
eval_dataloader,
ResidualBlock,
Expert
)
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module):
def __init__(
self,
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
num_classes=10018,
output_multiplier=2,
d_model=768,
num_resblocks=4,
num_domain_experts=20,
experts_dim=EXPORT_HIDE_DIM,
):
super().__init__()
self.output_multiplier = output_multiplier
# 1. 加载预训练 BERT仅保留 embeddings
bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置
self.experts_dim = experts_dim
# 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1)
self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)])
self.total_experts = 20
self.experts = nn.ModuleList()
for i in range(self.total_experts):
expert = Expert(
input_dim=self.hidden_size,
d_model=self.experts_dim[i],
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
dropout_prob=0.1,
)
self.experts.append(expert)
self.expert_bias = nn.Embedding(
self.total_experts, self.output_multiplier * self.hidden_size
)
# 4. 分类头
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.LayerNorm(self.output_multiplier * self.hidden_size),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg):
"""
input_ids : [batch, seq_len]
attention_mask: [batch, seq_len] (1 为有效0 为 padding)
pg : group_id训练时为 [batch] 的 LongTensor推理导出时为标量 Tensor
"""
# ----- 1. Embeddings -----
embeddings = self.embedding(input_ids) # [B, S, H]
# ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置
padding_mask = attention_mask == 0
encoded = self.encoder(
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
for block in self.res_blocks:
encoded = block(encoded)
# ----- 3. 池化量 -----
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
# ------------------ ONNX 导出模式条件分支batch=1------------------
# 此时 pg 为标量 Tensor转换为 Python int
group_id = pg.item() if torch.is_tensor(pg) else pg
if group_id == 0:
expert_out = self.experts[0](pooled) + self.expert_bias(
torch.tensor(0, device=pooled.device)
)
elif group_id == 1:
expert_out = self.experts[1](pooled) + self.expert_bias(
torch.tensor(1, device=pooled.device)
)
elif group_id == 2:
expert_out = self.experts[2](pooled) + self.expert_bias(
torch.tensor(2, device=pooled.device)
)
elif group_id == 3:
expert_out = self.experts[3](pooled) + self.expert_bias(
torch.tensor(3, device=pooled.device)
)
elif group_id == 4:
expert_out = self.experts[4](pooled) + self.expert_bias(
torch.tensor(4, device=pooled.device)
)
elif group_id == 5:
expert_out = self.experts[5](pooled) + self.expert_bias(
torch.tensor(5, device=pooled.device)
)
elif group_id == 6:
expert_out = self.experts[6](pooled) + self.expert_bias(
torch.tensor(6, device=pooled.device)
)
elif group_id == 7:
expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device)
)
elif group_id == 8: # group_id == 8
expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device)
)
elif group_id == 9: # group_id == 9
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10: # group_id == 10
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11: # group_id == 11
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
elif group_id == 12: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
elif group_id == 13: # group_id == 13
expert_out = self.experts[13](pooled) + self.expert_bias(
torch.tensor(13, device=pooled.device)
)
elif group_id == 14: # group_id == 14
expert_out = self.experts[14](pooled) + self.expert_bias(
torch.tensor(14, device=pooled.device)
)
elif group_id == 15: # group_id == 15
expert_out = self.experts[15](pooled) + self.expert_bias(
torch.tensor(15, device=pooled.device)
)
elif group_id == 16: # group_id == 16
expert_out = self.experts[16](pooled) + self.expert_bias(
torch.tensor(16, device=pooled.device)
)
elif group_id == 17: # group_id == 17
expert_out = self.experts[17](pooled) + self.expert_bias(
torch.tensor(17, device=pooled.device)
)
elif group_id == 18: # group_id == 18
expert_out = self.experts[18](pooled) + self.expert_bias(
torch.tensor(18, device=pooled.device)
)
else: # group_id == 19
expert_out = self.experts[19](pooled) + self.expert_bias(
torch.tensor(19, device=pooled.device)
)
else:
batch_size = pooled.size(0)
# 并行计算所有专家输出
expert_outputs = torch.stack(
[e(pooled) for e in self.experts], dim=0
) # [E, B, D]
# 根据 pg 索引专家输出
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
# 添加专家偏置
bias = self.expert_bias(pg) # [B, D]
expert_out = expert_out + bias
# ----- 5. 分类头 -----
logits = self.classifier(expert_out) # [batch, num_classes]
if not self.training: # 推理时加 Softmax
probs = torch.softmax(logits, dim=-1)
return probs
return logits
def model_eval(self, eval_dataloader, criterion=None):
"""
在验证集上评估模型,返回准确率和平均损失。
参数:
eval_dataloader: DataLoader提供 'input_ids', 'attention_mask', 'pg', 'char_id'
criterion: 损失函数,默认为 CrossEntropyLoss()
返回:
accuracy: float, 准确率
avg_loss: float, 平均损失
"""
if criterion is None:
criterion = nn.CrossEntropyLoss()
self.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
probs = self(input_ids, attention_mask, pg)
log_probs = torch.log(probs + 1e-12)
loss = nn.NLLLoss()(log_probs, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = probs.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def predict(self, sample, debug=False):
"""
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
参数:
sample : dict
必须包含字段:
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
- 'attention_mask': 同上
- 'pg' : [batch] 或标量
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
调试时debug=True必须包含字段
- 'txt' : 字符串列表batch或单个字符串
- 'char' : 字符串列表batch或单个字符串
- 'py' : 字符串列表batch或单个字符串
debug : bool
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py抛出 ValueError。
返回:
preds : torch.Tensor
[batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量)
"""
self.eval()
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0))
# ------------------ 2. 移动设备 ------------------
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
pg = pg.to(self.device)
# ------------------ 3. 推理 ------------------
with torch.no_grad():
logits = self(input_ids, attention_mask, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 调试打印(错误样本) ------------------
if debug:
# 检查必需字段
required_keys = ["char_id", "txt", "char", "py"]
missing = [k for k in required_keys if k not in sample]
if missing:
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
# 提取真实标签
true_labels = sample["char_id"]
if true_labels.dim() == 0:
true_labels = true_labels.unsqueeze(0)
# 移动真实标签到相同设备
true_labels = true_labels.to(self.device)
# 找出预测错误的索引
incorrect_mask = preds != true_labels
incorrect_indices = torch.where(incorrect_mask)[0]
if len(incorrect_indices) > 0:
print("\n=== 预测错误样本 ===")
# 获取调试字段(可能是列表或单个字符串)
txts = sample["txt"]
chars = sample["char"]
pys = sample["py"]
# 统一转换为列表(如果输入是单个字符串)
if isinstance(txts, str):
txts = [txts]
chars = [chars]
pys = [pys]
for idx in incorrect_indices.cpu().numpy():
print(f"样本索引 {idx}:")
print(f" Text : {txts[idx]}")
print(f" Char : {chars[idx]}")
print(f" Pinyin: {pys[idx]}")
print(
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
)
print("===================\n")
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor: Optional[TrainingMonitor] = None,
criterion=nn.CrossEntropyLoss(),
optimizer=None,
num_epochs=1,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
mixed_precision=False,
lr=1e-4,
lr_schedule=None, # 新增:可选的自定义学习率调度函数
):
"""
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
参数:
train_dataloader: DataLoader
训练数据加载器。
eval_dataloader: DataLoader, optional
评估数据加载器。
monitor: TrainingMonitor, optional
训练监控器。
criterion: nn.Module, optional
损失函数。
optimizer: optim.Optimizer, optional
优化器。
num_epochs: int, optional
训练轮数。
eval_frequency: int, optional
评估频率。
grad_accum_steps: int, optional
梯度累积步数。
clip_grad_norm: float, optional
梯度裁剪范数。
mixed_precision: bool, optional
是否使用混合精度。
lr: float, optional
初始学习率。
lr_schedule : callable, optional
自定义学习率调度函数,接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率。
"""
# 确保模型在正确的设备上
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式
super().train()
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
# ---------- 更新 batch 计数器 ----------
processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
if lr_schedule is not None:
# 调用用户自定义的调度函数
lr_schedule(processed_batches, optimizer)
# ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 混合精度前向
with amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
# 反向传播
scaler.scale(loss).backward()
# 梯度累积
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
global_step += 1
original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss
# 周期性评估(与原代码相同)
if (
eval_dataloader is not None
and global_step % eval_frequency == 0
):
avg_loss = batch_loss_sum / eval_frequency
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
super().train()
if monitor is not None:
monitor.add_step(
global_step,
{"loss": avg_loss, "acc": acc},
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
)
batch_loss_sum = 0.0
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False