refactor(MoELayer): 并行化前向传播以兼容 torch.compile 和 AMP

This commit is contained in:
songsenand 2026-04-05 23:19:11 +08:00
parent 7143896f4d
commit 493bfdec1a
3 changed files with 48 additions and 64 deletions

View File

@ -303,60 +303,43 @@ class MoELayer(nn.Module):
def forward(self, x):
"""
并行化 MoE 前向传播完全兼容 torch.compile AMP
Args:
x: [batch, seq_len, dim]
Returns:
out: [batch, seq_len, dim]
"""
B, L, D = x.shape
num_tokens = B * L
# 1. Compute Gating Scores
gates = self.gate(x) # [B, L, num_experts]
# 展平输入以便处理
x_flat = x.view(num_tokens, D) # [B*L, D]
# 2. Select Top-K Experts
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
# 1. 计算门控分数
gates = self.gate(x_flat) # [B*L, num_experts]
# Normalize weights for selected experts
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
# 2. 选择 Top-K 专家
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
# 3. Dispatch and Compute
# Initialize output
out = torch.zeros_like(x)
# 归一化权重
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
# Reshape for easier processing: flatten batch and sequence dimensions
x_flat = x.view(-1, D) # [B*L, D]
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
) # [B*L, num_experts, D]
# For each of the top-k positions
for k in range(self.top_k):
# Get expert indices and weights for this position
expert_indices = topk_indices_flat[:, k] # [B*L]
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
# 4. 使用 gather 选择对应专家的输出
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
selected_outputs = torch.gather(
expert_outputs, 1, indices_expanded
) # [B*L, K, D]
# 5. 加权求和
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
# Process each expert separately
for e_idx in range(self.num_experts):
# Mask for tokens assigned to this expert at position k
mask = expert_indices == e_idx # [B*L]
if not mask.any():
continue
# Extract tokens for this expert
x_selected = x_flat[mask] # [N_selected, D]
if x_selected.numel() == 0:
continue
# Pass through expert
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
# Apply expert weights and add to output
weighted_out = expert_out * expert_weights[mask]
# Scatter back to flat output
out_flat = out.view(-1, D)
out_flat[mask] += weighted_out
# Reshape back to original shape
out = out.view(B, L, D)
return out
# 恢复原始形状
return out_flat.view(B, L, D)

View File

@ -94,9 +94,9 @@ class InputMethodEngine(nn.Module):
"""
batch_size = input_ids.size(0)
# 处理 history_slot_ids若为 [num_slots] 则扩展 batch 维度
if history_slot_ids.dim() == 1:
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
# 处理 history_slot_ids确保为 [batch_size, num_slots]
# 使用 view 替代 if 判断,避免 torch.compile 图断开
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
# 1. 上下文编码 -> H [batch, seq_len, dim]
# 注意ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
@ -111,14 +111,13 @@ class InputMethodEngine(nn.Module):
# 4. MoE 处理 -> [batch, num_slots, dim]
moe_out = self.moe(fused)
# 5. 池化与分类:对槽位维度求平均(或使用 mask 池化)
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8)
# 如果所有槽位均为 padding则降级为全局平均
if torch.isnan(pooled).any():
pooled = moe_out.mean(dim=1)
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile
# 使用显式形状重塑,彻底杜绝广播歧义
batch_size = input_ids.size(0)
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch]
pooled = numerator / denominator.unsqueeze(-1) # [batch, dim]
logits = self.classifier(pooled) # [batch, vocab_size]
return logits

View File

@ -172,8 +172,8 @@ class Trainer:
# 设置状态文件
self.use_tensorboard = use_tensorboard
self.status_file = self.output_dir / status_file
# 如果状态文件已存在,则加载已有数据
self.training_status_data = self._load_existing_status_data()
# 不加载历史数据,直接初始化为空列表以覆盖原有数据
self.training_status_data = []
# 初始化Rich控制台
self.console = Console()
@ -239,13 +239,15 @@ class Trainer:
"""
self.model.train()
# 移动数据到设备
input_ids = batch["input_ids"].to(self.device)
token_type_ids = batch["token_type_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
history_slot_ids = batch["history_slot_ids"].to(self.device)
pinyin_ids = batch["pinyin_ids"].to(self.device)
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size]
# 移动数据到设备 (异步传输以提升 GPU 利用率)
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
token_type_ids = batch["token_type_ids"].to(self.device, non_blocking=True)
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
history_slot_ids = batch["history_slot_ids"].to(self.device, non_blocking=True)
pinyin_ids = batch["pinyin_ids"].to(self.device, non_blocking=True)
labels = (
batch["labels"].to(self.device, non_blocking=True).squeeze(-1)
) # [batch_size]
# 混合精度训练
with autocast(device_type=self.device.type, enabled=self.mixed_precision):