SUimeModelTraner/big_expert.py

99 lines
3.0 KiB
Python

import torch
# from model.components import MoELayer
from model.model import InputMethodEngine
class MoELayer(nn.Module):
def __init__(self, input_dim=512, dim=768, num_experts=20, top_k=3, export_resblocks=4):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.dim = dim
# Import Expert from your existing components
# Assuming Expert class is defined as in components.py [2]
self.experts = nn.ModuleList(
[
Expert(
input_dim=input_dim,
d_model=dim,
num_resblocks=export_resblocks,
output_multiplier=1,
)
for _ in range(num_experts)
]
)
# Gating Network [2]
self.gate = nn.Linear(input_dim, num_experts)
def forward(self, x):
"""
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
Args:
x: [batch, seq_len, dim]
Returns:
out: [batch, seq_len, dim]
"""
B, L, D = x.shape
num_tokens = B * L
# 展平输入以便处理
x_flat = x.view(num_tokens, D) # [B*L, D]
# 1. 计算门控分数
gates = self.gate(x_flat) # [B*L, num_experts]
# 2. 选择 Top-K 专家
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
# 归一化权重
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
) # [B*L, num_experts, D]
# 4. 使用 gather 选择对应专家的输出
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
selected_outputs = torch.gather(
expert_outputs, 1, indices_expanded
) # [B*L, K, D]
# 5. 加权求和
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
# 恢复原始形状
return out_flat.view(B, L, D)
class BigExpert(InputMethodEngine):
def __init__(self, *args, **kw):
if "compile" in kw:
compile = kw.pop("compile")
else:
compile = False
kw["compile"] = False
super().__init__(*args, **kw)
if "dim" in kw:
dim = kw["dim"]
else:
dim = 768
self.moe = MoELayer(input_dim=512, dim=768, num_experts=40, top_k=3)
if compile:
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",
fullgraph=False,
dynamic=False,
)