from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # 导入 components.py 中的组件 from .components import ( ContextEncoder, CrossAttentionFusion, MoELayer, SlotMemory, ) class InputMethodEngine(nn.Module): """ 输入法预测引擎模型。 基于 README 设计的轻量级输入法预测模型,整合了上下文编码、槽位记忆、 交叉注意力融合、混合专家网络 (MoE) 以及分类预测。 输入参数: input_ids (torch.Tensor): [batch_size, seq_len] 文本 token ids。 token_type_ids (torch.Tensor): [batch_size, seq_len] 标识前缀/后缀等类型,用于偏置。 attention_mask (torch.Tensor): [batch_size, seq_len] 注意力掩码,1 表示有效位置。 history_slot_ids (torch.Tensor): [batch_size, num_slots] 或 [num_slots] 历史槽位 ID。 如果输入为 [num_slots],内部会自动扩展 batch 维度。 输出: logits (torch.Tensor): [batch_size, vocab_size] 下一个字符的概率分布(未经过 softmax)。 """ def __init__( self, vocab_size: int = 10019, pinyin_vocab_size: int = 28, dim: int = 512, num_slots: int = 8, n_layers: int = 4, n_heads: int = 4, num_experts: int = 10, max_seq_len: int = 128, compile: bool = False, moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph" ): super().__init__() self.dim = dim self.num_slots = num_slots self.vocab_size = vocab_size # 1. 上下文编码器 (ContextEncoder) # 若 use_pinyin=False,则传入 pinyin_vocab_size=1 并固定嵌入为零 self.context_encoder = ContextEncoder( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, n_layers=n_layers, n_heads=n_heads, max_len=max_seq_len, ) # 2. 槽位记忆模块 (SlotMemory) # 适配历史槽位数量为 num_slots(每个槽位对应一个词,而非多步) self.slot_memory = SlotMemory( vocab_size=vocab_size, max_slots=num_slots, steps_per_slot=1, # 每个槽位只占一步 dim=dim, ) # 3. 交叉注意力融合 (CrossAttentionFusion) # 使用 F.scaled_dot_product_attention 实现的版本 self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads) # 4. 混合专家层 (MoE) self.moe = MoELayer( dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12, moe_mode=moe_mode, ) # 5. 槽位注意力池化 self.slot_attention = nn.Linear(dim, 1) # 6. 分类头 self.classifier = nn.Linear(dim, vocab_size) # 开启 torch.compile 优化 (如果请求) # 在模型编译时添加优化选项 if compile: from torch._inductor.select_algorithm import TritonTemplate TritonTemplate.all_templates.clear() self.forward = torch.compile( self.forward, mode="reduce-overhead", fullgraph=False, dynamic=False, # options={ # "epilogue_fusion": True, # "max_autotune": True, # "triton.cudagraphs": True, # "reorder_for_compute_comm_overlap": False, # }, ) def forward( self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor, pinyin_ids: torch.Tensor, history_slot_ids: torch.Tensor, ) -> torch.Tensor: """ 前向传播。 """ batch_size = input_ids.size(0) history_slot_ids = history_slot_ids.view(-1, self.num_slots) H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask) S = self.slot_memory(history_slot_ids) context_mask = attention_mask == 0 pinyin_mask = pinyin_ids == 0 fused = self.cross_attn( S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask ) moe_out = self.moe(fused) batch_size = input_ids.size(0) slot_scores = self.slot_attention(moe_out).squeeze(-1) slot_weights = torch.softmax(slot_scores, dim=1) pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) logits = self.classifier(pooled) return logits