feat(docs): 添加HTTP静态文件服务与远程监控说明
This commit is contained in:
parent
2f0166c8ce
commit
3da8ae8876
60
README.md
60
README.md
|
|
@ -626,6 +626,21 @@ monitor-training check
|
||||||
monitor-training check ./output/training_status.json
|
monitor-training check ./output/training_status.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**启动HTTP静态文件服务**
|
||||||
|
```bash
|
||||||
|
# 启动HTTP静态文件服务(默认端口8080)
|
||||||
|
monitor-training serve
|
||||||
|
|
||||||
|
# 指定状态文件路径和端口
|
||||||
|
monitor-training serve --status-file ./output/training_status.json --port 8080
|
||||||
|
|
||||||
|
# 禁用CORS支持(默认启用)
|
||||||
|
monitor-training serve --no-cors
|
||||||
|
|
||||||
|
# 指定主机地址
|
||||||
|
monitor-training serve --host 0.0.0.0 --port 8080
|
||||||
|
```
|
||||||
|
|
||||||
#### 监控界面功能
|
#### 监控界面功能
|
||||||
|
|
||||||
**📊 核心指标看板**
|
**📊 核心指标看板**
|
||||||
|
|
@ -689,6 +704,18 @@ http://192.168.1.100:8080
|
||||||
http://your-server.com:8501
|
http://your-server.com:8501
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**远程HTTP监控**
|
||||||
|
```bash
|
||||||
|
# GPU服务器启动HTTP服务
|
||||||
|
monitor-training serve --port 8080 --host 0.0.0.0
|
||||||
|
|
||||||
|
# 本地运行Streamlit监控,从HTTP URL读取数据
|
||||||
|
monitor-training monitor --host 127.0.0.1 --port 8501
|
||||||
|
|
||||||
|
# 在Streamlit界面输入远程URL:
|
||||||
|
http://<gpu服务器IP>:8080/training_status.json
|
||||||
|
```
|
||||||
|
|
||||||
#### 状态文件格式
|
#### 状态文件格式
|
||||||
|
|
||||||
状态文件 `training_status.json` 位于训练输出目录,格式如下:
|
状态文件 `training_status.json` 位于训练输出目录,格式如下:
|
||||||
|
|
@ -708,11 +735,44 @@ http://your-server.com:8501
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### HTTP静态文件服务与远程监控
|
||||||
|
|
||||||
|
针对GPU服务器只支持HTTP协议(不支持WebSockets)的环境,我们提供了HTTP静态文件服务方案,实现远程训练监控。
|
||||||
|
|
||||||
|
**🔧 技术特点**
|
||||||
|
- 纯HTTP协议,无需WebSockets支持
|
||||||
|
- 原子写入机制,避免读取不完整JSON数据
|
||||||
|
- 自动重试和JSON验证,确保数据完整性
|
||||||
|
- CORS支持,方便跨域访问
|
||||||
|
- 轻量级设计,不影响训练性能
|
||||||
|
|
||||||
|
**🚀 工作原理**
|
||||||
|
1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件
|
||||||
|
2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务
|
||||||
|
3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面
|
||||||
|
4. 本地机器:在Streamlit界面输入HTTP URL访问远程数据
|
||||||
|
5. Streamlit:通过HTTP轮询获取实时训练数据并展示
|
||||||
|
|
||||||
|
**🛡️ 数据安全**
|
||||||
|
- 原子写入:先写入临时文件,然后原子重命名,避免读取中断
|
||||||
|
- JSON验证:HTTP服务端验证JSON格式后才返回数据
|
||||||
|
- 临时文件处理:智能识别和读取`.tmp`临时文件
|
||||||
|
- 重试机制:JSON解析失败时自动重试读取
|
||||||
|
|
||||||
|
**🌐 网络要求**
|
||||||
|
- GPU服务器:需要开放HTTP端口(默认8080)
|
||||||
|
- 本地机器:需要能访问GPU服务器的HTTP端口
|
||||||
|
- 网络协议:纯HTTP,兼容防火墙和代理
|
||||||
|
|
||||||
#### 注意事项
|
#### 注意事项
|
||||||
1. 首次监控时如果状态文件不存在,会自动创建空文件
|
1. 首次监控时如果状态文件不存在,会自动创建空文件
|
||||||
2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0`
|
2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0`
|
||||||
3. 从检查点恢复训练时会自动加载已有的状态数据
|
3. 从检查点恢复训练时会自动加载已有的状态数据
|
||||||
4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟
|
4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟
|
||||||
|
5. HTTP服务支持原子写入,避免训练进程写入时读取不完整JSON
|
||||||
|
6. 远程监控需要确保GPU服务器防火墙开放对应HTTP端口
|
||||||
|
7. 建议使用`--host 0.0.0.0`参数使HTTP服务可被远程访问
|
||||||
|
|
||||||
|
|
||||||
### 6.7 评估模型(开发中)
|
### 6.7 评估模型(开发中)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
[]
|
||||||
|
|
@ -20,7 +20,6 @@ dependencies = [
|
||||||
"streamlit>=1.56.0",
|
"streamlit>=1.56.0",
|
||||||
"tensorboard>=2.20.0",
|
"tensorboard>=2.20.0",
|
||||||
"torch>=2.10.0",
|
"torch>=2.10.0",
|
||||||
"torchdata>=0.11.0",
|
|
||||||
"transformers==5.1.0",
|
"transformers==5.1.0",
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,7 @@ class CrossAttentionFusion(nn.Module):
|
||||||
bool_mask = context_mask == 0 # [batch, ctx_len]
|
bool_mask = context_mask == 0 # [batch, ctx_len]
|
||||||
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
|
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
|
||||||
# Convert to float mask where True (padding) becomes -inf
|
# Convert to float mask where True (padding) becomes -inf
|
||||||
attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf"))
|
attn_mask = bool_mask.float().masked_fill(bool_mask, -1e9)
|
||||||
|
|
||||||
# Scaled dot-product attention
|
# Scaled dot-product attention
|
||||||
attn_output = F.scaled_dot_product_attention(
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
|
@ -311,54 +311,35 @@ class MoELayer(nn.Module):
|
||||||
out: [batch, seq_len, dim]
|
out: [batch, seq_len, dim]
|
||||||
"""
|
"""
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
num_tokens = B * L
|
||||||
|
|
||||||
# 1. Compute Gating Scores
|
# 展平输入以便处理
|
||||||
gates = self.gate(x) # [B, L, num_experts]
|
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||||
|
|
||||||
# 2. Select Top-K Experts
|
# 1. 计算门控分数
|
||||||
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
|
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||||
|
|
||||||
# Normalize weights for selected experts
|
# 2. 选择 Top-K 专家
|
||||||
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
|
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||||
|
|
||||||
# 3. Dispatch and Compute
|
# 归一化权重
|
||||||
# Initialize output
|
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||||
out = torch.zeros_like(x)
|
|
||||||
|
|
||||||
# Reshape for easier processing: flatten batch and sequence dimensions
|
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||||
x_flat = x.view(-1, D) # [B*L, D]
|
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||||
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
|
expert_outputs = torch.stack(
|
||||||
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
|
[expert(x_flat) for expert in self.experts], dim=1
|
||||||
|
) # [B*L, num_experts, D]
|
||||||
|
|
||||||
# For each of the top-k positions
|
# 4. 使用 gather 选择对应专家的输出
|
||||||
for k in range(self.top_k):
|
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||||
# Get expert indices and weights for this position
|
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||||
expert_indices = topk_indices_flat[:, k] # [B*L]
|
selected_outputs = torch.gather(
|
||||||
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
|
expert_outputs, 1, indices_expanded
|
||||||
|
) # [B*L, K, D]
|
||||||
|
# 5. 加权求和
|
||||||
|
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||||
|
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||||
|
|
||||||
# Process each expert separately
|
# 恢复原始形状
|
||||||
for e_idx in range(self.num_experts):
|
return out_flat.view(B, L, D)
|
||||||
# Mask for tokens assigned to this expert at position k
|
|
||||||
mask = expert_indices == e_idx # [B*L]
|
|
||||||
if not mask.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract tokens for this expert
|
|
||||||
x_selected = x_flat[mask] # [N_selected, D]
|
|
||||||
if x_selected.numel() == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Pass through expert
|
|
||||||
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
|
|
||||||
|
|
||||||
# Apply expert weights and add to output
|
|
||||||
weighted_out = expert_out * expert_weights[mask]
|
|
||||||
|
|
||||||
# Scatter back to flat output
|
|
||||||
out_flat = out.view(-1, D)
|
|
||||||
out_flat[mask] += weighted_out
|
|
||||||
|
|
||||||
# Reshape back to original shape
|
|
||||||
out = out.view(B, L, D)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,364 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from modelscope import AutoModel
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 注意力池化模块----------------------------
|
||||||
|
class AttentionPooling(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.Linear(hidden_size, 1)
|
||||||
|
# 三个可学习偏置:文本、拼音、个性化
|
||||||
|
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, token_type_ids=None):
|
||||||
|
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
|
||||||
|
if token_type_ids is not None:
|
||||||
|
# 根据 token_type_ids 添加对应偏置
|
||||||
|
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
|
||||||
|
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
|
||||||
|
scores = scores + bias_per_token
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
weights = torch.softmax(scores, dim=-1)
|
||||||
|
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 残差块 ----------------------------
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, dim, dropout_prob=0.3):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(dim, dim)
|
||||||
|
self.ln1 = nn.LayerNorm(dim)
|
||||||
|
self.linear2 = nn.Linear(dim, dim)
|
||||||
|
self.ln2 = nn.LayerNorm(dim)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
# 修复:使用 self.gelu 而不是未定义的 self.relu
|
||||||
|
x = self.gelu(self.linear1(x))
|
||||||
|
x = self.ln1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = self.ln2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = x + residual
|
||||||
|
return self.gelu(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 专家网络 ----------------------------
|
||||||
|
class Expert(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
d_model=512,
|
||||||
|
num_resblocks=4,
|
||||||
|
output_multiplier=1,
|
||||||
|
dropout_prob=0.3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = input_dim * output_multiplier
|
||||||
|
self.linear_in = nn.Linear(input_dim, d_model)
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||||||
|
)
|
||||||
|
self.output = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout_prob),
|
||||||
|
nn.Linear(d_model, self.output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_in(x)
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
return self.output(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 1. 上下文编码器 (Context Encoder)
|
||||||
|
# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1]
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class ContextEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# Embeddings
|
||||||
|
self.text_emb = AutoModel.from_pretrained(
|
||||||
|
"iic/nlp_structbert_backbone_lite_std"
|
||||||
|
).embeddings
|
||||||
|
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||||
|
self.pos_emb = nn.Embedding(max_len, dim)
|
||||||
|
self.pinyin_pooling = AttentionPooling(dim)
|
||||||
|
|
||||||
|
# Transformer Encoder (4 layers, 4 heads) [1]
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=dim,
|
||||||
|
nhead=n_heads,
|
||||||
|
dim_feedforward=dim * 4,
|
||||||
|
dropout=0.1,
|
||||||
|
batch_first=True, # 方便处理 [B, L, D]
|
||||||
|
)
|
||||||
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||||
|
|
||||||
|
# LayerNorm for stability
|
||||||
|
self.ln = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, text_ids, pinyin_ids, mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text_ids: [batch, seq_len]
|
||||||
|
pinyin_ids: [batch, seq_len] (假设已对齐,若不对齐需预处理)
|
||||||
|
mask: [batch, seq_len] optional padding mask
|
||||||
|
Returns:
|
||||||
|
H: [batch, seq_len, 512] Context representation [1]
|
||||||
|
"""
|
||||||
|
# 1. Embed text
|
||||||
|
text_emb = self.text_emb(text_ids) # [B, 128, dim]
|
||||||
|
|
||||||
|
# 2. Embed and pool pinyin to global feature
|
||||||
|
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
|
||||||
|
# 方式1:Attention Pooling(推荐)
|
||||||
|
pinyin_global = self.pinyin_pooling(
|
||||||
|
pinyin_emb, mask=None
|
||||||
|
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
|
||||||
|
|
||||||
|
# Broadcast pinyin to all text positions
|
||||||
|
pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim]
|
||||||
|
pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim]
|
||||||
|
|
||||||
|
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
|
||||||
|
x = text_emb + pinyin_broadcast
|
||||||
|
|
||||||
|
seq_len = x.size(1)
|
||||||
|
pos_ids = (
|
||||||
|
torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids)
|
||||||
|
)
|
||||||
|
x += self.pos_emb(pos_ids)
|
||||||
|
|
||||||
|
# 2. Transformer Encoding
|
||||||
|
# src_key_padding_mask expects True for padding positions
|
||||||
|
if mask is not None:
|
||||||
|
# Convert 0/1 mask to bool mask where True is padding
|
||||||
|
src_mask = mask == 0
|
||||||
|
else:
|
||||||
|
src_mask = None
|
||||||
|
|
||||||
|
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||||||
|
return self.ln(H)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 2. 槽位记忆模块 (Slot Memory)
|
||||||
|
# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1]
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class SlotMemory(nn.Module):
|
||||||
|
def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512):
|
||||||
|
super().__init__()
|
||||||
|
self.max_slots = max_slots
|
||||||
|
self.steps_per_slot = steps_per_slot
|
||||||
|
self.total_steps = max_slots * steps_per_slot # 24 steps [1]
|
||||||
|
|
||||||
|
# Shared embedding layer for history tokens [1]
|
||||||
|
self.emb = nn.Embedding(vocab_size, dim)
|
||||||
|
|
||||||
|
# Learnable positional embeddings for the flattened sequence [1]
|
||||||
|
self.pos_emb = nn.Embedding(self.total_steps, dim)
|
||||||
|
|
||||||
|
# Start token embedding for empty slots [1]
|
||||||
|
self.start_emb = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
|
|
||||||
|
def forward(self, history_ids):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
history_ids: [batch, total_steps]
|
||||||
|
Flattened sequence of history tokens.
|
||||||
|
Empty positions should be filled with a special PAD or handled via mask.
|
||||||
|
Returns:
|
||||||
|
S: [batch, total_steps, 512] Slot sequence representation [1]
|
||||||
|
"""
|
||||||
|
# Embed history tokens
|
||||||
|
S = self.emb(history_ids) # [B, 24, 512]
|
||||||
|
|
||||||
|
# Add positional embeddings
|
||||||
|
pos_ids = (
|
||||||
|
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
||||||
|
)
|
||||||
|
S += self.pos_emb(pos_ids)
|
||||||
|
|
||||||
|
return S
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 3. 交叉注意力融合 (Cross-Attention Fusion)
|
||||||
|
# 对应 README: Query=Slots, Key/Value=Context H [1]
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class CrossAttentionFusion(nn.Module):
|
||||||
|
def __init__(self, dim=512, n_heads=4):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads"
|
||||||
|
|
||||||
|
# Linear projections for Q, K, V
|
||||||
|
self.q_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.ln = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, slots_S, context_H, slot_mask=None, context_mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
slots_S: [batch, num_slots_steps, dim] Query
|
||||||
|
context_H: [batch, ctx_len, dim] Key/Value
|
||||||
|
slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention)
|
||||||
|
context_mask: [batch, ctx_len] Optional padding mask
|
||||||
|
Returns:
|
||||||
|
Fused: [batch, num_slots_steps, dim]
|
||||||
|
"""
|
||||||
|
batch_size, num_slots, _ = slots_S.shape
|
||||||
|
_, ctx_len, _ = context_H.shape
|
||||||
|
|
||||||
|
# Project queries, keys, values
|
||||||
|
Q = self.q_proj(slots_S) # [batch, num_slots, dim]
|
||||||
|
K = self.k_proj(context_H) # [batch, ctx_len, dim]
|
||||||
|
V = self.v_proj(context_H) # [batch, ctx_len, dim]
|
||||||
|
|
||||||
|
# Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
|
||||||
|
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Prepare attention mask if context_mask is provided
|
||||||
|
attn_mask = None
|
||||||
|
if context_mask is not None:
|
||||||
|
# context_mask: [batch, ctx_len] where 0 means padding
|
||||||
|
# Convert to bool mask and reshape for broadcasting
|
||||||
|
bool_mask = context_mask == 0 # [batch, ctx_len]
|
||||||
|
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
|
||||||
|
# Convert to float mask where True (padding) becomes -inf
|
||||||
|
attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf"))
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=0.0, # no dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim]
|
||||||
|
attn_output = (
|
||||||
|
attn_output.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, num_slots, self.dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Project back
|
||||||
|
fused = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
# Residual connection and layer norm
|
||||||
|
fused = self.ln(fused + slots_S)
|
||||||
|
|
||||||
|
return fused
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 4. 专家混合层 (MoE Layer)
|
||||||
|
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class MoELayer(nn.Module):
|
||||||
|
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.top_k = top_k
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# Import Expert from your existing components
|
||||||
|
# Assuming Expert class is defined as in components.py [2]
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Expert(
|
||||||
|
input_dim=dim,
|
||||||
|
d_model=dim,
|
||||||
|
num_resblocks=export_resblocks,
|
||||||
|
output_multiplier=1,
|
||||||
|
)
|
||||||
|
for _ in range(num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gating Network [2]
|
||||||
|
self.gate = nn.Linear(dim, num_experts)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [batch, seq_len, dim]
|
||||||
|
Returns:
|
||||||
|
out: [batch, seq_len, dim]
|
||||||
|
"""
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
# 1. Compute Gating Scores
|
||||||
|
gates = self.gate(x) # [B, L, num_experts]
|
||||||
|
|
||||||
|
# 2. Select Top-K Experts
|
||||||
|
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
|
||||||
|
|
||||||
|
# Normalize weights for selected experts
|
||||||
|
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
|
||||||
|
|
||||||
|
# 3. Dispatch and Compute
|
||||||
|
# Initialize output
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
|
||||||
|
# Reshape for easier processing: flatten batch and sequence dimensions
|
||||||
|
x_flat = x.view(-1, D) # [B*L, D]
|
||||||
|
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
|
||||||
|
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
|
||||||
|
|
||||||
|
# For each of the top-k positions
|
||||||
|
for k in range(self.top_k):
|
||||||
|
# Get expert indices and weights for this position
|
||||||
|
expert_indices = topk_indices_flat[:, k] # [B*L]
|
||||||
|
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
|
||||||
|
|
||||||
|
# Process each expert separately
|
||||||
|
for e_idx in range(self.num_experts):
|
||||||
|
# Mask for tokens assigned to this expert at position k
|
||||||
|
mask = expert_indices == e_idx # [B*L]
|
||||||
|
if not mask.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract tokens for this expert
|
||||||
|
x_selected = x_flat[mask] # [N_selected, D]
|
||||||
|
if x_selected.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Pass through expert
|
||||||
|
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
|
||||||
|
|
||||||
|
# Apply expert weights and add to output
|
||||||
|
weighted_out = expert_out * expert_weights[mask]
|
||||||
|
|
||||||
|
# Scatter back to flat output
|
||||||
|
out_flat = out.view(-1, D)
|
||||||
|
out_flat[mask] += weighted_out
|
||||||
|
|
||||||
|
# Reshape back to original shape
|
||||||
|
out = out.view(B, L, D)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
@ -79,7 +79,9 @@ class InputMethodEngine(nn.Module):
|
||||||
|
|
||||||
# 开启 torch.compile 优化 (如果请求)
|
# 开启 torch.compile 优化 (如果请求)
|
||||||
if compile:
|
if compile:
|
||||||
self.forward = torch.compile(self.forward)
|
self.forward = torch.compile(
|
||||||
|
self.forward, mode="reduce-overhead", fullgraph=True
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import tempfile
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
@ -29,16 +27,6 @@ from torch.amp.grad_scaler import GradScaler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
# Try to import DataLoader2 for better streaming dataset support
|
|
||||||
try:
|
|
||||||
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
|
|
||||||
|
|
||||||
DATA_LOADER2_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
DATA_LOADER2_AVAILABLE = False
|
|
||||||
DataLoader2 = None
|
|
||||||
MultiProcessingReadingService = None
|
|
||||||
|
|
||||||
from .dataset import PinyinInputDataset
|
from .dataset import PinyinInputDataset
|
||||||
|
|
||||||
# 导入模型和数据
|
# 导入模型和数据
|
||||||
|
|
@ -770,34 +758,6 @@ def create_dataloader(
|
||||||
Returns:
|
Returns:
|
||||||
数据加载器实例
|
数据加载器实例
|
||||||
"""
|
"""
|
||||||
if (
|
|
||||||
DATA_LOADER2_AVAILABLE
|
|
||||||
and DataLoader2 is not None
|
|
||||||
and MultiProcessingReadingService is not None
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
# DataLoader2配置,针对流式数据集优化
|
|
||||||
reading_service = MultiProcessingReadingService(
|
|
||||||
num_workers=num_workers,
|
|
||||||
prefetch_factor=2, # 减少预取以避免内存问题
|
|
||||||
persistent_workers=True,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloader = DataLoader2(
|
|
||||||
dataset,
|
|
||||||
reading_service=reading_service,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
shuffle=shuffle,
|
|
||||||
)
|
|
||||||
logger.info(f"✅ 使用DataLoader2创建数据加载器,worker数量: {num_workers}")
|
|
||||||
return dataloader
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ DataLoader2创建失败: {e},回退到标准DataLoader")
|
|
||||||
|
|
||||||
# 回退到标准DataLoader
|
|
||||||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
@ -1012,7 +972,7 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=1000,
|
shuffle_buffer_size=50000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,487 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试 torch.compile 兼容性的脚本
|
||||||
|
用于验证 components.py 中的修改是否与 torch.compile 兼容
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# 添加 src 目录到路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from src.model.components import (
|
||||||
|
AttentionPooling,
|
||||||
|
ContextEncoder,
|
||||||
|
CrossAttentionFusion,
|
||||||
|
Expert,
|
||||||
|
MoELayer,
|
||||||
|
ResidualBlock,
|
||||||
|
SlotMemory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_attention_pooling():
|
||||||
|
"""测试 AttentionPooling 模块"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试 AttentionPooling 模块")
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
seq_len = 10
|
||||||
|
hidden_size = 512
|
||||||
|
|
||||||
|
# 创建模块
|
||||||
|
attn_pool = AttentionPooling(hidden_size)
|
||||||
|
|
||||||
|
# 测试数据
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size)
|
||||||
|
mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||||
|
mask[0, 5:] = 0 # 第一个样本的部分位置mask
|
||||||
|
token_type_ids = torch.randint(0, 3, (batch_size, seq_len))
|
||||||
|
|
||||||
|
# 测试未编译版本
|
||||||
|
output = attn_pool(x, mask=mask, token_type_ids=token_type_ids)
|
||||||
|
print(f"✓ AttentionPooling 输出形状: {output.shape}")
|
||||||
|
assert output.shape == (batch_size, hidden_size)
|
||||||
|
|
||||||
|
# 测试编译版本
|
||||||
|
try:
|
||||||
|
compiled_attn_pool = torch.compile(attn_pool, mode="reduce-overhead")
|
||||||
|
compiled_output = compiled_attn_pool(
|
||||||
|
x, mask=mask, token_type_ids=token_type_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查输出是否一致
|
||||||
|
diff = torch.abs(output - compiled_output).max().item()
|
||||||
|
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||||||
|
assert diff < 1e-4, f"输出差异过大: {diff}"
|
||||||
|
print("✓ AttentionPooling 通过 torch.compile 测试")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ AttentionPooling torch.compile 测试失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_layer():
|
||||||
|
"""测试 MoELayer 模块"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试 MoELayer 模块")
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 8
|
||||||
|
dim = 512
|
||||||
|
num_experts = 20
|
||||||
|
top_k = 2
|
||||||
|
|
||||||
|
# 创建模块
|
||||||
|
moe = MoELayer(dim=dim, num_experts=num_experts, top_k=top_k)
|
||||||
|
|
||||||
|
# 测试数据
|
||||||
|
x = torch.randn(batch_size, seq_len, dim)
|
||||||
|
|
||||||
|
# 测试未编译版本
|
||||||
|
output = moe(x)
|
||||||
|
print(f"✓ MoELayer 输出形状: {output.shape}")
|
||||||
|
assert output.shape == (batch_size, seq_len, dim)
|
||||||
|
|
||||||
|
# 检查门控权重
|
||||||
|
gates = moe.gate(x.view(-1, dim))
|
||||||
|
topk_vals, topk_indices = torch.topk(gates, top_k, dim=-1)
|
||||||
|
print(f"✓ 门控权重形状: {gates.shape}, top-k 索引形状: {topk_indices.shape}")
|
||||||
|
|
||||||
|
# 测试编译版本
|
||||||
|
try:
|
||||||
|
compiled_moe = torch.compile(moe, mode="reduce-overhead")
|
||||||
|
|
||||||
|
# 预热
|
||||||
|
for _ in range(3):
|
||||||
|
_ = compiled_moe(x)
|
||||||
|
|
||||||
|
compiled_output = compiled_moe(x)
|
||||||
|
|
||||||
|
# 检查输出是否一致
|
||||||
|
diff = torch.abs(output - compiled_output).max().item()
|
||||||
|
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||||||
|
assert diff < 1e-4, f"输出差异过大: {diff}"
|
||||||
|
print("✓ MoELayer 通过 torch.compile 测试")
|
||||||
|
|
||||||
|
# 性能测试
|
||||||
|
n_iter = 50
|
||||||
|
print(f"\n性能测试 ({n_iter} 次迭代):")
|
||||||
|
|
||||||
|
# 未编译版本
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(n_iter):
|
||||||
|
_ = moe(x)
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
base_time = time.time() - start
|
||||||
|
|
||||||
|
# 编译版本
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(n_iter):
|
||||||
|
_ = compiled_moe(x)
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
compiled_time = time.time() - start
|
||||||
|
|
||||||
|
print(f" 未编译: {base_time:.4f} 秒")
|
||||||
|
print(f" 已编译: {compiled_time:.4f} 秒")
|
||||||
|
if compiled_time > 0:
|
||||||
|
speedup = base_time / compiled_time
|
||||||
|
print(f" 加速比: {speedup:.2f}x")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ MoELayer torch.compile 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_attention_fusion():
|
||||||
|
"""测试 CrossAttentionFusion 模块"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试 CrossAttentionFusion 模块")
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
num_slots = 8
|
||||||
|
ctx_len = 16
|
||||||
|
dim = 512
|
||||||
|
|
||||||
|
# 创建模块
|
||||||
|
cross_attn = CrossAttentionFusion(dim=dim, n_heads=4)
|
||||||
|
|
||||||
|
# 测试数据
|
||||||
|
slots_S = torch.randn(batch_size, num_slots, dim)
|
||||||
|
context_H = torch.randn(batch_size, ctx_len, dim)
|
||||||
|
context_mask = torch.ones(batch_size, ctx_len, dtype=torch.long)
|
||||||
|
context_mask[0, 10:] = 0 # 第一个样本的部分位置mask
|
||||||
|
|
||||||
|
# 测试未编译版本
|
||||||
|
output = cross_attn(slots_S, context_H, context_mask=context_mask)
|
||||||
|
print(f"✓ CrossAttentionFusion 输出形状: {output.shape}")
|
||||||
|
assert output.shape == (batch_size, num_slots, dim)
|
||||||
|
|
||||||
|
# 测试编译版本
|
||||||
|
try:
|
||||||
|
compiled_cross_attn = torch.compile(cross_attn, mode="reduce-overhead")
|
||||||
|
compiled_output = compiled_cross_attn(
|
||||||
|
slots_S, context_H, context_mask=context_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查输出是否一致
|
||||||
|
diff = torch.abs(output - compiled_output).max().item()
|
||||||
|
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||||||
|
assert diff < 1e-4, f"输出差异过大: {diff}"
|
||||||
|
print("✓ CrossAttentionFusion 通过 torch.compile 测试")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ CrossAttentionFusion torch.compile 测试失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_encoder():
|
||||||
|
"""测试 ContextEncoder 模块"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试 ContextEncoder 模块")
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 32
|
||||||
|
vocab_size = 1000
|
||||||
|
pinyin_vocab_size = 30
|
||||||
|
dim = 512
|
||||||
|
|
||||||
|
# 创建模块
|
||||||
|
context_encoder = ContextEncoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
pinyin_vocab_size=pinyin_vocab_size,
|
||||||
|
dim=dim,
|
||||||
|
n_layers=2, # 测试时减少层数
|
||||||
|
n_heads=4,
|
||||||
|
max_len=seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试数据
|
||||||
|
text_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
||||||
|
pinyin_ids = torch.randint(
|
||||||
|
0, pinyin_vocab_size, (batch_size, 24)
|
||||||
|
) # pinyin_ids固定长度24
|
||||||
|
mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||||
|
mask[0, 20:] = 0 # 第一个样本的部分位置mask
|
||||||
|
|
||||||
|
# 测试未编译版本
|
||||||
|
output = context_encoder(text_ids, pinyin_ids, mask=mask)
|
||||||
|
print(f"✓ ContextEncoder 输出形状: {output.shape}")
|
||||||
|
assert output.shape == (batch_size, seq_len, dim)
|
||||||
|
|
||||||
|
# 测试编译版本(ContextEncoder包含外部模型,可能不完全兼容)
|
||||||
|
try:
|
||||||
|
compiled_context_encoder = torch.compile(
|
||||||
|
context_encoder, mode="reduce-overhead"
|
||||||
|
)
|
||||||
|
compiled_output = compiled_context_encoder(text_ids, pinyin_ids, mask=mask)
|
||||||
|
|
||||||
|
# 检查输出是否一致
|
||||||
|
diff = torch.abs(output - compiled_output).max().item()
|
||||||
|
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||||||
|
if diff < 1e-3: # ContextEncoder 可能精度要求稍低
|
||||||
|
print("✓ ContextEncoder 通过 torch.compile 测试")
|
||||||
|
else:
|
||||||
|
print(f"⚠ ContextEncoder 输出差异较大: {diff:.6f} (可能因外部模型)")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"⚠ ContextEncoder torch.compile 测试失败: {e} (可能是正常现象,因为包含外部模型)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_model_compile():
|
||||||
|
"""测试完整模型编译"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试完整模型编译")
|
||||||
|
|
||||||
|
# 导入完整模型
|
||||||
|
from src.model.model import InputMethodEngine
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
vocab_size = 10019
|
||||||
|
pinyin_vocab_size = 30
|
||||||
|
dim = 512
|
||||||
|
num_slots = 8
|
||||||
|
seq_len = 32
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
pinyin_vocab_size=pinyin_vocab_size,
|
||||||
|
dim=dim,
|
||||||
|
num_slots=num_slots,
|
||||||
|
n_layers=2, # 测试时减少层数
|
||||||
|
n_heads=4,
|
||||||
|
num_experts=20,
|
||||||
|
max_seq_len=seq_len,
|
||||||
|
compile=False, # 手动控制编译
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试数据
|
||||||
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
||||||
|
token_type_ids = torch.randint(0, 2, (batch_size, seq_len))
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||||
|
attention_mask[0, 20:] = 0
|
||||||
|
pinyin_ids = torch.randint(0, pinyin_vocab_size, (batch_size, 24))
|
||||||
|
history_slot_ids = torch.randint(0, vocab_size, (batch_size, num_slots))
|
||||||
|
|
||||||
|
# 测试未编译版本
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✓ 完整模型输出形状: {output.shape}")
|
||||||
|
assert output.shape == (batch_size, vocab_size)
|
||||||
|
|
||||||
|
# 测试编译版本
|
||||||
|
try:
|
||||||
|
# 创建一个新模型用于编译测试
|
||||||
|
model_for_compile = InputMethodEngine(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
pinyin_vocab_size=pinyin_vocab_size,
|
||||||
|
dim=dim,
|
||||||
|
num_slots=num_slots,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=4,
|
||||||
|
num_experts=20,
|
||||||
|
max_seq_len=seq_len,
|
||||||
|
compile=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 手动编译
|
||||||
|
compiled_model = torch.compile(model_for_compile, mode="reduce-overhead")
|
||||||
|
|
||||||
|
# 预热
|
||||||
|
for _ in range(3):
|
||||||
|
_ = compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
compiled_output = compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查输出是否一致
|
||||||
|
diff = torch.abs(output - compiled_output).max().item()
|
||||||
|
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||||||
|
|
||||||
|
# 完整模型精度要求可能较低
|
||||||
|
if diff < 1e-3:
|
||||||
|
print("✓ 完整模型通过 torch.compile 测试")
|
||||||
|
else:
|
||||||
|
print(f"⚠ 完整模型输出差异较大: {diff:.6f}")
|
||||||
|
|
||||||
|
# 性能测试
|
||||||
|
n_iter = 30
|
||||||
|
print(f"\n完整模型性能测试 ({n_iter} 次迭代):")
|
||||||
|
|
||||||
|
# 未编译版本
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(n_iter):
|
||||||
|
_ = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
base_time = time.time() - start
|
||||||
|
|
||||||
|
# 编译版本
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(n_iter):
|
||||||
|
_ = compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
compiled_time = time.time() - start
|
||||||
|
|
||||||
|
print(f" 未编译: {base_time:.4f} 秒")
|
||||||
|
print(f" 已编译: {compiled_time:.4f} 秒")
|
||||||
|
if compiled_time > 0:
|
||||||
|
speedup = base_time / compiled_time
|
||||||
|
print(f" 加速比: {speedup:.2f}x")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ 完整模型 torch.compile 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def check_compile_issues():
|
||||||
|
"""检查可能导致编译问题的代码模式"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("检查编译问题")
|
||||||
|
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# 检查 components.py 中的潜在问题
|
||||||
|
with open("src/model/components.py", "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# 检查 float('-inf')
|
||||||
|
if "float('-inf')" in content:
|
||||||
|
issues.append("❌ 发现 float('-inf'),应替换为 -1e9")
|
||||||
|
|
||||||
|
# 检查 .item() 调用
|
||||||
|
if ".item()" in content and "def forward" in content:
|
||||||
|
# 统计 forward 方法中的 .item() 调用
|
||||||
|
lines = content.split("\n")
|
||||||
|
in_forward = False
|
||||||
|
item_calls = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if "def forward" in line:
|
||||||
|
in_forward = True
|
||||||
|
elif in_forward and "def " in line and "def forward" not in line:
|
||||||
|
in_forward = False
|
||||||
|
if in_forward and ".item()" in line:
|
||||||
|
item_calls.append((i + 1, line.strip()))
|
||||||
|
|
||||||
|
if item_calls:
|
||||||
|
issues.append(
|
||||||
|
f"❌ 在 forward 方法中发现 {len(item_calls)} 个 .item() 调用:"
|
||||||
|
)
|
||||||
|
for line_num, line in item_calls[:3]: # 显示前3个
|
||||||
|
issues.append(f" 第 {line_num} 行: {line}")
|
||||||
|
|
||||||
|
# 检查动态控制流
|
||||||
|
dynamic_patterns = [
|
||||||
|
"if not mask.any():",
|
||||||
|
"if mask is not None:",
|
||||||
|
"if token_mask.any():",
|
||||||
|
"continue",
|
||||||
|
]
|
||||||
|
for pattern in dynamic_patterns:
|
||||||
|
if pattern in content:
|
||||||
|
# 检查是否在 forward 方法中
|
||||||
|
issues.append(f"⚠ 发现动态控制流: {pattern}")
|
||||||
|
|
||||||
|
if not issues:
|
||||||
|
print("✓ 未发现明显的编译问题")
|
||||||
|
else:
|
||||||
|
print("发现以下可能影响编译的问题:")
|
||||||
|
for issue in issues:
|
||||||
|
print(issue)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主测试函数"""
|
||||||
|
print("=" * 70)
|
||||||
|
print("torch.compile 兼容性测试")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# 设置设备
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"使用设备: {device}")
|
||||||
|
print(f"PyTorch 版本: {torch.__version__}")
|
||||||
|
|
||||||
|
# 检查 torch.compile 是否可用
|
||||||
|
if hasattr(torch, "compile"):
|
||||||
|
print(f"✓ torch.compile 可用")
|
||||||
|
else:
|
||||||
|
print("❌ torch.compile 不可用,需要 PyTorch 2.0+")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
try:
|
||||||
|
# 首先检查代码问题
|
||||||
|
check_compile_issues()
|
||||||
|
|
||||||
|
# 模块测试
|
||||||
|
test_attention_pooling()
|
||||||
|
test_moe_layer()
|
||||||
|
test_cross_attention_fusion()
|
||||||
|
test_context_encoder()
|
||||||
|
|
||||||
|
# 完整模型测试
|
||||||
|
test_full_model_compile()
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("✅ 所有测试完成!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
15
uv.lock
15
uv.lock
|
|
@ -2305,7 +2305,6 @@ dependencies = [
|
||||||
{ name = "streamlit" },
|
{ name = "streamlit" },
|
||||||
{ name = "tensorboard" },
|
{ name = "tensorboard" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "torchdata" },
|
|
||||||
{ name = "transformers" },
|
{ name = "transformers" },
|
||||||
{ name = "typer" },
|
{ name = "typer" },
|
||||||
]
|
]
|
||||||
|
|
@ -2333,7 +2332,6 @@ requires-dist = [
|
||||||
{ name = "streamlit", specifier = ">=1.56.0" },
|
{ name = "streamlit", specifier = ">=1.56.0" },
|
||||||
{ name = "tensorboard", specifier = ">=2.20.0" },
|
{ name = "tensorboard", specifier = ">=2.20.0" },
|
||||||
{ name = "torch", specifier = ">=2.10.0" },
|
{ name = "torch", specifier = ">=2.10.0" },
|
||||||
{ name = "torchdata", specifier = ">=0.11.0" },
|
|
||||||
{ name = "transformers", specifier = "==5.1.0" },
|
{ name = "transformers", specifier = "==5.1.0" },
|
||||||
{ name = "typer", specifier = ">=0.21.1" },
|
{ name = "typer", specifier = ">=0.21.1" },
|
||||||
]
|
]
|
||||||
|
|
@ -2473,19 +2471,6 @@ wheels = [
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/cf/bf/c8d12a2c86dbfd7f40fb2f56fbf5a505ccf2d9ce131eb559dfc7c51e1a04/torch-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b2a43985ff5ef6ddd923bbcf99943e5f58059805787c5c9a2622bf05ca2965b0" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/cf/bf/c8d12a2c86dbfd7f40fb2f56fbf5a505ccf2d9ce131eb559dfc7c51e1a04/torch-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b2a43985ff5ef6ddd923bbcf99943e5f58059805787c5c9a2622bf05ca2965b0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "torchdata"
|
|
||||||
version = "0.11.0"
|
|
||||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "requests" },
|
|
||||||
{ name = "torch" },
|
|
||||||
{ name = "urllib3" },
|
|
||||||
]
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/95/d4/af694ef718aedbe95a72760ab9ff7a6a7a44ace2d7f70c27bfeb67c5c503/torchdata-0.11.0-py3-none-any.whl", hash = "sha256:52b940fbbe0e00fb21cabddf528449d1bec5bfb0d0823b7487b15f951658ee33" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tornado"
|
name = "tornado"
|
||||||
version = "6.5.5"
|
version = "6.5.5"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue