449 lines
18 KiB
Python
449 lines
18 KiB
Python
import math
|
||
from pathlib import Path
|
||
from typing import Optional, Union
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.amp as amp
|
||
import torch.optim as optim
|
||
import torch.nn.functional as F
|
||
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from tqdm.notebook import tqdm
|
||
|
||
|
||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||
|
||
|
||
class InputMethodEngine(nn.Module):
|
||
def __init__(
|
||
self,
|
||
pretrained_encoder, # 已加载并扩展好的预训练编码器
|
||
output_vocab_size: int,
|
||
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
|
||
max_slot_steps: int = 24,
|
||
num_experts: int = 20,
|
||
top_k: int = 3,
|
||
expert_res_blocks: int = 4,
|
||
dropout: float = 0.3,
|
||
use_attention_pooling: bool = False,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
self.hidden_size = hidden_size
|
||
self.output_vocab_size = output_vocab_size
|
||
self.max_slot_steps = max_slot_steps
|
||
self.num_experts = num_experts
|
||
self.top_k = top_k
|
||
self.use_attention_pooling = use_attention_pooling
|
||
|
||
# 预训练编码器
|
||
self.encoder = pretrained_encoder
|
||
|
||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||
|
||
self.cross_attention = nn.MultiheadAttention(
|
||
embed_dim=hidden_size,
|
||
num_heads=4, # 可配置
|
||
dropout=dropout,
|
||
batch_first=True,
|
||
)
|
||
|
||
if use_attention_pooling:
|
||
self.attention_pooling = AttentionPooling(hidden_size)
|
||
|
||
self.gate = nn.Linear(hidden_size, num_experts)
|
||
self.experts = nn.ModuleList(
|
||
[
|
||
Expert(
|
||
input_dim=hidden_size,
|
||
d_model=hidden_size,
|
||
num_resblocks=expert_res_blocks,
|
||
output_multiplier=1,
|
||
dropout_prob=dropout,
|
||
)
|
||
for _ in range(num_experts)
|
||
]
|
||
)
|
||
|
||
self.classifier = nn.Sequential(
|
||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||
nn.LayerNorm(hidden_size),
|
||
nn.GELU(),
|
||
nn.Linear(hidden_size, output_vocab_size),
|
||
)
|
||
|
||
self._init_weights()
|
||
|
||
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
||
"""
|
||
使用预训练编码器编码文本
|
||
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
|
||
"""
|
||
outputs = self.encoder(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
return_dict=True,
|
||
)
|
||
# 取 last_hidden_state [batch, seq_len, hidden]
|
||
return outputs.last_hidden_state
|
||
|
||
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
||
"""
|
||
单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布
|
||
context: [batch, seq_len, hidden] 文本编码结果
|
||
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接)
|
||
slot_seq_mask: [batch, current_len] 有效位置mask(1有效)
|
||
返回: [batch, output_vocab_size] 概率分布
|
||
"""
|
||
batch_size = slot_seq_emb.size(0)
|
||
# 交叉注意力:Query是槽位序列(通常只取最后一个步的嵌入作为Query,但这里我们使用整个序列)
|
||
# 为了简单,我们使用整个序列作为Query,然后取最后一个位置的输出(因为自回归)
|
||
# 方法1:Query = 最后一个位置的嵌入(单个向量)
|
||
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
|
||
# 交叉注意力
|
||
attn_out, _ = self.cross_attention(
|
||
query=last_query,
|
||
key=context,
|
||
value=context,
|
||
key_padding_mask=(
|
||
context.sum(-1) == 0
|
||
), # 忽略填充位置,实际应传入attention_mask
|
||
) # [batch, 1, hidden]
|
||
attn_out = attn_out.squeeze(1) # [batch, hidden]
|
||
|
||
# 门控网络:选择top-k专家
|
||
gate_logits = self.gate(attn_out) # [batch, num_experts]
|
||
topk_weights, topk_indices = torch.topk(
|
||
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
|
||
)
|
||
# 归一化权重
|
||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||
|
||
# 计算专家输出加权和
|
||
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
|
||
for i in range(self.top_k):
|
||
expert_idx = topk_indices[:, i] # [batch]
|
||
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
|
||
# 批量获取专家输出(需对每个样本分别处理,或用循环)
|
||
# 这里简单循环,实际可优化为 gather
|
||
for b in range(batch_size):
|
||
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
|
||
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
|
||
|
||
# 分类头
|
||
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
|
||
return logits
|
||
|
||
def forward_train(
|
||
self,
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
slot_target_ids,
|
||
slot_target_mask=None,
|
||
):
|
||
"""
|
||
训练模式:一次性并行计算所有槽位步的预测
|
||
slot_target_ids: [batch, max_slot_steps] 真实标签(teacher forcing)
|
||
slot_target_mask: [batch, max_slot_steps] 有效步mask(1表示该步存在)
|
||
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
|
||
"""
|
||
batch_size, max_steps = slot_target_ids.shape
|
||
device = input_ids.device
|
||
|
||
# 1. 编码文本
|
||
context = self.encode_text(
|
||
input_ids, token_type_ids, attention_mask
|
||
) # [b, L, h]
|
||
|
||
# 2. 构建槽位序列嵌入(teacher forcing:使用真实标签构建)
|
||
# 将每个样本的槽位ID转为embedding,并添加位置编码
|
||
# slot_target_ids 形状 [b, T]
|
||
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
|
||
# 位置编码(位置从0开始)
|
||
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
|
||
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
|
||
slot_emb = slot_emb + pos_emb
|
||
# 可选:加入mask(无效位置置0)
|
||
if slot_target_mask is not None:
|
||
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
|
||
|
||
# 3. 交叉注意力:Query为整个槽位序列(每个位置独立预测)
|
||
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
|
||
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
|
||
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
|
||
attn_out, _ = self.cross_attention(
|
||
query=slot_emb, # [b, T, h]
|
||
key=context, # [b, L, h]
|
||
value=context,
|
||
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
|
||
) # [b, T, h]
|
||
|
||
# 4. 对每个槽位步分别通过门控+专家+分类头
|
||
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
|
||
b, T, h = attn_out.shape
|
||
attn_flat = attn_out.view(b * T, h) # [b*T, h]
|
||
|
||
# 门控网络
|
||
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
|
||
gate_probs = F.softmax(gate_logits, dim=-1)
|
||
topk_probs, topk_indices = torch.topk(
|
||
gate_probs, self.top_k, dim=-1
|
||
) # [b*T, top_k]
|
||
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
|
||
|
||
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
|
||
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
|
||
for i in range(self.top_k):
|
||
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
|
||
idx = topk_indices[:, i] # [b*T]
|
||
# 批量获取专家输出
|
||
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
|
||
# 实际生产环境应优化为组卷积或使用 einsum,此处保持清晰
|
||
for expert_id in range(self.num_experts):
|
||
mask = idx == expert_id
|
||
if mask.any():
|
||
# 取出属于该专家的样本
|
||
sub_input = attn_flat[mask] # [k, h]
|
||
sub_output = self.experts[expert_id](sub_input) # [k, h]
|
||
expert_out_flat[mask] += weight[mask] * sub_output
|
||
|
||
# 分类头
|
||
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
|
||
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
|
||
|
||
return logits
|
||
|
||
def forward(
|
||
self,
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
slot_target_ids=None,
|
||
slot_target_mask=None,
|
||
mode="train",
|
||
):
|
||
"""
|
||
统一接口
|
||
mode='train': 使用 teacher forcing 并行计算所有步 logits,返回 [b, T, vocab]
|
||
mode='infer': 需结合外部循环,使用 forward_single_step
|
||
"""
|
||
if mode == "train":
|
||
return self.forward_train(
|
||
input_ids,
|
||
token_type_ids,
|
||
attention_mask,
|
||
slot_target_ids,
|
||
slot_target_mask,
|
||
)
|
||
else:
|
||
raise NotImplementedError(
|
||
"Inference mode should call forward_single_step directly"
|
||
)
|
||
|
||
def to(self, device):
|
||
"""重写 to 方法,记录设备"""
|
||
self.device = device
|
||
return super().to(device)
|
||
|
||
def fit(
|
||
self,
|
||
train_dataloader,
|
||
eval_dataloader=None,
|
||
monitor=None,
|
||
criterion=None,
|
||
optimizer=None,
|
||
num_epochs=1,
|
||
stop_batch=2e5,
|
||
eval_frequency=500,
|
||
grad_accum_steps=1,
|
||
clip_grad_norm=1.0,
|
||
loss_weight=None,
|
||
mixed_precision=True,
|
||
weight_decay=0.1,
|
||
warmup_ratio=0.1,
|
||
label_smoothing=0.15,
|
||
lr=1e-4,
|
||
lr_schedule=None,
|
||
save_dir=None,
|
||
save_frequency=1000,
|
||
):
|
||
"""
|
||
训练模型
|
||
"""
|
||
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
|
||
if _processed_batches < _warmup_steps:
|
||
current_lr = _lr * (_processed_batches / _warmup_steps)
|
||
else:
|
||
progress = (_processed_batches - _warmup_steps) / (
|
||
_stop_batch - _warmup_steps
|
||
)
|
||
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||
return current_lr
|
||
|
||
if self.device is None:
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.to(self.device)
|
||
if lr_schedule is None:
|
||
lr_schedule = default_lr_schedule
|
||
|
||
self.train()
|
||
|
||
if optimizer is None:
|
||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||
|
||
# 损失函数:需要设置 ignore_index=-1,因为标签中无效位置用 -1 表示
|
||
if criterion is None:
|
||
if loss_weight is not None:
|
||
criterion = nn.CrossEntropyLoss(
|
||
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
|
||
)
|
||
else:
|
||
criterion = nn.CrossEntropyLoss(
|
||
label_smoothing=label_smoothing, ignore_index=-1
|
||
)
|
||
|
||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||
|
||
total_steps = stop_batch
|
||
warmup_steps = int(total_steps * warmup_ratio)
|
||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||
processed_batches = 0
|
||
batch_loss_sum = 0.0
|
||
optimizer.zero_grad()
|
||
|
||
try:
|
||
for epoch in range(num_epochs):
|
||
for batch_idx, batch in enumerate(
|
||
tqdm(train_dataloader, total=int(stop_batch))
|
||
):
|
||
# 学习率调度
|
||
current_lr = lr_schedule(
|
||
lr, stop_batch, processed_batches, warmup_steps
|
||
)
|
||
for param_group in optimizer.param_groups:
|
||
param_group["lr"] = current_lr
|
||
|
||
# 从 batch 中获取数据
|
||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
|
||
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
|
||
|
||
# 构建 slot_target_mask:有效位置为 1,无效位置为 0(假设无效标签为 -1)
|
||
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
|
||
|
||
with torch.amp.autocast(
|
||
device_type=self.device.type, enabled=mixed_precision
|
||
):
|
||
# 调用模型(训练模式)
|
||
logits = self(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
slot_target_ids=labels,
|
||
slot_target_mask=slot_target_mask,
|
||
mode="train",
|
||
) # logits: [batch, max_slot_steps, output_vocab_size]
|
||
|
||
# 计算损失(忽略填充位置,ignore_index=-1 已在 criterion 中设置)
|
||
loss = criterion(
|
||
logits.view(-1, self.output_vocab_size), labels.view(-1)
|
||
)
|
||
loss = loss / grad_accum_steps
|
||
|
||
scaler.scale(loss).backward()
|
||
|
||
# 梯度累积更新
|
||
if (processed_batches + 1) % grad_accum_steps == 0:
|
||
scaler.unscale_(optimizer)
|
||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
optimizer.zero_grad()
|
||
batch_loss_sum += loss.item() * grad_accum_steps
|
||
|
||
# 定期评估
|
||
if processed_batches % eval_frequency == 0:
|
||
if eval_dataloader:
|
||
self.eval()
|
||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||
self.train()
|
||
if monitor:
|
||
monitor.add_step(
|
||
processed_batches,
|
||
{
|
||
"train_loss": batch_loss_sum
|
||
/ (eval_frequency if processed_batches > 0 else 1),
|
||
"acc": acc,
|
||
"loss": eval_loss,
|
||
"lr": current_lr,
|
||
},
|
||
)
|
||
logger.info(
|
||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
|
||
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||
f"current_lr: {current_lr}"
|
||
)
|
||
else:
|
||
logger.info(
|
||
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||
f"current_lr: {current_lr}"
|
||
)
|
||
batch_loss_sum = 0.0
|
||
|
||
processed_batches += 1
|
||
if processed_batches >= stop_batch:
|
||
break
|
||
|
||
else:
|
||
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
|
||
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
|
||
# 但为了兼容原有逻辑,这里不做额外处理
|
||
pass
|
||
|
||
# 训练结束通知
|
||
if monitor:
|
||
monitor.finish()
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("Training interrupted by user")
|
||
|
||
|
||
|
||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||
state_dict = torch.load(
|
||
state_dict_path, weights_only=True, map_location=self.device
|
||
)
|
||
self.load_state_dict(state_dict)
|
||
|
||
def load_from_pretrained_base_model(
|
||
self,
|
||
BaseModel,
|
||
snapshot_path: Union[str, Path],
|
||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||
*args,
|
||
**kwargs,
|
||
):
|
||
base_model = BaseModel(*args, **kwargs)
|
||
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||
self_static_dict = self.state_dict()
|
||
pretrained_dict = base_model.state_dict()
|
||
|
||
freeze_layers = []
|
||
|
||
for key in self_static_dict.keys():
|
||
if key in pretrained_dict.keys():
|
||
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||
freeze_layers.append(key)
|
||
self.load_state_dict(self_static_dict)
|
||
for name, param in self.named_parameters():
|
||
if name in freeze_layers:
|
||
param.requires_grad = False
|
||
|