SUInput/src/trainer/model.py

595 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import pickle
from importlib.resources import files
from pathlib import Path
from typing import Optional, Union
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel, AutoTokenizer
from tqdm.autonotebook import tqdm
from suinput.dataset import PG
from .monitor import TrainingMonitor, send_serverchan_message
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
# ---------------------------- 注意力池化模块(新增)----------------------------
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
# 三个可学习偏置:文本、拼音、个性化
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
def forward(self, x, mask=None, token_type_ids=None):
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
if token_type_ids is not None:
# 根据 token_type_ids 添加对应偏置
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
scores = scores + bias_per_token
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
return pooled
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
# 修复:使用 self.gelu 而不是未定义的 self.relu
x = self.gelu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x)
x = x + residual
return self.gelu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim,
d_model=768,
num_resblocks=4,
output_multiplier=2,
dropout_prob=0.3,
):
super().__init__()
self.output_dim = input_dim * output_multiplier
self.linear_in = nn.Linear(input_dim, d_model)
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim),
)
def forward(self, x):
x = self.linear_in(x)
for block in self.res_blocks:
x = block(x)
return self.output(x)
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module):
def __init__(
self,
pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
num_classes=10018,
output_multiplier=2,
d_model=768,
num_resblocks=4,
num_domain_experts=12,
num_shared_experts=1,
):
super().__init__()
self.output_multiplier = output_multiplier
# 1. 加载预训练 Embedding (Lite: hidden_size=512)
logger.info(f"Loading backbone: {pretrained_model_name}")
bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # 512
self.device = None
# 2. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
# 3. 注意力池化(新增)
self.attn_pool = AttentionPooling(self.hidden_size)
# 4. 专家系统
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList()
for i in range(total_experts):
dropout = 0.3 if i < num_domain_experts else 0.4
self.experts.append(
Expert(
input_dim=self.hidden_size,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier,
dropout_prob=dropout,
)
)
self.expert_bias = nn.Embedding(
total_experts, self.hidden_size * self.output_multiplier
)
# 5. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Dropout(0.4),
nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
nn.GELU(),
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, token_type_ids, pg):
"""
新版 Forward 函数,不再需要 p_start改用 token_type_ids。
Args:
input_ids: [B, L]
attention_mask: [B, L]
token_type_ids: [B, L] (0=文本, 1=拼音)
pg: [B] 拼音组 ID
"""
# ----- 1. Embeddings -----
# 注意:预训练的 embedding 层本身可能已经包含了 token_type_ids 的处理,
# 但这里我们直接使用它的 embedding并手动将 token_type_ids 的嵌入加到上面。
# 由于 bert.embeddings 通常包含 token_type_embeddings我们可以利用它。
# 但为简化,我们直接使用 bert.embeddings(input_ids, token_type_ids=token_type_ids)
# 如果当前 embedding 不支持传入 token_type_ids可以手动相加
# embeddings = self.embedding(input_ids) + self.embedding.token_type_embeddings(token_type_ids)
# 这里采用更通用的方式:假设 self.embedding 有 token_type_ids 参数
embeddings = self.embedding(input_ids, token_type_ids=token_type_ids)
# ----- 2. Transformer Encoder -----
padding_mask = attention_mask == 0
encoded = self.encoder(
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
# ----- 3. 注意力池化(代替原来的 Span Pooling-----
# 使用 attention_mask 忽略 padding 位置
pooled = self.attn_pool(encoded, attention_mask, token_type_ids) # [B, H]
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
# ONNX 导出模式batch=1根据 pg 选择专家
group_id = pg.item() if torch.is_tensor(pg) else pg
# 注意:专家索引从 0 开始,确保所有 case 都覆盖且偏置正确
# 使用字典映射或 if-elifONNX 需要静态图,此处保持原样但修正索引错误)
if group_id == 0:
expert_out = self.experts[0](pooled) + self.expert_bias(
torch.tensor(0, device=pooled.device)
)
elif group_id == 1:
expert_out = self.experts[1](pooled) + self.expert_bias(
torch.tensor(1, device=pooled.device)
)
elif group_id == 2:
expert_out = self.experts[2](pooled) + self.expert_bias(
torch.tensor(2, device=pooled.device)
)
elif group_id == 3:
expert_out = self.experts[3](pooled) + self.expert_bias(
torch.tensor(3, device=pooled.device)
)
elif group_id == 4:
expert_out = self.experts[4](pooled) + self.expert_bias(
torch.tensor(4, device=pooled.device)
)
elif group_id == 5:
expert_out = self.experts[5](pooled) + self.expert_bias(
torch.tensor(5, device=pooled.device)
)
elif group_id == 6:
expert_out = self.experts[6](pooled) + self.expert_bias(
torch.tensor(6, device=pooled.device)
)
elif group_id == 7:
expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device)
)
elif group_id == 8:
expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device)
)
elif group_id == 9:
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10:
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11:
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
else: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
else:
batch_size = pooled.size(0)
expert_outputs = torch.stack(
[e(pooled) for e in self.experts], dim=0
) # [E, B, D]
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
bias = self.expert_bias(pg) # [B, D]
expert_out = expert_out + bias
# ----- 5. 分类头 -----
return self.classifier(expert_out) # [batch, num_classes]
def model_eval(self, eval_dataloader, criterion):
"""
评估模型在验证集上的性能。
Args:
eval_dataloader (DataLoader): 验证集的数据加载器每个batch包含以下字段
- hint: 包含input_ids、attention_mask和token_type_ids的字典
- pg: 程序图数据
- char_id: 字符ID标签
criterion (callable): 损失函数,用于计算模型输出与标签之间的损失
Returns:
tuple: 包含两个浮点数的元组 (accuracy, avg_loss)
- accuracy (float): 模型在验证集上的准确率
- avg_loss (float): 模型在验证集上的平均损失
Note:
该方法会自动将模型切换到评估模式(self.eval())
并使用torch.no_grad()上下文管理器来禁用梯度计算,
以节省内存和计算资源。
"""
self.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_dataloader:
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
logits = self(input_ids, attention_mask, token_type_ids, pg)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def gen_predict_sample(self, text, py, tokenizer=None):
"""
生成用于预测的样本数据。
该方法将文本和拼音转换为模型所需的输入格式包括input_ids、attention_mask和token_type_ids。
如果没有提供tokenizer会使用默认的AutoTokenizer。
Args:
text (str): 输入文本,作为第一句输入。
py (str): 拼音字符串,作为第二句输入。
tokenizer (AutoTokenizer, optional): 分词器实例。如果为None且self.tokenizer不存在
则会创建默认的分词器。默认为None。
Returns:
dict: 包含模型输入的字典,格式为:
{
"hint": {
"input_ids": tensor, # 文本和拼音的token IDs
"attention_mask": tensor, # 注意力掩码
"token_type_ids": tensor # 句子类型ID
},
"pg": tensor # 拼音组ID根据拼音首字母生成
}
Notes:
- 使用text_pair参数让分词器自动生成token_type_ids
- 确保分词器支持return_token_type_ids=True
- 最大长度(max_length)设置为88
- 会自动进行padding和truncation处理
- 拼音组ID当前根据拼音首字母生成可根据实际需要改进
"""
if tokenizer is None and not hasattr(self, "tokenizer"):
self.tokenizer = AutoTokenizer.from_pretrained(
"iic/nlp_structbert_backbone_lite_std"
)
else:
self.tokenizer = tokenizer or self.tokenizer
# 使用 text_pair 参数让分词器自动生成 token_type_ids
# 注意:确保分词器支持 return_token_type_ids=True
encoded = self.tokenizer(
text, # 文本作为第一句
py, # 拼音作为第二句
max_length=88,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True, # 显式要求返回 token_type_ids
)
sample = {
"hint": {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"token_type_ids": encoded["token_type_ids"], # 新增
},
"pg": torch.tensor(
[PG[py[0]] if py != "" else 12]
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
}
return sample
def predict(self, text, py, tokenizer=None):
"""
预测函数,自动处理 batch 维度
Args:
text (str or List[str]): 输入文本或文本列表
py (int or List[int]): 拼音特征,可以是单个值或列表
tokenizer (object, optional): 分词器对象用于文本预处理。默认为None
Returns:
torch.Tensor: 预测结果,如果是单个输入则返回一维张量,
如果是批量输入则返回二维张量
"""
self.eval()
sample = self.gen_predict_sample(text, py, tokenizer)
input_ids = sample["hint"]["input_ids"]
attention_mask = sample["hint"]["attention_mask"]
token_type_ids = sample["hint"]["token_type_ids"]
pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
token_type_ids = token_type_ids.unsqueeze(0)
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0))
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
token_type_ids = token_type_ids.to(self.device)
pg = pg.to(self.device)
with torch.no_grad():
logits = self(input_ids, attention_mask, token_type_ids, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
if not has_batch_dim:
return preds.squeeze(0)
return preds
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor: Optional[TrainingMonitor] = None,
criterion=None,
optimizer=None,
num_epochs=1,
stop_batch=2e5,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
):
"""训练函数,调整了输入参数"""
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
self.train()
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing
)
else:
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0
global_step = 0 # 初始化
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
# LR Schedule
if processed_batches < warmup_steps:
current_lr = lr * (processed_batches/ warmup_steps)
else:
progress = (processed_batches - warmup_steps) / (
total_steps - warmup_steps
)
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# 移动数据注意batch 中现在包含 token_type_ids
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, token_type_ids, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
scaler.scale(loss).backward()
if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
if global_step % eval_frequency == 0:
if eval_dataloader:
self.eval()
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
self.train()
if monitor:
# 使用 eval_loss 作为监控指标
monitor.add_step(
global_step, {"loss": batch_loss_sum / (eval_frequency if global_step > 0 else 1), "acc": acc}
)
logger.info(
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
)
else:
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
batch_loss_sum = 0.0
if processed_batches >= stop_batch:
break
processed_batches += 1
global_step += 1
# 训练结束发送通知
try:
res_acc, res_loss = self.model_eval(eval_dataloader, criterion)
send_serverchan_message(
title="训练完成",
content=f"acc: {res_acc:.4f}, loss: {res_loss:.4f}",
)
logger.info(f"训练完成acc: {res_acc:.4f}, loss: {res_loss:.4f}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False
# --- ONNX 导出辅助函数 ---
def export_onnx(self, output_path, dummy_input):
"""
dummy_input 应该是一个元组,包含:
(input_ids, attention_mask, token_type_ids, pg)
"""
self.eval()
input_names = ["input_ids", "attention_mask", "token_type_ids", "pg"]
output_names = ["logits"]
torch.onnx.export(
self,
dummy_input,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"token_type_ids": {0: "batch_size", 1: "seq_len"},
"pg": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=14,
do_constant_folding=True,
)
logger.info(f"ONNX model exported to {output_path}")