import torch import torch.nn as nn import torch.nn.functional as F from modelscope import AutoModel # ---------------------------- 注意力池化模块---------------------------- class AttentionPooling(nn.Module): def __init__(self, hidden_size): super().__init__() self.attn = nn.Linear(hidden_size, 1) # 三个可学习偏置:文本、拼音、个性化 self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias] def forward(self, x, mask=None, token_type_ids=None): scores = self.attn(x).squeeze(-1) # [batch, seq_len] if token_type_ids is not None: # 根据 token_type_ids 添加对应偏置 # bias 形状 [3],通过索引扩展为 [batch, seq_len] bias_per_token = self.bias[token_type_ids] # [batch, seq_len] scores = scores + bias_per_token if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) weights = torch.softmax(scores, dim=-1) pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1) return pooled # ---------------------------- 拼音LSTM编码器 ---------------------------- class PinyinLSTMEncoder(nn.Module): """ 拼音序列编码器,返回每位置的拼音编码。 使用双向LSTM,每个位置都能看到前后文信息。 """ def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2 self.num_layers = num_layers self.dropout = dropout self.lstm = nn.LSTM( input_size=input_dim, hidden_size=self.hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True, dropout=dropout if num_layers > 1 else 0.0, ) self.proj = nn.Linear(self.hidden_dim * 2, input_dim) self.layer_norm = nn.LayerNorm(input_dim) def forward(self, x, mask=None): """ Args: x: [batch, seq_len, input_dim] pinyin embeddings mask: [batch, seq_len] optional padding mask (True for valid, False for padding) Returns: output: [batch, seq_len, input_dim] 每位置的拼音编码 """ total_len = x.size(1) if mask is not None: lengths = mask.sum(dim=1).cpu().clamp(min=1) packed = nn.utils.rnn.pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) packed_out, (hidden, cell) = self.lstm(packed) output, _ = nn.utils.rnn.pad_packed_sequence( packed_out, batch_first=True, total_length=total_len ) else: output, (hidden, cell) = self.lstm(x) projected = self.proj(output) return self.layer_norm(projected) # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.3): super().__init__() self.linear1 = nn.Linear(dim, dim) self.ln1 = nn.LayerNorm(dim) self.linear2 = nn.Linear(dim, dim) self.ln2 = nn.LayerNorm(dim) self.gelu = nn.GELU() self.dropout = nn.Dropout(dropout_prob) def forward(self, x): residual = x # 修复:使用 self.gelu 而不是未定义的 self.relu x = self.gelu(self.linear1(x)) x = self.ln1(x) x = self.linear2(x) x = self.ln2(x) x = self.dropout(x) x = x + residual return self.gelu(x) # ---------------------------- 专家网络 ---------------------------- class Expert(nn.Module): def __init__( self, input_dim, d_model=512, num_resblocks=4, output_multiplier=1, dropout_prob=0.3, ): super().__init__() self.output_dim = input_dim * output_multiplier self.linear_in = nn.Linear(input_dim, d_model) self.res_blocks = nn.ModuleList( [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] ) self.output = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout_prob), nn.Linear(d_model, self.output_dim), ) def forward(self, x): x = self.linear_in(x) for block in self.res_blocks: x = block(x) return self.output(x) # ------------------------------------------------------------------ # 1. 上下文编码器 (Context Encoder) # 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1] # ------------------------------------------------------------------ class ContextEncoder(nn.Module): def __init__( self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128 ): super().__init__() self.dim = dim # Embeddings self.text_emb = AutoModel.from_pretrained( "iic/nlp_structbert_backbone_lite_std" ).embeddings self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim) self.pos_emb = nn.Embedding(max_len, dim) self.pinyin_pooling = PinyinLSTMEncoder(dim) # Transformer Encoder (4 layers, 4 heads) [1] encoder_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=n_heads, dim_feedforward=dim * 4, dropout=0.1, batch_first=True, # 方便处理 [B, L, D] ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) # LayerNorm for stability self.ln = nn.LayerNorm(dim) def forward(self, text_ids, pinyin_ids, mask=None): """ Args: text_ids: [batch, seq_len] 文本 token ids pinyin_ids: [batch, pinyin_len] 拼音 token ids mask: [batch, seq_len] optional padding mask (1 for valid, 0 for padding) Returns: H: [batch, seq_len, dim] 文本上下文编码 P: [batch, pinyin_len, dim] 拼音序列编码(每位置) """ text_emb = self.text_emb(text_ids) # [B, seq_len, dim] seq_len = text_emb.size(1) pos_ids = torch.arange(seq_len, device=text_ids.device).unsqueeze(0) x = text_emb + self.pos_emb(pos_ids) if mask is not None: src_mask = mask == 0 else: src_mask = None H = self.transformer(x, src_key_padding_mask=src_mask) H = self.ln(H) pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim] pinyin_mask = pinyin_ids != 0 P = self.pinyin_pooling(pinyin_emb, mask=pinyin_mask) # [B, pinyin_len, dim] return H, P # ------------------------------------------------------------------ # 2. 槽位记忆模块 (Slot Memory) # 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1] # ------------------------------------------------------------------ class SlotMemory(nn.Module): def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512): super().__init__() self.max_slots = max_slots self.steps_per_slot = steps_per_slot self.total_steps = max_slots * steps_per_slot # 24 steps [1] # Shared embedding layer for history tokens [1] self.emb = nn.Embedding(vocab_size, dim) # Learnable positional embeddings for the flattened sequence [1] self.pos_emb = nn.Embedding(self.total_steps, dim) # Start token embedding for empty slots [1] self.start_emb = nn.Parameter(torch.randn(1, 1, dim)) def forward(self, history_ids): """ Args: history_ids: [batch, total_steps] Flattened sequence of history tokens. Zero positions use start_emb as learnable query. Returns: S: [batch, total_steps, 512] Slot sequence representation [1] """ S = self.emb(history_ids) # [B, num_slots, dim] zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1] S = S * (1 - zero_mask) + self.start_emb * zero_mask pos_ids = ( torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids) ) S += self.pos_emb(pos_ids) return S # ------------------------------------------------------------------ # 3. 交叉注意力融合 (Cross-Attention Fusion) # 槽位同时查询文本上下文和拼音序列 # ------------------------------------------------------------------ class CrossAttentionFusion(nn.Module): """ 双路交叉注意力: - 槽位查询文本上下文 (H) - 槽位查询拼音序列 (P) 通过注意力机制,start_emb 自动学会关注"当前应该预测的拼音位置" """ def __init__(self, dim=512, n_heads=4): super().__init__() self.dim = dim self.n_heads = n_heads self.head_dim = dim // n_heads assert self.head_dim * n_heads == dim self.q_proj = nn.Linear(dim, dim, bias=False) self.k_text_proj = nn.Linear(dim, dim, bias=False) self.v_text_proj = nn.Linear(dim, dim, bias=False) self.k_pinyin_proj = nn.Linear(dim, dim, bias=False) self.v_pinyin_proj = nn.Linear(dim, dim, bias=False) self.out_proj = nn.Linear(dim, dim, bias=False) self.ln = nn.LayerNorm(dim) def forward( self, slots_S, context_H, pinyin_P, context_mask=None, pinyin_mask=None ): """ Args: slots_S: [batch, num_slots, dim] 槽位编码 context_H: [batch, ctx_len, dim] 文本上下文编码 pinyin_P: [batch, pinyin_len, dim] 拼音序列编码 context_mask: [batch, ctx_len] 文本 padding mask (True for padding) pinyin_mask: [batch, pinyin_len] 拼音 padding mask (True for padding) Returns: fused: [batch, num_slots, dim] """ batch_size, num_slots, _ = slots_S.shape ctx_len = context_H.size(1) pinyin_len = pinyin_P.size(1) Q = self.q_proj(slots_S) K_text = self.k_text_proj(context_H) V_text = self.v_text_proj(context_H) K_pinyin = self.k_pinyin_proj(pinyin_P) V_pinyin = self.v_pinyin_proj(pinyin_P) Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2) K_text = K_text.view( batch_size, ctx_len, self.n_heads, self.head_dim ).transpose(1, 2) V_text = V_text.view( batch_size, ctx_len, self.n_heads, self.head_dim ).transpose(1, 2) K_pinyin = K_pinyin.view( batch_size, pinyin_len, self.n_heads, self.head_dim ).transpose(1, 2) V_pinyin = V_pinyin.view( batch_size, pinyin_len, self.n_heads, self.head_dim ).transpose(1, 2) text_attn_mask = None if context_mask is not None: text_attn_mask = ( context_mask[:, None, None, :] .float() .masked_fill(context_mask[:, None, None, :], -1e9) ) pinyin_attn_mask = None if pinyin_mask is not None: pinyin_attn_mask = ( pinyin_mask[:, None, None, :] .float() .masked_fill(pinyin_mask[:, None, None, :], -1e9) ) text_attn = F.scaled_dot_product_attention( Q, K_text, V_text, attn_mask=text_attn_mask ) pinyin_attn = F.scaled_dot_product_attention( Q, K_pinyin, V_pinyin, attn_mask=pinyin_attn_mask ) combined_attn = text_attn + pinyin_attn combined_attn = ( combined_attn.transpose(1, 2) .contiguous() .view(batch_size, num_slots, self.dim) ) fused = self.out_proj(combined_attn) fused = self.ln(fused + slots_S) return fused # ------------------------------------------------------------------ # 4. 专家混合层 (MoE Layer) # 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类 # ------------------------------------------------------------------ @torch.compiler.allow_in_graph def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts): output = torch.zeros_like(x_flat) for e in range(num_experts): mask = topk_indices == e idx, k_idx = mask.nonzero(as_tuple=True) if idx.numel() > 0: w = topk_weights[idx, k_idx].unsqueeze(-1) output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype)) return output class MoELayer(nn.Module): """ moe_mode 支持三种策略: - "all": 计算全部专家,torch.compile 不断裂 (当前默认) - "sparse": 只计算被路由到的专家 (产生 graph break) - "sparse_allow_graph": 稀疏 MoE,通过 allow_in_graph 避免 graph break """ def __init__( self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all" ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.dim = dim self.moe_mode = moe_mode self.experts = nn.ModuleList( [ Expert( input_dim=dim, d_model=dim, num_resblocks=num_resblocks, output_multiplier=1, ) for _ in range(num_experts) ] ) self.gate = nn.Linear(dim, num_experts) def forward(self, x): B, L, D = x.shape num_tokens = B * L x_flat = x.view(num_tokens, D) gates = self.gate(x_flat) topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) topk_weights = F.softmax(topk_weights, dim=-1) if self.moe_mode == "all": expert_outputs = torch.stack( [expert(x_flat) for expert in self.experts], dim=1 ) indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) selected_outputs = torch.gather(expert_outputs, 1, indices_expanded) weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) out_flat = weighted_outputs.sum(dim=1) elif self.moe_mode == "sparse": out_flat = torch.zeros_like(x_flat) for e in range(self.num_experts): mask = topk_indices == e idx, k_idx = mask.nonzero(as_tuple=True) if idx.numel() > 0: w = topk_weights[idx, k_idx].unsqueeze(-1) out_flat.index_add_( 0, idx, (w * self.experts[e](x_flat[idx])).to(out_flat.dtype), ) elif self.moe_mode == "sparse_allow_graph": out_flat = _sparse_moe_dispatch( x_flat, self.experts, topk_indices, topk_weights, self.num_experts ) else: raise ValueError(f"Unknown moe_mode: {self.moe_mode}") return out_flat.view(B, L, D)