From 493bfdec1a8e4a1dbfe6626ea1e9c0284669f967 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 5 Apr 2026 23:19:11 +0800 Subject: [PATCH] =?UTF-8?q?refactor(MoELayer):=20=E5=B9=B6=E8=A1=8C?= =?UTF-8?q?=E5=8C=96=E5=89=8D=E5=90=91=E4=BC=A0=E6=92=AD=E4=BB=A5=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=20torch.compile=20=E5=92=8C=20AMP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/components.py | 71 ++++++++++++++++------------------------- src/model/model.py | 21 ++++++------ src/model/trainer.py | 20 ++++++------ 3 files changed, 48 insertions(+), 64 deletions(-) diff --git a/src/model/components.py b/src/model/components.py index ab2e344..b5ed7a5 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -303,60 +303,43 @@ class MoELayer(nn.Module): def forward(self, x): """ + 并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。 + Args: x: [batch, seq_len, dim] Returns: out: [batch, seq_len, dim] """ B, L, D = x.shape + num_tokens = B * L - # 1. Compute Gating Scores - gates = self.gate(x) # [B, L, num_experts] + # 展平输入以便处理 + x_flat = x.view(num_tokens, D) # [B*L, D] - # 2. Select Top-K Experts - topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K] + # 1. 计算门控分数 + gates = self.gate(x_flat) # [B*L, num_experts] - # Normalize weights for selected experts - weights = F.softmax(topk_vals, dim=-1) # [B, L, K] + # 2. 选择 Top-K 专家 + topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K] - # 3. Dispatch and Compute - # Initialize output - out = torch.zeros_like(x) + # 归一化权重 + topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K] - # Reshape for easier processing: flatten batch and sequence dimensions - x_flat = x.view(-1, D) # [B*L, D] - weights_flat = weights.view(-1, self.top_k) # [B*L, K] - topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K] + # 3. 并行计算所有专家(消除 Python 循环中的动态控制流) + # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量 + expert_outputs = torch.stack( + [expert(x_flat) for expert in self.experts], dim=1 + ) # [B*L, num_experts, D] - # For each of the top-k positions - for k in range(self.top_k): - # Get expert indices and weights for this position - expert_indices = topk_indices_flat[:, k] # [B*L] - expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1] + # 4. 使用 gather 选择对应专家的输出 + # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D] + indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D] + selected_outputs = torch.gather( + expert_outputs, 1, indices_expanded + ) # [B*L, K, D] + # 5. 加权求和 + weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D] + out_flat = weighted_outputs.sum(dim=1) # [B*L, D] - # Process each expert separately - for e_idx in range(self.num_experts): - # Mask for tokens assigned to this expert at position k - mask = expert_indices == e_idx # [B*L] - if not mask.any(): - continue - - # Extract tokens for this expert - x_selected = x_flat[mask] # [N_selected, D] - if x_selected.numel() == 0: - continue - - # Pass through expert - expert_out = self.experts[e_idx](x_selected) # [N_selected, D] - - # Apply expert weights and add to output - weighted_out = expert_out * expert_weights[mask] - - # Scatter back to flat output - out_flat = out.view(-1, D) - out_flat[mask] += weighted_out - - # Reshape back to original shape - out = out.view(B, L, D) - - return out + # 恢复原始形状 + return out_flat.view(B, L, D) diff --git a/src/model/model.py b/src/model/model.py index ba9d492..4ff9a46 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -94,9 +94,9 @@ class InputMethodEngine(nn.Module): """ batch_size = input_ids.size(0) - # 处理 history_slot_ids:若为 [num_slots] 则扩展 batch 维度 - if history_slot_ids.dim() == 1: - history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1) + # 处理 history_slot_ids:确保为 [batch_size, num_slots] + # 使用 view 替代 if 判断,避免 torch.compile 图断开 + history_slot_ids = history_slot_ids.view(-1, self.num_slots) # 1. 上下文编码 -> H [batch, seq_len, dim] # 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask @@ -111,14 +111,13 @@ class InputMethodEngine(nn.Module): # 4. MoE 处理 -> [batch, num_slots, dim] moe_out = self.moe(fused) - # 5. 池化与分类:对槽位维度求平均(或使用 mask 池化) - # 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask - slot_mask = (history_slot_ids != 0).float() # [batch, num_slots] - slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1] - pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8) - # 如果所有槽位均为 padding,则降级为全局平均 - if torch.isnan(pooled).any(): - pooled = moe_out.mean(dim=1) + # 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile) + # 使用显式形状重塑,彻底杜绝广播歧义 + batch_size = input_ids.size(0) + slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1) + numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim] + denominator = slot_mask.view(batch_size, -1).sum(dim=1) + 1e-8 # [batch] + pooled = numerator / denominator.unsqueeze(-1) # [batch, dim] logits = self.classifier(pooled) # [batch, vocab_size] return logits diff --git a/src/model/trainer.py b/src/model/trainer.py index a411788..02f8449 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -172,8 +172,8 @@ class Trainer: # 设置状态文件 self.use_tensorboard = use_tensorboard self.status_file = self.output_dir / status_file - # 如果状态文件已存在,则加载已有数据 - self.training_status_data = self._load_existing_status_data() + # 不加载历史数据,直接初始化为空列表以覆盖原有数据 + self.training_status_data = [] # 初始化Rich控制台 self.console = Console() @@ -239,13 +239,15 @@ class Trainer: """ self.model.train() - # 移动数据到设备 - input_ids = batch["input_ids"].to(self.device) - token_type_ids = batch["token_type_ids"].to(self.device) - attention_mask = batch["attention_mask"].to(self.device) - history_slot_ids = batch["history_slot_ids"].to(self.device) - pinyin_ids = batch["pinyin_ids"].to(self.device) - labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size] + # 移动数据到设备 (异步传输以提升 GPU 利用率) + input_ids = batch["input_ids"].to(self.device, non_blocking=True) + token_type_ids = batch["token_type_ids"].to(self.device, non_blocking=True) + attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) + history_slot_ids = batch["history_slot_ids"].to(self.device, non_blocking=True) + pinyin_ids = batch["pinyin_ids"].to(self.device, non_blocking=True) + labels = ( + batch["labels"].to(self.device, non_blocking=True).squeeze(-1) + ) # [batch_size] # 混合精度训练 with autocast(device_type=self.device.type, enabled=self.mixed_precision):