From 3da8ae8876e553e1322f227779ac3b25fc11374c Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 6 Apr 2026 22:53:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(docs):=20=E6=B7=BB=E5=8A=A0HTTP=E9=9D=99?= =?UTF-8?q?=E6=80=81=E6=96=87=E4=BB=B6=E6=9C=8D=E5=8A=A1=E4=B8=8E=E8=BF=9C?= =?UTF-8?q?=E7=A8=8B=E7=9B=91=E6=8E=A7=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 60 ++++ output/training_status.json | 1 + pyproject.toml | 1 - src/model/components.py | 71 ++--- src/model/components_backup.py | 364 ++++++++++++++++++++++++ src/model/model.py | 4 +- src/model/trainer.py | 42 +-- test_compile.py | 487 +++++++++++++++++++++++++++++++++ uv.lock | 15 - 9 files changed, 942 insertions(+), 103 deletions(-) create mode 100644 output/training_status.json create mode 100644 src/model/components_backup.py create mode 100644 test_compile.py diff --git a/README.md b/README.md index ec7f7d8..b42dd85 100644 --- a/README.md +++ b/README.md @@ -626,6 +626,21 @@ monitor-training check monitor-training check ./output/training_status.json ``` +**启动HTTP静态文件服务** +```bash +# 启动HTTP静态文件服务(默认端口8080) +monitor-training serve + +# 指定状态文件路径和端口 +monitor-training serve --status-file ./output/training_status.json --port 8080 + +# 禁用CORS支持(默认启用) +monitor-training serve --no-cors + +# 指定主机地址 +monitor-training serve --host 0.0.0.0 --port 8080 +``` + #### 监控界面功能 **📊 核心指标看板** @@ -689,6 +704,18 @@ http://192.168.1.100:8080 http://your-server.com:8501 ``` +**远程HTTP监控** +```bash +# GPU服务器启动HTTP服务 +monitor-training serve --port 8080 --host 0.0.0.0 + +# 本地运行Streamlit监控,从HTTP URL读取数据 +monitor-training monitor --host 127.0.0.1 --port 8501 + +# 在Streamlit界面输入远程URL: +http://:8080/training_status.json +``` + #### 状态文件格式 状态文件 `training_status.json` 位于训练输出目录,格式如下: @@ -708,11 +735,44 @@ http://your-server.com:8501 ] ``` +#### HTTP静态文件服务与远程监控 + +针对GPU服务器只支持HTTP协议(不支持WebSockets)的环境,我们提供了HTTP静态文件服务方案,实现远程训练监控。 + +**🔧 技术特点** +- 纯HTTP协议,无需WebSockets支持 +- 原子写入机制,避免读取不完整JSON数据 +- 自动重试和JSON验证,确保数据完整性 +- CORS支持,方便跨域访问 +- 轻量级设计,不影响训练性能 + +**🚀 工作原理** +1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件 +2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务 +3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面 +4. 本地机器:在Streamlit界面输入HTTP URL访问远程数据 +5. Streamlit:通过HTTP轮询获取实时训练数据并展示 + +**🛡️ 数据安全** +- 原子写入:先写入临时文件,然后原子重命名,避免读取中断 +- JSON验证:HTTP服务端验证JSON格式后才返回数据 +- 临时文件处理:智能识别和读取`.tmp`临时文件 +- 重试机制:JSON解析失败时自动重试读取 + +**🌐 网络要求** +- GPU服务器:需要开放HTTP端口(默认8080) +- 本地机器:需要能访问GPU服务器的HTTP端口 +- 网络协议:纯HTTP,兼容防火墙和代理 + #### 注意事项 1. 首次监控时如果状态文件不存在,会自动创建空文件 2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0` 3. 从检查点恢复训练时会自动加载已有的状态数据 4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟 +5. HTTP服务支持原子写入,避免训练进程写入时读取不完整JSON +6. 远程监控需要确保GPU服务器防火墙开放对应HTTP端口 +7. 建议使用`--host 0.0.0.0`参数使HTTP服务可被远程访问 + ### 6.7 评估模型(开发中) diff --git a/output/training_status.json b/output/training_status.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/output/training_status.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f5d8fa2..64b509b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "streamlit>=1.56.0", "tensorboard>=2.20.0", "torch>=2.10.0", - "torchdata>=0.11.0", "transformers==5.1.0", "typer>=0.21.1", ] diff --git a/src/model/components.py b/src/model/components.py index dd87fc0..d8988f9 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -246,7 +246,7 @@ class CrossAttentionFusion(nn.Module): bool_mask = context_mask == 0 # [batch, ctx_len] bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len] # Convert to float mask where True (padding) becomes -inf - attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf")) + attn_mask = bool_mask.float().masked_fill(bool_mask, -1e9) # Scaled dot-product attention attn_output = F.scaled_dot_product_attention( @@ -311,54 +311,35 @@ class MoELayer(nn.Module): out: [batch, seq_len, dim] """ B, L, D = x.shape + num_tokens = B * L - # 1. Compute Gating Scores - gates = self.gate(x) # [B, L, num_experts] + # 展平输入以便处理 + x_flat = x.view(num_tokens, D) # [B*L, D] - # 2. Select Top-K Experts - topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K] + # 1. 计算门控分数 + gates = self.gate(x_flat) # [B*L, num_experts] - # Normalize weights for selected experts - weights = F.softmax(topk_vals, dim=-1) # [B, L, K] + # 2. 选择 Top-K 专家 + topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K] - # 3. Dispatch and Compute - # Initialize output - out = torch.zeros_like(x) + # 归一化权重 + topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K] - # Reshape for easier processing: flatten batch and sequence dimensions - x_flat = x.view(-1, D) # [B*L, D] - weights_flat = weights.view(-1, self.top_k) # [B*L, K] - topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K] + # 3. 并行计算所有专家(消除 Python 循环中的动态控制流) + # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量 + expert_outputs = torch.stack( + [expert(x_flat) for expert in self.experts], dim=1 + ) # [B*L, num_experts, D] - # For each of the top-k positions - for k in range(self.top_k): - # Get expert indices and weights for this position - expert_indices = topk_indices_flat[:, k] # [B*L] - expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1] + # 4. 使用 gather 选择对应专家的输出 + # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D] + indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D] + selected_outputs = torch.gather( + expert_outputs, 1, indices_expanded + ) # [B*L, K, D] + # 5. 加权求和 + weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D] + out_flat = weighted_outputs.sum(dim=1) # [B*L, D] - # Process each expert separately - for e_idx in range(self.num_experts): - # Mask for tokens assigned to this expert at position k - mask = expert_indices == e_idx # [B*L] - if not mask.any(): - continue - - # Extract tokens for this expert - x_selected = x_flat[mask] # [N_selected, D] - if x_selected.numel() == 0: - continue - - # Pass through expert - expert_out = self.experts[e_idx](x_selected) # [N_selected, D] - - # Apply expert weights and add to output - weighted_out = expert_out * expert_weights[mask] - - # Scatter back to flat output - out_flat = out.view(-1, D) - out_flat[mask] += weighted_out - - # Reshape back to original shape - out = out.view(B, L, D) - - return out + # 恢复原始形状 + return out_flat.view(B, L, D) diff --git a/src/model/components_backup.py b/src/model/components_backup.py new file mode 100644 index 0000000..dd87fc0 --- /dev/null +++ b/src/model/components_backup.py @@ -0,0 +1,364 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from modelscope import AutoModel + + +# ---------------------------- 注意力池化模块---------------------------- +class AttentionPooling(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.attn = nn.Linear(hidden_size, 1) + # 三个可学习偏置:文本、拼音、个性化 + self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias] + + def forward(self, x, mask=None, token_type_ids=None): + scores = self.attn(x).squeeze(-1) # [batch, seq_len] + if token_type_ids is not None: + # 根据 token_type_ids 添加对应偏置 + # bias 形状 [3],通过索引扩展为 [batch, seq_len] + bias_per_token = self.bias[token_type_ids] # [batch, seq_len] + scores = scores + bias_per_token + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + weights = torch.softmax(scores, dim=-1) + pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1) + return pooled + + +# ---------------------------- 残差块 ---------------------------- +class ResidualBlock(nn.Module): + def __init__(self, dim, dropout_prob=0.3): + super().__init__() + self.linear1 = nn.Linear(dim, dim) + self.ln1 = nn.LayerNorm(dim) + self.linear2 = nn.Linear(dim, dim) + self.ln2 = nn.LayerNorm(dim) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, x): + residual = x + # 修复:使用 self.gelu 而不是未定义的 self.relu + x = self.gelu(self.linear1(x)) + x = self.ln1(x) + x = self.linear2(x) + x = self.ln2(x) + x = self.dropout(x) + x = x + residual + return self.gelu(x) + + +# ---------------------------- 专家网络 ---------------------------- +class Expert(nn.Module): + def __init__( + self, + input_dim, + d_model=512, + num_resblocks=4, + output_multiplier=1, + dropout_prob=0.3, + ): + super().__init__() + self.output_dim = input_dim * output_multiplier + self.linear_in = nn.Linear(input_dim, d_model) + self.res_blocks = nn.ModuleList( + [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] + ) + self.output = nn.Sequential( + nn.Linear(d_model, d_model), + nn.GELU(), + nn.Dropout(dropout_prob), + nn.Linear(d_model, self.output_dim), + ) + + def forward(self, x): + x = self.linear_in(x) + for block in self.res_blocks: + x = block(x) + return self.output(x) + + +# ------------------------------------------------------------------ +# 1. 上下文编码器 (Context Encoder) +# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1] +# ------------------------------------------------------------------ +class ContextEncoder(nn.Module): + def __init__( + self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128 + ): + super().__init__() + self.dim = dim + + # Embeddings + self.text_emb = AutoModel.from_pretrained( + "iic/nlp_structbert_backbone_lite_std" + ).embeddings + self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim) + self.pos_emb = nn.Embedding(max_len, dim) + self.pinyin_pooling = AttentionPooling(dim) + + # Transformer Encoder (4 layers, 4 heads) [1] + encoder_layer = nn.TransformerEncoderLayer( + d_model=dim, + nhead=n_heads, + dim_feedforward=dim * 4, + dropout=0.1, + batch_first=True, # 方便处理 [B, L, D] + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + + # LayerNorm for stability + self.ln = nn.LayerNorm(dim) + + def forward(self, text_ids, pinyin_ids, mask=None): + """ + Args: + text_ids: [batch, seq_len] + pinyin_ids: [batch, seq_len] (假设已对齐,若不对齐需预处理) + mask: [batch, seq_len] optional padding mask + Returns: + H: [batch, seq_len, 512] Context representation [1] + """ + # 1. Embed text + text_emb = self.text_emb(text_ids) # [B, 128, dim] + + # 2. Embed and pool pinyin to global feature + pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim] + # 方式1:Attention Pooling(推荐) + pinyin_global = self.pinyin_pooling( + pinyin_emb, mask=None + ) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position + + # Broadcast pinyin to all text positions + pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim] + pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim] + + # 策略:拼音作为增强特征叠加到文本上,符合轻量级设计 + x = text_emb + pinyin_broadcast + + seq_len = x.size(1) + pos_ids = ( + torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids) + ) + x += self.pos_emb(pos_ids) + + # 2. Transformer Encoding + # src_key_padding_mask expects True for padding positions + if mask is not None: + # Convert 0/1 mask to bool mask where True is padding + src_mask = mask == 0 + else: + src_mask = None + + H = self.transformer(x, src_key_padding_mask=src_mask) + return self.ln(H) + + +# ------------------------------------------------------------------ +# 2. 槽位记忆模块 (Slot Memory) +# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1] +# ------------------------------------------------------------------ +class SlotMemory(nn.Module): + def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512): + super().__init__() + self.max_slots = max_slots + self.steps_per_slot = steps_per_slot + self.total_steps = max_slots * steps_per_slot # 24 steps [1] + + # Shared embedding layer for history tokens [1] + self.emb = nn.Embedding(vocab_size, dim) + + # Learnable positional embeddings for the flattened sequence [1] + self.pos_emb = nn.Embedding(self.total_steps, dim) + + # Start token embedding for empty slots [1] + self.start_emb = nn.Parameter(torch.randn(1, 1, dim)) + + def forward(self, history_ids): + """ + Args: + history_ids: [batch, total_steps] + Flattened sequence of history tokens. + Empty positions should be filled with a special PAD or handled via mask. + Returns: + S: [batch, total_steps, 512] Slot sequence representation [1] + """ + # Embed history tokens + S = self.emb(history_ids) # [B, 24, 512] + + # Add positional embeddings + pos_ids = ( + torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids) + ) + S += self.pos_emb(pos_ids) + + return S + + +# ------------------------------------------------------------------ +# 3. 交叉注意力融合 (Cross-Attention Fusion) +# 对应 README: Query=Slots, Key/Value=Context H [1] +# ------------------------------------------------------------------ +class CrossAttentionFusion(nn.Module): + def __init__(self, dim=512, n_heads=4): + super().__init__() + self.dim = dim + self.n_heads = n_heads + self.head_dim = dim // n_heads + assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads" + + # Linear projections for Q, K, V + self.q_proj = nn.Linear(dim, dim, bias=False) + self.k_proj = nn.Linear(dim, dim, bias=False) + self.v_proj = nn.Linear(dim, dim, bias=False) + self.out_proj = nn.Linear(dim, dim, bias=False) + self.ln = nn.LayerNorm(dim) + + def forward(self, slots_S, context_H, slot_mask=None, context_mask=None): + """ + Args: + slots_S: [batch, num_slots_steps, dim] Query + context_H: [batch, ctx_len, dim] Key/Value + slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention) + context_mask: [batch, ctx_len] Optional padding mask + Returns: + Fused: [batch, num_slots_steps, dim] + """ + batch_size, num_slots, _ = slots_S.shape + _, ctx_len, _ = context_H.shape + + # Project queries, keys, values + Q = self.q_proj(slots_S) # [batch, num_slots, dim] + K = self.k_proj(context_H) # [batch, ctx_len, dim] + V = self.v_proj(context_H) # [batch, ctx_len, dim] + + # Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim] + Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2) + K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2) + V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2) + + # Prepare attention mask if context_mask is provided + attn_mask = None + if context_mask is not None: + # context_mask: [batch, ctx_len] where 0 means padding + # Convert to bool mask and reshape for broadcasting + bool_mask = context_mask == 0 # [batch, ctx_len] + bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len] + # Convert to float mask where True (padding) becomes -inf + attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf")) + + # Scaled dot-product attention + attn_output = F.scaled_dot_product_attention( + Q, + K, + V, + attn_mask=attn_mask, + dropout_p=0.0, # no dropout + ) + + # Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim] + attn_output = ( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, num_slots, self.dim) + ) + + # Project back + fused = self.out_proj(attn_output) + + # Residual connection and layer norm + fused = self.ln(fused + slots_S) + + return fused + + +# ------------------------------------------------------------------ +# 4. 专家混合层 (MoE Layer) +# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类 +# ------------------------------------------------------------------ +class MoELayer(nn.Module): + def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.dim = dim + + # Import Expert from your existing components + # Assuming Expert class is defined as in components.py [2] + self.experts = nn.ModuleList( + [ + Expert( + input_dim=dim, + d_model=dim, + num_resblocks=export_resblocks, + output_multiplier=1, + ) + for _ in range(num_experts) + ] + ) + + # Gating Network [2] + self.gate = nn.Linear(dim, num_experts) + + def forward(self, x): + """ + 并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。 + + Args: + x: [batch, seq_len, dim] + Returns: + out: [batch, seq_len, dim] + """ + B, L, D = x.shape + + # 1. Compute Gating Scores + gates = self.gate(x) # [B, L, num_experts] + + # 2. Select Top-K Experts + topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K] + + # Normalize weights for selected experts + weights = F.softmax(topk_vals, dim=-1) # [B, L, K] + + # 3. Dispatch and Compute + # Initialize output + out = torch.zeros_like(x) + + # Reshape for easier processing: flatten batch and sequence dimensions + x_flat = x.view(-1, D) # [B*L, D] + weights_flat = weights.view(-1, self.top_k) # [B*L, K] + topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K] + + # For each of the top-k positions + for k in range(self.top_k): + # Get expert indices and weights for this position + expert_indices = topk_indices_flat[:, k] # [B*L] + expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1] + + # Process each expert separately + for e_idx in range(self.num_experts): + # Mask for tokens assigned to this expert at position k + mask = expert_indices == e_idx # [B*L] + if not mask.any(): + continue + + # Extract tokens for this expert + x_selected = x_flat[mask] # [N_selected, D] + if x_selected.numel() == 0: + continue + + # Pass through expert + expert_out = self.experts[e_idx](x_selected) # [N_selected, D] + + # Apply expert weights and add to output + weighted_out = expert_out * expert_weights[mask] + + # Scatter back to flat output + out_flat = out.view(-1, D) + out_flat[mask] += weighted_out + + # Reshape back to original shape + out = out.view(B, L, D) + + return out diff --git a/src/model/model.py b/src/model/model.py index 4ff9a46..5aafe4c 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -79,7 +79,9 @@ class InputMethodEngine(nn.Module): # 开启 torch.compile 优化 (如果请求) if compile: - self.forward = torch.compile(self.forward) + self.forward = torch.compile( + self.forward, mode="reduce-overhead", fullgraph=True + ) def forward( self, diff --git a/src/model/trainer.py b/src/model/trainer.py index cf679a5..88a789f 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1,8 +1,6 @@ import json import math -import os import random -import tempfile from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -29,16 +27,6 @@ from torch.amp.grad_scaler import GradScaler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -# Try to import DataLoader2 for better streaming dataset support -try: - from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService - - DATA_LOADER2_AVAILABLE = True -except ImportError: - DATA_LOADER2_AVAILABLE = False - DataLoader2 = None - MultiProcessingReadingService = None - from .dataset import PinyinInputDataset # 导入模型和数据 @@ -770,34 +758,6 @@ def create_dataloader( Returns: 数据加载器实例 """ - if ( - DATA_LOADER2_AVAILABLE - and DataLoader2 is not None - and MultiProcessingReadingService is not None - ): - try: - # DataLoader2配置,针对流式数据集优化 - reading_service = MultiProcessingReadingService( - num_workers=num_workers, - prefetch_factor=2, # 减少预取以避免内存问题 - persistent_workers=True, - pin_memory=pin_memory, - worker_init_fn=worker_init_fn, - ) - - dataloader = DataLoader2( - dataset, - reading_service=reading_service, - batch_size=batch_size, - collate_fn=collate_fn, - shuffle=shuffle, - ) - logger.info(f"✅ 使用DataLoader2创建数据加载器,worker数量: {num_workers}") - return dataloader - except Exception as e: - logger.warning(f"⚠️ DataLoader2创建失败: {e},回退到标准DataLoader") - - # 回退到标准DataLoader logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}") dataloader = DataLoader( dataset, @@ -1012,7 +972,7 @@ def train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=1000, + shuffle_buffer_size=50000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) diff --git a/test_compile.py b/test_compile.py new file mode 100644 index 0000000..dfa1d98 --- /dev/null +++ b/test_compile.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python3 +""" +测试 torch.compile 兼容性的脚本 +用于验证 components.py 中的修改是否与 torch.compile 兼容 +""" + +import os +import sys +import time + +import torch +import torch.nn as nn + +# 添加 src 目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.model.components import ( + AttentionPooling, + ContextEncoder, + CrossAttentionFusion, + Expert, + MoELayer, + ResidualBlock, + SlotMemory, +) + + +def test_attention_pooling(): + """测试 AttentionPooling 模块""" + print("=" * 60) + print("测试 AttentionPooling 模块") + + batch_size = 4 + seq_len = 10 + hidden_size = 512 + + # 创建模块 + attn_pool = AttentionPooling(hidden_size) + + # 测试数据 + x = torch.randn(batch_size, seq_len, hidden_size) + mask = torch.ones(batch_size, seq_len, dtype=torch.long) + mask[0, 5:] = 0 # 第一个样本的部分位置mask + token_type_ids = torch.randint(0, 3, (batch_size, seq_len)) + + # 测试未编译版本 + output = attn_pool(x, mask=mask, token_type_ids=token_type_ids) + print(f"✓ AttentionPooling 输出形状: {output.shape}") + assert output.shape == (batch_size, hidden_size) + + # 测试编译版本 + try: + compiled_attn_pool = torch.compile(attn_pool, mode="reduce-overhead") + compiled_output = compiled_attn_pool( + x, mask=mask, token_type_ids=token_type_ids + ) + + # 检查输出是否一致 + diff = torch.abs(output - compiled_output).max().item() + print(f"✓ 编译前后输出差异: {diff:.6f}") + assert diff < 1e-4, f"输出差异过大: {diff}" + print("✓ AttentionPooling 通过 torch.compile 测试") + except Exception as e: + print(f"⚠ AttentionPooling torch.compile 测试失败: {e}") + raise + + +def test_moe_layer(): + """测试 MoELayer 模块""" + print("\n" + "=" * 60) + print("测试 MoELayer 模块") + + batch_size = 2 + seq_len = 8 + dim = 512 + num_experts = 20 + top_k = 2 + + # 创建模块 + moe = MoELayer(dim=dim, num_experts=num_experts, top_k=top_k) + + # 测试数据 + x = torch.randn(batch_size, seq_len, dim) + + # 测试未编译版本 + output = moe(x) + print(f"✓ MoELayer 输出形状: {output.shape}") + assert output.shape == (batch_size, seq_len, dim) + + # 检查门控权重 + gates = moe.gate(x.view(-1, dim)) + topk_vals, topk_indices = torch.topk(gates, top_k, dim=-1) + print(f"✓ 门控权重形状: {gates.shape}, top-k 索引形状: {topk_indices.shape}") + + # 测试编译版本 + try: + compiled_moe = torch.compile(moe, mode="reduce-overhead") + + # 预热 + for _ in range(3): + _ = compiled_moe(x) + + compiled_output = compiled_moe(x) + + # 检查输出是否一致 + diff = torch.abs(output - compiled_output).max().item() + print(f"✓ 编译前后输出差异: {diff:.6f}") + assert diff < 1e-4, f"输出差异过大: {diff}" + print("✓ MoELayer 通过 torch.compile 测试") + + # 性能测试 + n_iter = 50 + print(f"\n性能测试 ({n_iter} 次迭代):") + + # 未编译版本 + torch.cuda.synchronize() if torch.cuda.is_available() else None + start = time.time() + for _ in range(n_iter): + _ = moe(x) + torch.cuda.synchronize() if torch.cuda.is_available() else None + base_time = time.time() - start + + # 编译版本 + torch.cuda.synchronize() if torch.cuda.is_available() else None + start = time.time() + for _ in range(n_iter): + _ = compiled_moe(x) + torch.cuda.synchronize() if torch.cuda.is_available() else None + compiled_time = time.time() - start + + print(f" 未编译: {base_time:.4f} 秒") + print(f" 已编译: {compiled_time:.4f} 秒") + if compiled_time > 0: + speedup = base_time / compiled_time + print(f" 加速比: {speedup:.2f}x") + + except Exception as e: + print(f"⚠ MoELayer torch.compile 测试失败: {e}") + import traceback + + traceback.print_exc() + raise + + +def test_cross_attention_fusion(): + """测试 CrossAttentionFusion 模块""" + print("\n" + "=" * 60) + print("测试 CrossAttentionFusion 模块") + + batch_size = 2 + num_slots = 8 + ctx_len = 16 + dim = 512 + + # 创建模块 + cross_attn = CrossAttentionFusion(dim=dim, n_heads=4) + + # 测试数据 + slots_S = torch.randn(batch_size, num_slots, dim) + context_H = torch.randn(batch_size, ctx_len, dim) + context_mask = torch.ones(batch_size, ctx_len, dtype=torch.long) + context_mask[0, 10:] = 0 # 第一个样本的部分位置mask + + # 测试未编译版本 + output = cross_attn(slots_S, context_H, context_mask=context_mask) + print(f"✓ CrossAttentionFusion 输出形状: {output.shape}") + assert output.shape == (batch_size, num_slots, dim) + + # 测试编译版本 + try: + compiled_cross_attn = torch.compile(cross_attn, mode="reduce-overhead") + compiled_output = compiled_cross_attn( + slots_S, context_H, context_mask=context_mask + ) + + # 检查输出是否一致 + diff = torch.abs(output - compiled_output).max().item() + print(f"✓ 编译前后输出差异: {diff:.6f}") + assert diff < 1e-4, f"输出差异过大: {diff}" + print("✓ CrossAttentionFusion 通过 torch.compile 测试") + except Exception as e: + print(f"⚠ CrossAttentionFusion torch.compile 测试失败: {e}") + raise + + +def test_context_encoder(): + """测试 ContextEncoder 模块""" + print("\n" + "=" * 60) + print("测试 ContextEncoder 模块") + + batch_size = 2 + seq_len = 32 + vocab_size = 1000 + pinyin_vocab_size = 30 + dim = 512 + + # 创建模块 + context_encoder = ContextEncoder( + vocab_size=vocab_size, + pinyin_vocab_size=pinyin_vocab_size, + dim=dim, + n_layers=2, # 测试时减少层数 + n_heads=4, + max_len=seq_len, + ) + + # 测试数据 + text_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + pinyin_ids = torch.randint( + 0, pinyin_vocab_size, (batch_size, 24) + ) # pinyin_ids固定长度24 + mask = torch.ones(batch_size, seq_len, dtype=torch.long) + mask[0, 20:] = 0 # 第一个样本的部分位置mask + + # 测试未编译版本 + output = context_encoder(text_ids, pinyin_ids, mask=mask) + print(f"✓ ContextEncoder 输出形状: {output.shape}") + assert output.shape == (batch_size, seq_len, dim) + + # 测试编译版本(ContextEncoder包含外部模型,可能不完全兼容) + try: + compiled_context_encoder = torch.compile( + context_encoder, mode="reduce-overhead" + ) + compiled_output = compiled_context_encoder(text_ids, pinyin_ids, mask=mask) + + # 检查输出是否一致 + diff = torch.abs(output - compiled_output).max().item() + print(f"✓ 编译前后输出差异: {diff:.6f}") + if diff < 1e-3: # ContextEncoder 可能精度要求稍低 + print("✓ ContextEncoder 通过 torch.compile 测试") + else: + print(f"⚠ ContextEncoder 输出差异较大: {diff:.6f} (可能因外部模型)") + except Exception as e: + print( + f"⚠ ContextEncoder torch.compile 测试失败: {e} (可能是正常现象,因为包含外部模型)" + ) + + +def test_full_model_compile(): + """测试完整模型编译""" + print("\n" + "=" * 60) + print("测试完整模型编译") + + # 导入完整模型 + from src.model.model import InputMethodEngine + + batch_size = 2 + vocab_size = 10019 + pinyin_vocab_size = 30 + dim = 512 + num_slots = 8 + seq_len = 32 + + # 创建模型 + model = InputMethodEngine( + vocab_size=vocab_size, + pinyin_vocab_size=pinyin_vocab_size, + dim=dim, + num_slots=num_slots, + n_layers=2, # 测试时减少层数 + n_heads=4, + num_experts=20, + max_seq_len=seq_len, + compile=False, # 手动控制编译 + ) + + # 测试数据 + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + token_type_ids = torch.randint(0, 2, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + attention_mask[0, 20:] = 0 + pinyin_ids = torch.randint(0, pinyin_vocab_size, (batch_size, 24)) + history_slot_ids = torch.randint(0, vocab_size, (batch_size, num_slots)) + + # 测试未编译版本 + with torch.no_grad(): + output = model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + + print(f"✓ 完整模型输出形状: {output.shape}") + assert output.shape == (batch_size, vocab_size) + + # 测试编译版本 + try: + # 创建一个新模型用于编译测试 + model_for_compile = InputMethodEngine( + vocab_size=vocab_size, + pinyin_vocab_size=pinyin_vocab_size, + dim=dim, + num_slots=num_slots, + n_layers=2, + n_heads=4, + num_experts=20, + max_seq_len=seq_len, + compile=False, + ) + + # 手动编译 + compiled_model = torch.compile(model_for_compile, mode="reduce-overhead") + + # 预热 + for _ in range(3): + _ = compiled_model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + + with torch.no_grad(): + compiled_output = compiled_model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + + # 检查输出是否一致 + diff = torch.abs(output - compiled_output).max().item() + print(f"✓ 编译前后输出差异: {diff:.6f}") + + # 完整模型精度要求可能较低 + if diff < 1e-3: + print("✓ 完整模型通过 torch.compile 测试") + else: + print(f"⚠ 完整模型输出差异较大: {diff:.6f}") + + # 性能测试 + n_iter = 30 + print(f"\n完整模型性能测试 ({n_iter} 次迭代):") + + # 未编译版本 + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + with torch.no_grad(): + for _ in range(n_iter): + _ = model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + base_time = time.time() - start + + # 编译版本 + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + with torch.no_grad(): + for _ in range(n_iter): + _ = compiled_model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + pinyin_ids=pinyin_ids, + history_slot_ids=history_slot_ids, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + compiled_time = time.time() - start + + print(f" 未编译: {base_time:.4f} 秒") + print(f" 已编译: {compiled_time:.4f} 秒") + if compiled_time > 0: + speedup = base_time / compiled_time + print(f" 加速比: {speedup:.2f}x") + + except Exception as e: + print(f"⚠ 完整模型 torch.compile 测试失败: {e}") + import traceback + + traceback.print_exc() + + +def check_compile_issues(): + """检查可能导致编译问题的代码模式""" + print("\n" + "=" * 60) + print("检查编译问题") + + issues = [] + + # 检查 components.py 中的潜在问题 + with open("src/model/components.py", "r") as f: + content = f.read() + + # 检查 float('-inf') + if "float('-inf')" in content: + issues.append("❌ 发现 float('-inf'),应替换为 -1e9") + + # 检查 .item() 调用 + if ".item()" in content and "def forward" in content: + # 统计 forward 方法中的 .item() 调用 + lines = content.split("\n") + in_forward = False + item_calls = [] + for i, line in enumerate(lines): + if "def forward" in line: + in_forward = True + elif in_forward and "def " in line and "def forward" not in line: + in_forward = False + if in_forward and ".item()" in line: + item_calls.append((i + 1, line.strip())) + + if item_calls: + issues.append( + f"❌ 在 forward 方法中发现 {len(item_calls)} 个 .item() 调用:" + ) + for line_num, line in item_calls[:3]: # 显示前3个 + issues.append(f" 第 {line_num} 行: {line}") + + # 检查动态控制流 + dynamic_patterns = [ + "if not mask.any():", + "if mask is not None:", + "if token_mask.any():", + "continue", + ] + for pattern in dynamic_patterns: + if pattern in content: + # 检查是否在 forward 方法中 + issues.append(f"⚠ 发现动态控制流: {pattern}") + + if not issues: + print("✓ 未发现明显的编译问题") + else: + print("发现以下可能影响编译的问题:") + for issue in issues: + print(issue) + + +def main(): + """主测试函数""" + print("=" * 70) + print("torch.compile 兼容性测试") + print("=" * 70) + + # 设置设备 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {device}") + print(f"PyTorch 版本: {torch.__version__}") + + # 检查 torch.compile 是否可用 + if hasattr(torch, "compile"): + print(f"✓ torch.compile 可用") + else: + print("❌ torch.compile 不可用,需要 PyTorch 2.0+") + return + + # 运行测试 + try: + # 首先检查代码问题 + check_compile_issues() + + # 模块测试 + test_attention_pooling() + test_moe_layer() + test_cross_attention_fusion() + test_context_encoder() + + # 完整模型测试 + test_full_model_compile() + + print("\n" + "=" * 70) + print("✅ 所有测试完成!") + + except Exception as e: + print(f"\n❌ 测试失败: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/uv.lock b/uv.lock index 9611c2b..ffe8c7b 100644 --- a/uv.lock +++ b/uv.lock @@ -2305,7 +2305,6 @@ dependencies = [ { name = "streamlit" }, { name = "tensorboard" }, { name = "torch" }, - { name = "torchdata" }, { name = "transformers" }, { name = "typer" }, ] @@ -2333,7 +2332,6 @@ requires-dist = [ { name = "streamlit", specifier = ">=1.56.0" }, { name = "tensorboard", specifier = ">=2.20.0" }, { name = "torch", specifier = ">=2.10.0" }, - { name = "torchdata", specifier = ">=0.11.0" }, { name = "transformers", specifier = "==5.1.0" }, { name = "typer", specifier = ">=0.21.1" }, ] @@ -2473,19 +2471,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/cf/bf/c8d12a2c86dbfd7f40fb2f56fbf5a505ccf2d9ce131eb559dfc7c51e1a04/torch-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b2a43985ff5ef6ddd923bbcf99943e5f58059805787c5c9a2622bf05ca2965b0" }, ] -[[package]] -name = "torchdata" -version = "0.11.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple/" } -dependencies = [ - { name = "requests" }, - { name = "torch" }, - { name = "urllib3" }, -] -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/95/d4/af694ef718aedbe95a72760ab9ff7a6a7a44ace2d7f70c27bfeb67c5c503/torchdata-0.11.0-py3-none-any.whl", hash = "sha256:52b940fbbe0e00fb21cabddf528449d1bec5bfb0d0823b7487b15f951658ee33" }, -] - [[package]] name = "tornado" version = "6.5.5"