refactor(MoELayer): 并行化前向传播以兼容 torch.compile 和 AMP
This commit is contained in:
parent
7143896f4d
commit
493bfdec1a
|
|
@ -303,60 +303,43 @@ class MoELayer(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
"""
|
||||
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
||||
|
||||
Args:
|
||||
x: [batch, seq_len, dim]
|
||||
Returns:
|
||||
out: [batch, seq_len, dim]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
num_tokens = B * L
|
||||
|
||||
# 1. Compute Gating Scores
|
||||
gates = self.gate(x) # [B, L, num_experts]
|
||||
# 展平输入以便处理
|
||||
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||
|
||||
# 2. Select Top-K Experts
|
||||
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
|
||||
# 1. 计算门控分数
|
||||
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||
|
||||
# Normalize weights for selected experts
|
||||
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
|
||||
# 2. 选择 Top-K 专家
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||
|
||||
# 3. Dispatch and Compute
|
||||
# Initialize output
|
||||
out = torch.zeros_like(x)
|
||||
# 归一化权重
|
||||
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||
|
||||
# Reshape for easier processing: flatten batch and sequence dimensions
|
||||
x_flat = x.view(-1, D) # [B*L, D]
|
||||
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
|
||||
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
|
||||
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||
expert_outputs = torch.stack(
|
||||
[expert(x_flat) for expert in self.experts], dim=1
|
||||
) # [B*L, num_experts, D]
|
||||
|
||||
# For each of the top-k positions
|
||||
for k in range(self.top_k):
|
||||
# Get expert indices and weights for this position
|
||||
expert_indices = topk_indices_flat[:, k] # [B*L]
|
||||
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
|
||||
# 4. 使用 gather 选择对应专家的输出
|
||||
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||
selected_outputs = torch.gather(
|
||||
expert_outputs, 1, indices_expanded
|
||||
) # [B*L, K, D]
|
||||
# 5. 加权求和
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||
|
||||
# Process each expert separately
|
||||
for e_idx in range(self.num_experts):
|
||||
# Mask for tokens assigned to this expert at position k
|
||||
mask = expert_indices == e_idx # [B*L]
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
# Extract tokens for this expert
|
||||
x_selected = x_flat[mask] # [N_selected, D]
|
||||
if x_selected.numel() == 0:
|
||||
continue
|
||||
|
||||
# Pass through expert
|
||||
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
|
||||
|
||||
# Apply expert weights and add to output
|
||||
weighted_out = expert_out * expert_weights[mask]
|
||||
|
||||
# Scatter back to flat output
|
||||
out_flat = out.view(-1, D)
|
||||
out_flat[mask] += weighted_out
|
||||
|
||||
# Reshape back to original shape
|
||||
out = out.view(B, L, D)
|
||||
|
||||
return out
|
||||
# 恢复原始形状
|
||||
return out_flat.view(B, L, D)
|
||||
|
|
|
|||
|
|
@ -94,9 +94,9 @@ class InputMethodEngine(nn.Module):
|
|||
"""
|
||||
batch_size = input_ids.size(0)
|
||||
|
||||
# 处理 history_slot_ids:若为 [num_slots] 则扩展 batch 维度
|
||||
if history_slot_ids.dim() == 1:
|
||||
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
# 处理 history_slot_ids:确保为 [batch_size, num_slots]
|
||||
# 使用 view 替代 if 判断,避免 torch.compile 图断开
|
||||
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
||||
|
||||
# 1. 上下文编码 -> H [batch, seq_len, dim]
|
||||
# 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
|
||||
|
|
@ -111,14 +111,13 @@ class InputMethodEngine(nn.Module):
|
|||
# 4. MoE 处理 -> [batch, num_slots, dim]
|
||||
moe_out = self.moe(fused)
|
||||
|
||||
# 5. 池化与分类:对槽位维度求平均(或使用 mask 池化)
|
||||
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
|
||||
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
|
||||
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
|
||||
pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8)
|
||||
# 如果所有槽位均为 padding,则降级为全局平均
|
||||
if torch.isnan(pooled).any():
|
||||
pooled = moe_out.mean(dim=1)
|
||||
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile)
|
||||
# 使用显式形状重塑,彻底杜绝广播歧义
|
||||
batch_size = input_ids.size(0)
|
||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch]
|
||||
pooled = numerator / denominator.unsqueeze(-1) # [batch, dim]
|
||||
|
||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -172,8 +172,8 @@ class Trainer:
|
|||
# 设置状态文件
|
||||
self.use_tensorboard = use_tensorboard
|
||||
self.status_file = self.output_dir / status_file
|
||||
# 如果状态文件已存在,则加载已有数据
|
||||
self.training_status_data = self._load_existing_status_data()
|
||||
# 不加载历史数据,直接初始化为空列表以覆盖原有数据
|
||||
self.training_status_data = []
|
||||
|
||||
# 初始化Rich控制台
|
||||
self.console = Console()
|
||||
|
|
@ -239,13 +239,15 @@ class Trainer:
|
|||
"""
|
||||
self.model.train()
|
||||
|
||||
# 移动数据到设备
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
token_type_ids = batch["token_type_ids"].to(self.device)
|
||||
attention_mask = batch["attention_mask"].to(self.device)
|
||||
history_slot_ids = batch["history_slot_ids"].to(self.device)
|
||||
pinyin_ids = batch["pinyin_ids"].to(self.device)
|
||||
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size]
|
||||
# 移动数据到设备 (异步传输以提升 GPU 利用率)
|
||||
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
|
||||
token_type_ids = batch["token_type_ids"].to(self.device, non_blocking=True)
|
||||
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
|
||||
history_slot_ids = batch["history_slot_ids"].to(self.device, non_blocking=True)
|
||||
pinyin_ids = batch["pinyin_ids"].to(self.device, non_blocking=True)
|
||||
labels = (
|
||||
batch["labels"].to(self.device, non_blocking=True).squeeze(-1)
|
||||
) # [batch_size]
|
||||
|
||||
# 混合精度训练
|
||||
with autocast(device_type=self.device.type, enabled=self.mixed_precision):
|
||||
|
|
|
|||
Loading…
Reference in New Issue