feat: 实现并行化 MoE 层以兼容 torch.compile 和 AMP
This commit is contained in:
parent
e1efcc75a8
commit
6ee28e0aa5
|
|
@ -1,9 +1,78 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.components import MoELayer
|
# from model.components import MoELayer
|
||||||
from model.model import InputMethodEngine
|
from model.model import InputMethodEngine
|
||||||
|
|
||||||
|
|
||||||
|
class MoELayer(nn.Module):
|
||||||
|
def __init__(self, input_dim=512, dim=768, num_experts=20, top_k=3, export_resblocks=4):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.top_k = top_k
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# Import Expert from your existing components
|
||||||
|
# Assuming Expert class is defined as in components.py [2]
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Expert(
|
||||||
|
input_dim=input_dim,
|
||||||
|
d_model=dim,
|
||||||
|
num_resblocks=export_resblocks,
|
||||||
|
output_multiplier=1,
|
||||||
|
)
|
||||||
|
for _ in range(num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gating Network [2]
|
||||||
|
self.gate = nn.Linear(input_dim, num_experts)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [batch, seq_len, dim]
|
||||||
|
Returns:
|
||||||
|
out: [batch, seq_len, dim]
|
||||||
|
"""
|
||||||
|
B, L, D = x.shape
|
||||||
|
num_tokens = B * L
|
||||||
|
|
||||||
|
# 展平输入以便处理
|
||||||
|
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||||
|
|
||||||
|
# 1. 计算门控分数
|
||||||
|
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||||
|
|
||||||
|
# 2. 选择 Top-K 专家
|
||||||
|
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||||
|
|
||||||
|
# 归一化权重
|
||||||
|
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||||
|
|
||||||
|
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||||
|
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||||
|
expert_outputs = torch.stack(
|
||||||
|
[expert(x_flat) for expert in self.experts], dim=1
|
||||||
|
) # [B*L, num_experts, D]
|
||||||
|
|
||||||
|
# 4. 使用 gather 选择对应专家的输出
|
||||||
|
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||||
|
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||||
|
selected_outputs = torch.gather(
|
||||||
|
expert_outputs, 1, indices_expanded
|
||||||
|
) # [B*L, K, D]
|
||||||
|
# 5. 加权求和
|
||||||
|
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||||
|
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||||
|
|
||||||
|
# 恢复原始形状
|
||||||
|
return out_flat.view(B, L, D)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BigExpert(InputMethodEngine):
|
class BigExpert(InputMethodEngine):
|
||||||
def __init__(self, *args, **kw):
|
def __init__(self, *args, **kw):
|
||||||
if "compile" in kw:
|
if "compile" in kw:
|
||||||
|
|
@ -16,9 +85,9 @@ class BigExpert(InputMethodEngine):
|
||||||
if "dim" in kw:
|
if "dim" in kw:
|
||||||
dim = kw["dim"]
|
dim = kw["dim"]
|
||||||
else:
|
else:
|
||||||
dim = 512
|
dim = 768
|
||||||
|
|
||||||
self.moe = MoELayer(dim=dim, num_experts=40, top_k=3)
|
self.moe = MoELayer(input_dim=512, dim=768, num_experts=40, top_k=3)
|
||||||
|
|
||||||
if compile:
|
if compile:
|
||||||
self.forward = torch.compile(
|
self.forward = torch.compile(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue