feat: 实现并行化 MoE 层以兼容 torch.compile 和 AMP
This commit is contained in:
parent
e1efcc75a8
commit
6ee28e0aa5
|
|
@ -1,9 +1,78 @@
|
|||
import torch
|
||||
|
||||
from model.components import MoELayer
|
||||
# from model.components import MoELayer
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, input_dim=512, dim=768, num_experts=20, top_k=3, export_resblocks=4):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.dim = dim
|
||||
|
||||
# Import Expert from your existing components
|
||||
# Assuming Expert class is defined as in components.py [2]
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
input_dim=input_dim,
|
||||
d_model=dim,
|
||||
num_resblocks=export_resblocks,
|
||||
output_multiplier=1,
|
||||
)
|
||||
for _ in range(num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# Gating Network [2]
|
||||
self.gate = nn.Linear(input_dim, num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
||||
|
||||
Args:
|
||||
x: [batch, seq_len, dim]
|
||||
Returns:
|
||||
out: [batch, seq_len, dim]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
num_tokens = B * L
|
||||
|
||||
# 展平输入以便处理
|
||||
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||
|
||||
# 1. 计算门控分数
|
||||
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||
|
||||
# 2. 选择 Top-K 专家
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||
|
||||
# 归一化权重
|
||||
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||
|
||||
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||
expert_outputs = torch.stack(
|
||||
[expert(x_flat) for expert in self.experts], dim=1
|
||||
) # [B*L, num_experts, D]
|
||||
|
||||
# 4. 使用 gather 选择对应专家的输出
|
||||
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||
selected_outputs = torch.gather(
|
||||
expert_outputs, 1, indices_expanded
|
||||
) # [B*L, K, D]
|
||||
# 5. 加权求和
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||
|
||||
# 恢复原始形状
|
||||
return out_flat.view(B, L, D)
|
||||
|
||||
|
||||
|
||||
class BigExpert(InputMethodEngine):
|
||||
def __init__(self, *args, **kw):
|
||||
if "compile" in kw:
|
||||
|
|
@ -16,9 +85,9 @@ class BigExpert(InputMethodEngine):
|
|||
if "dim" in kw:
|
||||
dim = kw["dim"]
|
||||
else:
|
||||
dim = 512
|
||||
dim = 768
|
||||
|
||||
self.moe = MoELayer(dim=dim, num_experts=40, top_k=3)
|
||||
self.moe = MoELayer(input_dim=512, dim=768, num_experts=40, top_k=3)
|
||||
|
||||
if compile:
|
||||
self.forward = torch.compile(
|
||||
|
|
|
|||
Loading…
Reference in New Issue