SUimeModelTraner/src/model/components.py

429 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope import AutoModel
# ---------------------------- 注意力池化模块----------------------------
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
# 三个可学习偏置:文本、拼音、个性化
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
def forward(self, x, mask=None, token_type_ids=None):
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
if token_type_ids is not None:
# 根据 token_type_ids 添加对应偏置
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
scores = scores + bias_per_token
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
return pooled
# ---------------------------- 拼音LSTM编码器 ----------------------------
class PinyinLSTMEncoder(nn.Module):
"""
拼音序列编码器,返回每位置的拼音编码。
使用双向LSTM每个位置都能看到前后文信息。
"""
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
self.num_layers = num_layers
self.dropout = dropout
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=self.hidden_dim,
num_layers=num_layers,
bidirectional=True,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, input_dim] pinyin embeddings
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
Returns:
output: [batch, seq_len, input_dim] 每位置的拼音编码
"""
total_len = x.size(1)
if mask is not None:
lengths = mask.sum(dim=1).cpu().clamp(min=1)
packed = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False
)
packed_out, (hidden, cell) = self.lstm(packed)
output, _ = nn.utils.rnn.pad_packed_sequence(
packed_out, batch_first=True, total_length=total_len
)
else:
output, (hidden, cell) = self.lstm(x)
projected = self.proj(output)
return self.layer_norm(projected)
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
# 修复:使用 self.gelu 而不是未定义的 self.relu
x = self.gelu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x)
x = x + residual
return self.gelu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim,
d_model=512,
num_resblocks=4,
output_multiplier=1,
dropout_prob=0.3,
):
super().__init__()
self.output_dim = input_dim * output_multiplier
self.linear_in = nn.Linear(input_dim, d_model)
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim),
)
def forward(self, x):
x = self.linear_in(x)
for block in self.res_blocks:
x = block(x)
return self.output(x)
# ------------------------------------------------------------------
# 1. 上下文编码器 (Context Encoder)
# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1]
# ------------------------------------------------------------------
class ContextEncoder(nn.Module):
def __init__(
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
):
super().__init__()
self.dim = dim
# Embeddings
self.text_emb = AutoModel.from_pretrained(
"iic/nlp_structbert_backbone_lite_std"
).embeddings
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(max_len, dim)
self.pinyin_pooling = PinyinLSTMEncoder(dim)
# Transformer Encoder (4 layers, 4 heads) [1]
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=n_heads,
dim_feedforward=dim * 4,
dropout=0.1,
batch_first=True, # 方便处理 [B, L, D]
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# LayerNorm for stability
self.ln = nn.LayerNorm(dim)
def forward(self, text_ids, pinyin_ids, mask=None):
"""
Args:
text_ids: [batch, seq_len] 文本 token ids
pinyin_ids: [batch, pinyin_len] 拼音 token ids
mask: [batch, seq_len] optional padding mask (1 for valid, 0 for padding)
Returns:
H: [batch, seq_len, dim] 文本上下文编码
P: [batch, pinyin_len, dim] 拼音序列编码(每位置)
"""
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
seq_len = text_emb.size(1)
pos_ids = torch.arange(seq_len, device=text_ids.device).unsqueeze(0)
x = text_emb + self.pos_emb(pos_ids)
if mask is not None:
src_mask = mask == 0
else:
src_mask = None
H = self.transformer(x, src_key_padding_mask=src_mask)
H = self.ln(H)
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
pinyin_mask = pinyin_ids != 0
P = self.pinyin_pooling(pinyin_emb, mask=pinyin_mask) # [B, pinyin_len, dim]
return H, P
# ------------------------------------------------------------------
# 2. 槽位记忆模块 (Slot Memory)
# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1]
# ------------------------------------------------------------------
class SlotMemory(nn.Module):
def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512):
super().__init__()
self.max_slots = max_slots
self.steps_per_slot = steps_per_slot
self.total_steps = max_slots * steps_per_slot # 24 steps [1]
# Shared embedding layer for history tokens [1]
self.emb = nn.Embedding(vocab_size, dim)
# Learnable positional embeddings for the flattened sequence [1]
self.pos_emb = nn.Embedding(self.total_steps, dim)
# Start token embedding for empty slots [1]
self.start_emb = nn.Parameter(torch.randn(1, 1, dim))
def forward(self, history_ids):
"""
Args:
history_ids: [batch, total_steps]
Flattened sequence of history tokens.
Zero positions use start_emb as learnable query.
Returns:
S: [batch, total_steps, 512] Slot sequence representation [1]
"""
S = self.emb(history_ids) # [B, num_slots, dim]
zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1]
S = S * (1 - zero_mask) + self.start_emb * zero_mask
pos_ids = (
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
)
S += self.pos_emb(pos_ids)
return S
# ------------------------------------------------------------------
# 3. 交叉注意力融合 (Cross-Attention Fusion)
# 槽位同时查询文本上下文和拼音序列
# ------------------------------------------------------------------
class CrossAttentionFusion(nn.Module):
"""
双路交叉注意力:
- 槽位查询文本上下文 (H)
- 槽位查询拼音序列 (P)
通过注意力机制start_emb 自动学会关注"当前应该预测的拼音位置"
"""
def __init__(self, dim=512, n_heads=4):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
assert self.head_dim * n_heads == dim
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_text_proj = nn.Linear(dim, dim, bias=False)
self.v_text_proj = nn.Linear(dim, dim, bias=False)
self.k_pinyin_proj = nn.Linear(dim, dim, bias=False)
self.v_pinyin_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.ln = nn.LayerNorm(dim)
def forward(
self, slots_S, context_H, pinyin_P, context_mask=None, pinyin_mask=None
):
"""
Args:
slots_S: [batch, num_slots, dim] 槽位编码
context_H: [batch, ctx_len, dim] 文本上下文编码
pinyin_P: [batch, pinyin_len, dim] 拼音序列编码
context_mask: [batch, ctx_len] 文本 padding mask (True for padding)
pinyin_mask: [batch, pinyin_len] 拼音 padding mask (True for padding)
Returns:
fused: [batch, num_slots, dim]
"""
batch_size, num_slots, _ = slots_S.shape
ctx_len = context_H.size(1)
pinyin_len = pinyin_P.size(1)
Q = self.q_proj(slots_S)
K_text = self.k_text_proj(context_H)
V_text = self.v_text_proj(context_H)
K_pinyin = self.k_pinyin_proj(pinyin_P)
V_pinyin = self.v_pinyin_proj(pinyin_P)
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
K_text = K_text.view(
batch_size, ctx_len, self.n_heads, self.head_dim
).transpose(1, 2)
V_text = V_text.view(
batch_size, ctx_len, self.n_heads, self.head_dim
).transpose(1, 2)
K_pinyin = K_pinyin.view(
batch_size, pinyin_len, self.n_heads, self.head_dim
).transpose(1, 2)
V_pinyin = V_pinyin.view(
batch_size, pinyin_len, self.n_heads, self.head_dim
).transpose(1, 2)
text_attn_mask = None
if context_mask is not None:
text_attn_mask = (
context_mask[:, None, None, :]
.float()
.masked_fill(context_mask[:, None, None, :], -1e9)
)
pinyin_attn_mask = None
if pinyin_mask is not None:
pinyin_attn_mask = (
pinyin_mask[:, None, None, :]
.float()
.masked_fill(pinyin_mask[:, None, None, :], -1e9)
)
text_attn = F.scaled_dot_product_attention(
Q, K_text, V_text, attn_mask=text_attn_mask
)
pinyin_attn = F.scaled_dot_product_attention(
Q, K_pinyin, V_pinyin, attn_mask=pinyin_attn_mask
)
combined_attn = text_attn + pinyin_attn
combined_attn = (
combined_attn.transpose(1, 2)
.contiguous()
.view(batch_size, num_slots, self.dim)
)
fused = self.out_proj(combined_attn)
fused = self.ln(fused + slots_S)
return fused
# ------------------------------------------------------------------
# 4. 专家混合层 (MoE Layer)
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
# ------------------------------------------------------------------
@torch.compiler.allow_in_graph
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
output = torch.zeros_like(x_flat)
for e in range(num_experts):
mask = topk_indices == e
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
return output
class MoELayer(nn.Module):
"""
moe_mode 支持三种策略:
- "all": 计算全部专家torch.compile 不断裂 (当前默认)
- "sparse": 只计算被路由到的专家 (产生 graph break)
- "sparse_allow_graph": 稀疏 MoE通过 allow_in_graph 避免 graph break
"""
def __init__(
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.dim = dim
self.moe_mode = moe_mode
self.experts = nn.ModuleList(
[
Expert(
input_dim=dim,
d_model=dim,
num_resblocks=num_resblocks,
output_multiplier=1,
)
for _ in range(num_experts)
]
)
self.gate = nn.Linear(dim, num_experts)
def forward(self, x):
B, L, D = x.shape
num_tokens = B * L
x_flat = x.view(num_tokens, D)
gates = self.gate(x_flat)
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
topk_weights = F.softmax(topk_weights, dim=-1)
if self.moe_mode == "all":
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
)
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
out_flat = weighted_outputs.sum(dim=1)
elif self.moe_mode == "sparse":
out_flat = torch.zeros_like(x_flat)
for e in range(self.num_experts):
mask = topk_indices == e
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
out_flat.index_add_(
0,
idx,
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
)
elif self.moe_mode == "sparse_allow_graph":
out_flat = _sparse_moe_dispatch(
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
)
else:
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
return out_flat.view(B, L, D)