refactor(MoELayer): 并行化前向传播以兼容 torch.compile 和 AMP

This commit is contained in:
songsenand 2026-04-05 23:19:11 +08:00
parent 7143896f4d
commit 493bfdec1a
3 changed files with 48 additions and 64 deletions

View File

@ -303,60 +303,43 @@ class MoELayer(nn.Module):
def forward(self, x): def forward(self, x):
""" """
并行化 MoE 前向传播完全兼容 torch.compile AMP
Args: Args:
x: [batch, seq_len, dim] x: [batch, seq_len, dim]
Returns: Returns:
out: [batch, seq_len, dim] out: [batch, seq_len, dim]
""" """
B, L, D = x.shape B, L, D = x.shape
num_tokens = B * L
# 1. Compute Gating Scores # 展平输入以便处理
gates = self.gate(x) # [B, L, num_experts] x_flat = x.view(num_tokens, D) # [B*L, D]
# 2. Select Top-K Experts # 1. 计算门控分数
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K] gates = self.gate(x_flat) # [B*L, num_experts]
# Normalize weights for selected experts # 2. 选择 Top-K 专家
weights = F.softmax(topk_vals, dim=-1) # [B, L, K] topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
# 3. Dispatch and Compute # 归一化权重
# Initialize output topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
out = torch.zeros_like(x)
# Reshape for easier processing: flatten batch and sequence dimensions # 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
x_flat = x.view(-1, D) # [B*L, D] # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
weights_flat = weights.view(-1, self.top_k) # [B*L, K] expert_outputs = torch.stack(
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K] [expert(x_flat) for expert in self.experts], dim=1
) # [B*L, num_experts, D]
# For each of the top-k positions # 4. 使用 gather 选择对应专家的输出
for k in range(self.top_k): # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
# Get expert indices and weights for this position indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
expert_indices = topk_indices_flat[:, k] # [B*L] selected_outputs = torch.gather(
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1] expert_outputs, 1, indices_expanded
) # [B*L, K, D]
# 5. 加权求和
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
# Process each expert separately # 恢复原始形状
for e_idx in range(self.num_experts): return out_flat.view(B, L, D)
# Mask for tokens assigned to this expert at position k
mask = expert_indices == e_idx # [B*L]
if not mask.any():
continue
# Extract tokens for this expert
x_selected = x_flat[mask] # [N_selected, D]
if x_selected.numel() == 0:
continue
# Pass through expert
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
# Apply expert weights and add to output
weighted_out = expert_out * expert_weights[mask]
# Scatter back to flat output
out_flat = out.view(-1, D)
out_flat[mask] += weighted_out
# Reshape back to original shape
out = out.view(B, L, D)
return out

View File

@ -94,9 +94,9 @@ class InputMethodEngine(nn.Module):
""" """
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
# 处理 history_slot_ids若为 [num_slots] 则扩展 batch 维度 # 处理 history_slot_ids确保为 [batch_size, num_slots]
if history_slot_ids.dim() == 1: # 使用 view 替代 if 判断,避免 torch.compile 图断开
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1) history_slot_ids = history_slot_ids.view(-1, self.num_slots)
# 1. 上下文编码 -> H [batch, seq_len, dim] # 1. 上下文编码 -> H [batch, seq_len, dim]
# 注意ContextEncoder.forward 接受 text_ids, pinyin_ids, mask # 注意ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
@ -111,14 +111,13 @@ class InputMethodEngine(nn.Module):
# 4. MoE 处理 -> [batch, num_slots, dim] # 4. MoE 处理 -> [batch, num_slots, dim]
moe_out = self.moe(fused) moe_out = self.moe(fused)
# 5. 池化与分类:对槽位维度求平均(或使用 mask 池化) # 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask # 使用显式形状重塑,彻底杜绝广播歧义
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots] batch_size = input_ids.size(0)
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1] slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8) numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
# 如果所有槽位均为 padding则降级为全局平均 denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch]
if torch.isnan(pooled).any(): pooled = numerator / denominator.unsqueeze(-1) # [batch, dim]
pooled = moe_out.mean(dim=1)
logits = self.classifier(pooled) # [batch, vocab_size] logits = self.classifier(pooled) # [batch, vocab_size]
return logits return logits

View File

@ -172,8 +172,8 @@ class Trainer:
# 设置状态文件 # 设置状态文件
self.use_tensorboard = use_tensorboard self.use_tensorboard = use_tensorboard
self.status_file = self.output_dir / status_file self.status_file = self.output_dir / status_file
# 如果状态文件已存在,则加载已有数据 # 不加载历史数据,直接初始化为空列表以覆盖原有数据
self.training_status_data = self._load_existing_status_data() self.training_status_data = []
# 初始化Rich控制台 # 初始化Rich控制台
self.console = Console() self.console = Console()
@ -239,13 +239,15 @@ class Trainer:
""" """
self.model.train() self.model.train()
# 移动数据到设备 # 移动数据到设备 (异步传输以提升 GPU 利用率)
input_ids = batch["input_ids"].to(self.device) input_ids = batch["input_ids"].to(self.device, non_blocking=True)
token_type_ids = batch["token_type_ids"].to(self.device) token_type_ids = batch["token_type_ids"].to(self.device, non_blocking=True)
attention_mask = batch["attention_mask"].to(self.device) attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
history_slot_ids = batch["history_slot_ids"].to(self.device) history_slot_ids = batch["history_slot_ids"].to(self.device, non_blocking=True)
pinyin_ids = batch["pinyin_ids"].to(self.device) pinyin_ids = batch["pinyin_ids"].to(self.device, non_blocking=True)
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size] labels = (
batch["labels"].to(self.device, non_blocking=True).squeeze(-1)
) # [batch_size]
# 混合精度训练 # 混合精度训练
with autocast(device_type=self.device.type, enabled=self.mixed_precision): with autocast(device_type=self.device.type, enabled=self.mixed_precision):