429 lines
15 KiB
Python
429 lines
15 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from modelscope import AutoModel
|
||
|
||
|
||
# ---------------------------- 注意力池化模块----------------------------
|
||
class AttentionPooling(nn.Module):
|
||
def __init__(self, hidden_size):
|
||
super().__init__()
|
||
self.attn = nn.Linear(hidden_size, 1)
|
||
# 三个可学习偏置:文本、拼音、个性化
|
||
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
|
||
|
||
def forward(self, x, mask=None, token_type_ids=None):
|
||
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
|
||
if token_type_ids is not None:
|
||
# 根据 token_type_ids 添加对应偏置
|
||
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
|
||
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
|
||
scores = scores + bias_per_token
|
||
if mask is not None:
|
||
scores = scores.masked_fill(mask == 0, -1e9)
|
||
weights = torch.softmax(scores, dim=-1)
|
||
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
||
return pooled
|
||
|
||
|
||
# ---------------------------- 拼音LSTM编码器 ----------------------------
|
||
class PinyinLSTMEncoder(nn.Module):
|
||
"""
|
||
拼音序列编码器,返回每位置的拼音编码。
|
||
使用双向LSTM,每个位置都能看到前后文信息。
|
||
"""
|
||
|
||
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
||
super().__init__()
|
||
self.input_dim = input_dim
|
||
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
|
||
self.num_layers = num_layers
|
||
self.dropout = dropout
|
||
|
||
self.lstm = nn.LSTM(
|
||
input_size=input_dim,
|
||
hidden_size=self.hidden_dim,
|
||
num_layers=num_layers,
|
||
bidirectional=True,
|
||
batch_first=True,
|
||
dropout=dropout if num_layers > 1 else 0.0,
|
||
)
|
||
|
||
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
||
self.layer_norm = nn.LayerNorm(input_dim)
|
||
|
||
def forward(self, x, mask=None):
|
||
"""
|
||
Args:
|
||
x: [batch, seq_len, input_dim] pinyin embeddings
|
||
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
|
||
Returns:
|
||
output: [batch, seq_len, input_dim] 每位置的拼音编码
|
||
"""
|
||
total_len = x.size(1)
|
||
|
||
if mask is not None:
|
||
lengths = mask.sum(dim=1).cpu().clamp(min=1)
|
||
packed = nn.utils.rnn.pack_padded_sequence(
|
||
x, lengths, batch_first=True, enforce_sorted=False
|
||
)
|
||
packed_out, (hidden, cell) = self.lstm(packed)
|
||
output, _ = nn.utils.rnn.pad_packed_sequence(
|
||
packed_out, batch_first=True, total_length=total_len
|
||
)
|
||
else:
|
||
output, (hidden, cell) = self.lstm(x)
|
||
|
||
projected = self.proj(output)
|
||
return self.layer_norm(projected)
|
||
|
||
|
||
# ---------------------------- 残差块 ----------------------------
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, dim, dropout_prob=0.3):
|
||
super().__init__()
|
||
self.linear1 = nn.Linear(dim, dim)
|
||
self.ln1 = nn.LayerNorm(dim)
|
||
self.linear2 = nn.Linear(dim, dim)
|
||
self.ln2 = nn.LayerNorm(dim)
|
||
self.gelu = nn.GELU()
|
||
self.dropout = nn.Dropout(dropout_prob)
|
||
|
||
def forward(self, x):
|
||
residual = x
|
||
# 修复:使用 self.gelu 而不是未定义的 self.relu
|
||
x = self.gelu(self.linear1(x))
|
||
x = self.ln1(x)
|
||
x = self.linear2(x)
|
||
x = self.ln2(x)
|
||
x = self.dropout(x)
|
||
x = x + residual
|
||
return self.gelu(x)
|
||
|
||
|
||
# ---------------------------- 专家网络 ----------------------------
|
||
class Expert(nn.Module):
|
||
def __init__(
|
||
self,
|
||
input_dim,
|
||
d_model=512,
|
||
num_resblocks=4,
|
||
output_multiplier=1,
|
||
dropout_prob=0.3,
|
||
):
|
||
super().__init__()
|
||
self.output_dim = input_dim * output_multiplier
|
||
self.linear_in = nn.Linear(input_dim, d_model)
|
||
self.res_blocks = nn.ModuleList(
|
||
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||
)
|
||
self.output = nn.Sequential(
|
||
nn.Linear(d_model, d_model),
|
||
nn.GELU(),
|
||
nn.Dropout(dropout_prob),
|
||
nn.Linear(d_model, self.output_dim),
|
||
)
|
||
|
||
def forward(self, x):
|
||
x = self.linear_in(x)
|
||
for block in self.res_blocks:
|
||
x = block(x)
|
||
return self.output(x)
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# 1. 上下文编码器 (Context Encoder)
|
||
# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1]
|
||
# ------------------------------------------------------------------
|
||
class ContextEncoder(nn.Module):
|
||
def __init__(
|
||
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
|
||
):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
# Embeddings
|
||
self.text_emb = AutoModel.from_pretrained(
|
||
"iic/nlp_structbert_backbone_lite_std"
|
||
).embeddings
|
||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||
self.pos_emb = nn.Embedding(max_len, dim)
|
||
self.pinyin_pooling = PinyinLSTMEncoder(dim)
|
||
|
||
# Transformer Encoder (4 layers, 4 heads) [1]
|
||
encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=dim,
|
||
nhead=n_heads,
|
||
dim_feedforward=dim * 4,
|
||
dropout=0.1,
|
||
batch_first=True, # 方便处理 [B, L, D]
|
||
)
|
||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||
|
||
# LayerNorm for stability
|
||
self.ln = nn.LayerNorm(dim)
|
||
|
||
def forward(self, text_ids, pinyin_ids, mask=None):
|
||
"""
|
||
Args:
|
||
text_ids: [batch, seq_len] 文本 token ids
|
||
pinyin_ids: [batch, pinyin_len] 拼音 token ids
|
||
mask: [batch, seq_len] optional padding mask (1 for valid, 0 for padding)
|
||
Returns:
|
||
H: [batch, seq_len, dim] 文本上下文编码
|
||
P: [batch, pinyin_len, dim] 拼音序列编码(每位置)
|
||
"""
|
||
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
|
||
|
||
seq_len = text_emb.size(1)
|
||
pos_ids = torch.arange(seq_len, device=text_ids.device).unsqueeze(0)
|
||
x = text_emb + self.pos_emb(pos_ids)
|
||
|
||
if mask is not None:
|
||
src_mask = mask == 0
|
||
else:
|
||
src_mask = None
|
||
|
||
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||
H = self.ln(H)
|
||
|
||
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
|
||
pinyin_mask = pinyin_ids != 0
|
||
P = self.pinyin_pooling(pinyin_emb, mask=pinyin_mask) # [B, pinyin_len, dim]
|
||
|
||
return H, P
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# 2. 槽位记忆模块 (Slot Memory)
|
||
# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1]
|
||
# ------------------------------------------------------------------
|
||
class SlotMemory(nn.Module):
|
||
def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512):
|
||
super().__init__()
|
||
self.max_slots = max_slots
|
||
self.steps_per_slot = steps_per_slot
|
||
self.total_steps = max_slots * steps_per_slot # 24 steps [1]
|
||
|
||
# Shared embedding layer for history tokens [1]
|
||
self.emb = nn.Embedding(vocab_size, dim)
|
||
|
||
# Learnable positional embeddings for the flattened sequence [1]
|
||
self.pos_emb = nn.Embedding(self.total_steps, dim)
|
||
|
||
# Start token embedding for empty slots [1]
|
||
self.start_emb = nn.Parameter(torch.randn(1, 1, dim))
|
||
|
||
def forward(self, history_ids):
|
||
"""
|
||
Args:
|
||
history_ids: [batch, total_steps]
|
||
Flattened sequence of history tokens.
|
||
Zero positions use start_emb as learnable query.
|
||
Returns:
|
||
S: [batch, total_steps, 512] Slot sequence representation [1]
|
||
"""
|
||
S = self.emb(history_ids) # [B, num_slots, dim]
|
||
|
||
zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1]
|
||
S = S * (1 - zero_mask) + self.start_emb * zero_mask
|
||
|
||
pos_ids = (
|
||
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
||
)
|
||
S += self.pos_emb(pos_ids)
|
||
|
||
return S
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# 3. 交叉注意力融合 (Cross-Attention Fusion)
|
||
# 槽位同时查询文本上下文和拼音序列
|
||
# ------------------------------------------------------------------
|
||
class CrossAttentionFusion(nn.Module):
|
||
"""
|
||
双路交叉注意力:
|
||
- 槽位查询文本上下文 (H)
|
||
- 槽位查询拼音序列 (P)
|
||
通过注意力机制,start_emb 自动学会关注"当前应该预测的拼音位置"
|
||
"""
|
||
|
||
def __init__(self, dim=512, n_heads=4):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.n_heads = n_heads
|
||
self.head_dim = dim // n_heads
|
||
assert self.head_dim * n_heads == dim
|
||
|
||
self.q_proj = nn.Linear(dim, dim, bias=False)
|
||
self.k_text_proj = nn.Linear(dim, dim, bias=False)
|
||
self.v_text_proj = nn.Linear(dim, dim, bias=False)
|
||
self.k_pinyin_proj = nn.Linear(dim, dim, bias=False)
|
||
self.v_pinyin_proj = nn.Linear(dim, dim, bias=False)
|
||
|
||
self.out_proj = nn.Linear(dim, dim, bias=False)
|
||
self.ln = nn.LayerNorm(dim)
|
||
|
||
def forward(
|
||
self, slots_S, context_H, pinyin_P, context_mask=None, pinyin_mask=None
|
||
):
|
||
"""
|
||
Args:
|
||
slots_S: [batch, num_slots, dim] 槽位编码
|
||
context_H: [batch, ctx_len, dim] 文本上下文编码
|
||
pinyin_P: [batch, pinyin_len, dim] 拼音序列编码
|
||
context_mask: [batch, ctx_len] 文本 padding mask (True for padding)
|
||
pinyin_mask: [batch, pinyin_len] 拼音 padding mask (True for padding)
|
||
Returns:
|
||
fused: [batch, num_slots, dim]
|
||
"""
|
||
batch_size, num_slots, _ = slots_S.shape
|
||
ctx_len = context_H.size(1)
|
||
pinyin_len = pinyin_P.size(1)
|
||
|
||
Q = self.q_proj(slots_S)
|
||
K_text = self.k_text_proj(context_H)
|
||
V_text = self.v_text_proj(context_H)
|
||
K_pinyin = self.k_pinyin_proj(pinyin_P)
|
||
V_pinyin = self.v_pinyin_proj(pinyin_P)
|
||
|
||
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
|
||
K_text = K_text.view(
|
||
batch_size, ctx_len, self.n_heads, self.head_dim
|
||
).transpose(1, 2)
|
||
V_text = V_text.view(
|
||
batch_size, ctx_len, self.n_heads, self.head_dim
|
||
).transpose(1, 2)
|
||
K_pinyin = K_pinyin.view(
|
||
batch_size, pinyin_len, self.n_heads, self.head_dim
|
||
).transpose(1, 2)
|
||
V_pinyin = V_pinyin.view(
|
||
batch_size, pinyin_len, self.n_heads, self.head_dim
|
||
).transpose(1, 2)
|
||
|
||
text_attn_mask = None
|
||
if context_mask is not None:
|
||
text_attn_mask = (
|
||
context_mask[:, None, None, :]
|
||
.float()
|
||
.masked_fill(context_mask[:, None, None, :], -1e9)
|
||
)
|
||
|
||
pinyin_attn_mask = None
|
||
if pinyin_mask is not None:
|
||
pinyin_attn_mask = (
|
||
pinyin_mask[:, None, None, :]
|
||
.float()
|
||
.masked_fill(pinyin_mask[:, None, None, :], -1e9)
|
||
)
|
||
|
||
text_attn = F.scaled_dot_product_attention(
|
||
Q, K_text, V_text, attn_mask=text_attn_mask
|
||
)
|
||
pinyin_attn = F.scaled_dot_product_attention(
|
||
Q, K_pinyin, V_pinyin, attn_mask=pinyin_attn_mask
|
||
)
|
||
|
||
combined_attn = text_attn + pinyin_attn
|
||
|
||
combined_attn = (
|
||
combined_attn.transpose(1, 2)
|
||
.contiguous()
|
||
.view(batch_size, num_slots, self.dim)
|
||
)
|
||
fused = self.out_proj(combined_attn)
|
||
fused = self.ln(fused + slots_S)
|
||
|
||
return fused
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# 4. 专家混合层 (MoE Layer)
|
||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||
# ------------------------------------------------------------------
|
||
|
||
|
||
@torch.compiler.allow_in_graph
|
||
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
|
||
output = torch.zeros_like(x_flat)
|
||
for e in range(num_experts):
|
||
mask = topk_indices == e
|
||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||
if idx.numel() > 0:
|
||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
|
||
return output
|
||
|
||
|
||
class MoELayer(nn.Module):
|
||
"""
|
||
moe_mode 支持三种策略:
|
||
- "all": 计算全部专家,torch.compile 不断裂 (当前默认)
|
||
- "sparse": 只计算被路由到的专家 (产生 graph break)
|
||
- "sparse_allow_graph": 稀疏 MoE,通过 allow_in_graph 避免 graph break
|
||
"""
|
||
|
||
def __init__(
|
||
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
|
||
):
|
||
super().__init__()
|
||
self.num_experts = num_experts
|
||
self.top_k = top_k
|
||
self.dim = dim
|
||
self.moe_mode = moe_mode
|
||
|
||
self.experts = nn.ModuleList(
|
||
[
|
||
Expert(
|
||
input_dim=dim,
|
||
d_model=dim,
|
||
num_resblocks=num_resblocks,
|
||
output_multiplier=1,
|
||
)
|
||
for _ in range(num_experts)
|
||
]
|
||
)
|
||
|
||
self.gate = nn.Linear(dim, num_experts)
|
||
|
||
def forward(self, x):
|
||
B, L, D = x.shape
|
||
num_tokens = B * L
|
||
x_flat = x.view(num_tokens, D)
|
||
|
||
gates = self.gate(x_flat)
|
||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
|
||
topk_weights = F.softmax(topk_weights, dim=-1)
|
||
|
||
if self.moe_mode == "all":
|
||
expert_outputs = torch.stack(
|
||
[expert(x_flat) for expert in self.experts], dim=1
|
||
)
|
||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
|
||
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
|
||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
|
||
out_flat = weighted_outputs.sum(dim=1)
|
||
|
||
elif self.moe_mode == "sparse":
|
||
out_flat = torch.zeros_like(x_flat)
|
||
for e in range(self.num_experts):
|
||
mask = topk_indices == e
|
||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||
if idx.numel() > 0:
|
||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||
out_flat.index_add_(
|
||
0,
|
||
idx,
|
||
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
|
||
)
|
||
|
||
elif self.moe_mode == "sparse_allow_graph":
|
||
out_flat = _sparse_moe_dispatch(
|
||
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
|
||
)
|
||
|
||
else:
|
||
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
|
||
|
||
return out_flat.view(B, L, D)
|