import torch # from model.components import MoELayer from model.model import InputMethodEngine class MoELayer(nn.Module): def __init__(self, input_dim=512, dim=768, num_experts=20, top_k=3, export_resblocks=4): super().__init__() self.num_experts = num_experts self.top_k = top_k self.dim = dim # Import Expert from your existing components # Assuming Expert class is defined as in components.py [2] self.experts = nn.ModuleList( [ Expert( input_dim=input_dim, d_model=dim, num_resblocks=export_resblocks, output_multiplier=1, ) for _ in range(num_experts) ] ) # Gating Network [2] self.gate = nn.Linear(input_dim, num_experts) def forward(self, x): """ 并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。 Args: x: [batch, seq_len, dim] Returns: out: [batch, seq_len, dim] """ B, L, D = x.shape num_tokens = B * L # 展平输入以便处理 x_flat = x.view(num_tokens, D) # [B*L, D] # 1. 计算门控分数 gates = self.gate(x_flat) # [B*L, num_experts] # 2. 选择 Top-K 专家 topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K] # 归一化权重 topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K] # 3. 并行计算所有专家(消除 Python 循环中的动态控制流) # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量 expert_outputs = torch.stack( [expert(x_flat) for expert in self.experts], dim=1 ) # [B*L, num_experts, D] # 4. 使用 gather 选择对应专家的输出 # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D] indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D] selected_outputs = torch.gather( expert_outputs, 1, indices_expanded ) # [B*L, K, D] # 5. 加权求和 weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D] out_flat = weighted_outputs.sum(dim=1) # [B*L, D] # 恢复原始形状 return out_flat.view(B, L, D) class BigExpert(InputMethodEngine): def __init__(self, *args, **kw): if "compile" in kw: compile = kw.pop("compile") else: compile = False kw["compile"] = False super().__init__(*args, **kw) if "dim" in kw: dim = kw["dim"] else: dim = 768 self.moe = MoELayer(input_dim=512, dim=768, num_experts=40, top_k=3) if compile: self.forward = torch.compile( self.forward, mode="reduce-overhead", fullgraph=False, dynamic=False, )