From 6ee28e0aa579218892541acd7a39bd7debcaa5dd Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 9 Apr 2026 12:37:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E5=B9=B6=E8=A1=8C?= =?UTF-8?q?=E5=8C=96=20MoE=20=E5=B1=82=E4=BB=A5=E5=85=BC=E5=AE=B9=20torch.?= =?UTF-8?q?compile=20=E5=92=8C=20AMP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- big_expert.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/big_expert.py b/big_expert.py index 021c2cd..aa06ed5 100644 --- a/big_expert.py +++ b/big_expert.py @@ -1,9 +1,78 @@ import torch -from model.components import MoELayer +# from model.components import MoELayer from model.model import InputMethodEngine +class MoELayer(nn.Module): + def __init__(self, input_dim=512, dim=768, num_experts=20, top_k=3, export_resblocks=4): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.dim = dim + + # Import Expert from your existing components + # Assuming Expert class is defined as in components.py [2] + self.experts = nn.ModuleList( + [ + Expert( + input_dim=input_dim, + d_model=dim, + num_resblocks=export_resblocks, + output_multiplier=1, + ) + for _ in range(num_experts) + ] + ) + + # Gating Network [2] + self.gate = nn.Linear(input_dim, num_experts) + + def forward(self, x): + """ + 并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。 + + Args: + x: [batch, seq_len, dim] + Returns: + out: [batch, seq_len, dim] + """ + B, L, D = x.shape + num_tokens = B * L + + # 展平输入以便处理 + x_flat = x.view(num_tokens, D) # [B*L, D] + + # 1. 计算门控分数 + gates = self.gate(x_flat) # [B*L, num_experts] + + # 2. 选择 Top-K 专家 + topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K] + + # 归一化权重 + topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K] + + # 3. 并行计算所有专家(消除 Python 循环中的动态控制流) + # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量 + expert_outputs = torch.stack( + [expert(x_flat) for expert in self.experts], dim=1 + ) # [B*L, num_experts, D] + + # 4. 使用 gather 选择对应专家的输出 + # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D] + indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D] + selected_outputs = torch.gather( + expert_outputs, 1, indices_expanded + ) # [B*L, K, D] + # 5. 加权求和 + weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D] + out_flat = weighted_outputs.sum(dim=1) # [B*L, D] + + # 恢复原始形状 + return out_flat.view(B, L, D) + + + class BigExpert(InputMethodEngine): def __init__(self, *args, **kw): if "compile" in kw: @@ -16,9 +85,9 @@ class BigExpert(InputMethodEngine): if "dim" in kw: dim = kw["dim"] else: - dim = 512 + dim = 768 - self.moe = MoELayer(dim=dim, num_experts=40, top_k=3) + self.moe = MoELayer(input_dim=512, dim=768, num_experts=40, top_k=3) if compile: self.forward = torch.compile(