SUimeModelTraner/src/model/model.py

449 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.amp as amp
import torch.optim as optim
import torch.nn.functional as F
from loguru import logger
from modelscope import AutoTokenizer
from tqdm.notebook import tqdm
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
class InputMethodEngine(nn.Module):
def __init__(
self,
pretrained_encoder, # 已加载并扩展好的预训练编码器
output_vocab_size: int,
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
max_slot_steps: int = 24,
num_experts: int = 20,
top_k: int = 3,
expert_res_blocks: int = 4,
dropout: float = 0.3,
use_attention_pooling: bool = False,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.output_vocab_size = output_vocab_size
self.max_slot_steps = max_slot_steps
self.num_experts = num_experts
self.top_k = top_k
self.use_attention_pooling = use_attention_pooling
# 预训练编码器
self.encoder = pretrained_encoder
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
self.cross_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=4, # 可配置
dropout=dropout,
batch_first=True,
)
if use_attention_pooling:
self.attention_pooling = AttentionPooling(hidden_size)
self.gate = nn.Linear(hidden_size, num_experts)
self.experts = nn.ModuleList(
[
Expert(
input_dim=hidden_size,
d_model=hidden_size,
num_resblocks=expert_res_blocks,
output_multiplier=1,
dropout_prob=dropout,
)
for _ in range(num_experts)
]
)
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=False),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Linear(hidden_size, output_vocab_size),
)
self._init_weights()
def encode_text(self, input_ids, token_type_ids, attention_mask):
"""
使用预训练编码器编码文本
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
"""
outputs = self.encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
return_dict=True,
)
# 取 last_hidden_state [batch, seq_len, hidden]
return outputs.last_hidden_state
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
"""
单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布
context: [batch, seq_len, hidden] 文本编码结果
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接)
slot_seq_mask: [batch, current_len] 有效位置mask1有效
返回: [batch, output_vocab_size] 概率分布
"""
batch_size = slot_seq_emb.size(0)
# 交叉注意力Query是槽位序列通常只取最后一个步的嵌入作为Query但这里我们使用整个序列
# 为了简单我们使用整个序列作为Query然后取最后一个位置的输出因为自回归
# 方法1Query = 最后一个位置的嵌入(单个向量)
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
# 交叉注意力
attn_out, _ = self.cross_attention(
query=last_query,
key=context,
value=context,
key_padding_mask=(
context.sum(-1) == 0
), # 忽略填充位置实际应传入attention_mask
) # [batch, 1, hidden]
attn_out = attn_out.squeeze(1) # [batch, hidden]
# 门控网络选择top-k专家
gate_logits = self.gate(attn_out) # [batch, num_experts]
topk_weights, topk_indices = torch.topk(
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
)
# 归一化权重
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# 计算专家输出加权和
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # [batch]
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
# 批量获取专家输出(需对每个样本分别处理,或用循环)
# 这里简单循环,实际可优化为 gather
for b in range(batch_size):
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
# 分类头
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
return logits
def forward_train(
self,
input_ids,
token_type_ids,
attention_mask,
slot_target_ids,
slot_target_mask=None,
):
"""
训练模式:一次性并行计算所有槽位步的预测
slot_target_ids: [batch, max_slot_steps] 真实标签teacher forcing
slot_target_mask: [batch, max_slot_steps] 有效步mask1表示该步存在
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
"""
batch_size, max_steps = slot_target_ids.shape
device = input_ids.device
# 1. 编码文本
context = self.encode_text(
input_ids, token_type_ids, attention_mask
) # [b, L, h]
# 2. 构建槽位序列嵌入teacher forcing使用真实标签构建
# 将每个样本的槽位ID转为embedding并添加位置编码
# slot_target_ids 形状 [b, T]
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
# 位置编码位置从0开始
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
slot_emb = slot_emb + pos_emb
# 可选加入mask无效位置置0
if slot_target_mask is not None:
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
# 3. 交叉注意力Query为整个槽位序列每个位置独立预测
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
attn_out, _ = self.cross_attention(
query=slot_emb, # [b, T, h]
key=context, # [b, L, h]
value=context,
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
) # [b, T, h]
# 4. 对每个槽位步分别通过门控+专家+分类头
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
b, T, h = attn_out.shape
attn_flat = attn_out.view(b * T, h) # [b*T, h]
# 门控网络
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
gate_probs = F.softmax(gate_logits, dim=-1)
topk_probs, topk_indices = torch.topk(
gate_probs, self.top_k, dim=-1
) # [b*T, top_k]
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
for i in range(self.top_k):
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
idx = topk_indices[:, i] # [b*T]
# 批量获取专家输出
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
# 实际生产环境应优化为组卷积或使用 einsum此处保持清晰
for expert_id in range(self.num_experts):
mask = idx == expert_id
if mask.any():
# 取出属于该专家的样本
sub_input = attn_flat[mask] # [k, h]
sub_output = self.experts[expert_id](sub_input) # [k, h]
expert_out_flat[mask] += weight[mask] * sub_output
# 分类头
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
return logits
def forward(
self,
input_ids,
token_type_ids,
attention_mask,
slot_target_ids=None,
slot_target_mask=None,
mode="train",
):
"""
统一接口
mode='train': 使用 teacher forcing 并行计算所有步 logits返回 [b, T, vocab]
mode='infer': 需结合外部循环,使用 forward_single_step
"""
if mode == "train":
return self.forward_train(
input_ids,
token_type_ids,
attention_mask,
slot_target_ids,
slot_target_mask,
)
else:
raise NotImplementedError(
"Inference mode should call forward_single_step directly"
)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor=None,
criterion=None,
optimizer=None,
num_epochs=1,
stop_batch=2e5,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
lr_schedule=None,
save_dir=None,
save_frequency=1000,
):
"""
训练模型
"""
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
if _processed_batches < _warmup_steps:
current_lr = _lr * (_processed_batches / _warmup_steps)
else:
progress = (_processed_batches - _warmup_steps) / (
_stop_batch - _warmup_steps
)
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
return current_lr
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
if lr_schedule is None:
lr_schedule = default_lr_schedule
self.train()
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数:需要设置 ignore_index=-1因为标签中无效位置用 -1 表示
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
)
else:
criterion = nn.CrossEntropyLoss(
label_smoothing=label_smoothing, ignore_index=-1
)
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0
batch_loss_sum = 0.0
optimizer.zero_grad()
try:
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
# 学习率调度
current_lr = lr_schedule(
lr, stop_batch, processed_batches, warmup_steps
)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# 从 batch 中获取数据
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
# 构建 slot_target_mask有效位置为 1无效位置为 0假设无效标签为 -1
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
# 调用模型(训练模式)
logits = self(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
slot_target_ids=labels,
slot_target_mask=slot_target_mask,
mode="train",
) # logits: [batch, max_slot_steps, output_vocab_size]
# 计算损失忽略填充位置ignore_index=-1 已在 criterion 中设置)
loss = criterion(
logits.view(-1, self.output_vocab_size), labels.view(-1)
)
loss = loss / grad_accum_steps
scaler.scale(loss).backward()
# 梯度累积更新
if (processed_batches + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
# 定期评估
if processed_batches % eval_frequency == 0:
if eval_dataloader:
self.eval()
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
self.train()
if monitor:
monitor.add_step(
processed_batches,
{
"train_loss": batch_loss_sum
/ (eval_frequency if processed_batches > 0 else 1),
"acc": acc,
"loss": eval_loss,
"lr": current_lr,
},
)
logger.info(
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
f"current_lr: {current_lr}"
)
else:
logger.info(
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
f"current_lr: {current_lr}"
)
batch_loss_sum = 0.0
processed_batches += 1
if processed_batches >= stop_batch:
break
else:
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
# 但为了兼容原有逻辑,这里不做额外处理
pass
# 训练结束通知
if monitor:
monitor.finish()
except KeyboardInterrupt:
logger.info("Training interrupted by user")
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False