488 lines
15 KiB
Python
488 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试 torch.compile 兼容性的脚本
|
||
用于验证 components.py 中的修改是否与 torch.compile 兼容
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# 添加 src 目录到路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from src.model.components import (
|
||
AttentionPooling,
|
||
ContextEncoder,
|
||
CrossAttentionFusion,
|
||
Expert,
|
||
MoELayer,
|
||
ResidualBlock,
|
||
SlotMemory,
|
||
)
|
||
|
||
|
||
def test_attention_pooling():
|
||
"""测试 AttentionPooling 模块"""
|
||
print("=" * 60)
|
||
print("测试 AttentionPooling 模块")
|
||
|
||
batch_size = 4
|
||
seq_len = 10
|
||
hidden_size = 512
|
||
|
||
# 创建模块
|
||
attn_pool = AttentionPooling(hidden_size)
|
||
|
||
# 测试数据
|
||
x = torch.randn(batch_size, seq_len, hidden_size)
|
||
mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||
mask[0, 5:] = 0 # 第一个样本的部分位置mask
|
||
token_type_ids = torch.randint(0, 3, (batch_size, seq_len))
|
||
|
||
# 测试未编译版本
|
||
output = attn_pool(x, mask=mask, token_type_ids=token_type_ids)
|
||
print(f"✓ AttentionPooling 输出形状: {output.shape}")
|
||
assert output.shape == (batch_size, hidden_size)
|
||
|
||
# 测试编译版本
|
||
try:
|
||
compiled_attn_pool = torch.compile(attn_pool, mode="reduce-overhead")
|
||
compiled_output = compiled_attn_pool(
|
||
x, mask=mask, token_type_ids=token_type_ids
|
||
)
|
||
|
||
# 检查输出是否一致
|
||
diff = torch.abs(output - compiled_output).max().item()
|
||
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||
assert diff < 1e-4, f"输出差异过大: {diff}"
|
||
print("✓ AttentionPooling 通过 torch.compile 测试")
|
||
except Exception as e:
|
||
print(f"⚠ AttentionPooling torch.compile 测试失败: {e}")
|
||
raise
|
||
|
||
|
||
def test_moe_layer():
|
||
"""测试 MoELayer 模块"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 MoELayer 模块")
|
||
|
||
batch_size = 2
|
||
seq_len = 8
|
||
dim = 512
|
||
num_experts = 20
|
||
top_k = 2
|
||
|
||
# 创建模块
|
||
moe = MoELayer(dim=dim, num_experts=num_experts, top_k=top_k)
|
||
|
||
# 测试数据
|
||
x = torch.randn(batch_size, seq_len, dim)
|
||
|
||
# 测试未编译版本
|
||
output = moe(x)
|
||
print(f"✓ MoELayer 输出形状: {output.shape}")
|
||
assert output.shape == (batch_size, seq_len, dim)
|
||
|
||
# 检查门控权重
|
||
gates = moe.gate(x.view(-1, dim))
|
||
topk_vals, topk_indices = torch.topk(gates, top_k, dim=-1)
|
||
print(f"✓ 门控权重形状: {gates.shape}, top-k 索引形状: {topk_indices.shape}")
|
||
|
||
# 测试编译版本
|
||
try:
|
||
compiled_moe = torch.compile(moe, mode="reduce-overhead")
|
||
|
||
# 预热
|
||
for _ in range(3):
|
||
_ = compiled_moe(x)
|
||
|
||
compiled_output = compiled_moe(x)
|
||
|
||
# 检查输出是否一致
|
||
diff = torch.abs(output - compiled_output).max().item()
|
||
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||
assert diff < 1e-4, f"输出差异过大: {diff}"
|
||
print("✓ MoELayer 通过 torch.compile 测试")
|
||
|
||
# 性能测试
|
||
n_iter = 50
|
||
print(f"\n性能测试 ({n_iter} 次迭代):")
|
||
|
||
# 未编译版本
|
||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||
start = time.time()
|
||
for _ in range(n_iter):
|
||
_ = moe(x)
|
||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||
base_time = time.time() - start
|
||
|
||
# 编译版本
|
||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||
start = time.time()
|
||
for _ in range(n_iter):
|
||
_ = compiled_moe(x)
|
||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||
compiled_time = time.time() - start
|
||
|
||
print(f" 未编译: {base_time:.4f} 秒")
|
||
print(f" 已编译: {compiled_time:.4f} 秒")
|
||
if compiled_time > 0:
|
||
speedup = base_time / compiled_time
|
||
print(f" 加速比: {speedup:.2f}x")
|
||
|
||
except Exception as e:
|
||
print(f"⚠ MoELayer torch.compile 测试失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
|
||
def test_cross_attention_fusion():
|
||
"""测试 CrossAttentionFusion 模块"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 CrossAttentionFusion 模块")
|
||
|
||
batch_size = 2
|
||
num_slots = 8
|
||
ctx_len = 16
|
||
dim = 512
|
||
|
||
# 创建模块
|
||
cross_attn = CrossAttentionFusion(dim=dim, n_heads=4)
|
||
|
||
# 测试数据
|
||
slots_S = torch.randn(batch_size, num_slots, dim)
|
||
context_H = torch.randn(batch_size, ctx_len, dim)
|
||
context_mask = torch.ones(batch_size, ctx_len, dtype=torch.long)
|
||
context_mask[0, 10:] = 0 # 第一个样本的部分位置mask
|
||
|
||
# 测试未编译版本
|
||
output = cross_attn(slots_S, context_H, context_mask=context_mask)
|
||
print(f"✓ CrossAttentionFusion 输出形状: {output.shape}")
|
||
assert output.shape == (batch_size, num_slots, dim)
|
||
|
||
# 测试编译版本
|
||
try:
|
||
compiled_cross_attn = torch.compile(cross_attn, mode="reduce-overhead")
|
||
compiled_output = compiled_cross_attn(
|
||
slots_S, context_H, context_mask=context_mask
|
||
)
|
||
|
||
# 检查输出是否一致
|
||
diff = torch.abs(output - compiled_output).max().item()
|
||
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||
assert diff < 1e-4, f"输出差异过大: {diff}"
|
||
print("✓ CrossAttentionFusion 通过 torch.compile 测试")
|
||
except Exception as e:
|
||
print(f"⚠ CrossAttentionFusion torch.compile 测试失败: {e}")
|
||
raise
|
||
|
||
|
||
def test_context_encoder():
|
||
"""测试 ContextEncoder 模块"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 ContextEncoder 模块")
|
||
|
||
batch_size = 2
|
||
seq_len = 32
|
||
vocab_size = 1000
|
||
pinyin_vocab_size = 30
|
||
dim = 512
|
||
|
||
# 创建模块
|
||
context_encoder = ContextEncoder(
|
||
vocab_size=vocab_size,
|
||
pinyin_vocab_size=pinyin_vocab_size,
|
||
dim=dim,
|
||
n_layers=2, # 测试时减少层数
|
||
n_heads=4,
|
||
max_len=seq_len,
|
||
)
|
||
|
||
# 测试数据
|
||
text_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
||
pinyin_ids = torch.randint(
|
||
0, pinyin_vocab_size, (batch_size, 24)
|
||
) # pinyin_ids固定长度24
|
||
mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||
mask[0, 20:] = 0 # 第一个样本的部分位置mask
|
||
|
||
# 测试未编译版本
|
||
output = context_encoder(text_ids, pinyin_ids, mask=mask)
|
||
print(f"✓ ContextEncoder 输出形状: {output.shape}")
|
||
assert output.shape == (batch_size, seq_len, dim)
|
||
|
||
# 测试编译版本(ContextEncoder包含外部模型,可能不完全兼容)
|
||
try:
|
||
compiled_context_encoder = torch.compile(
|
||
context_encoder, mode="reduce-overhead"
|
||
)
|
||
compiled_output = compiled_context_encoder(text_ids, pinyin_ids, mask=mask)
|
||
|
||
# 检查输出是否一致
|
||
diff = torch.abs(output - compiled_output).max().item()
|
||
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||
if diff < 1e-3: # ContextEncoder 可能精度要求稍低
|
||
print("✓ ContextEncoder 通过 torch.compile 测试")
|
||
else:
|
||
print(f"⚠ ContextEncoder 输出差异较大: {diff:.6f} (可能因外部模型)")
|
||
except Exception as e:
|
||
print(
|
||
f"⚠ ContextEncoder torch.compile 测试失败: {e} (可能是正常现象,因为包含外部模型)"
|
||
)
|
||
|
||
|
||
def test_full_model_compile():
|
||
"""测试完整模型编译"""
|
||
print("\n" + "=" * 60)
|
||
print("测试完整模型编译")
|
||
|
||
# 导入完整模型
|
||
from src.model.model import InputMethodEngine
|
||
|
||
batch_size = 2
|
||
vocab_size = 10019
|
||
pinyin_vocab_size = 30
|
||
dim = 512
|
||
num_slots = 8
|
||
seq_len = 32
|
||
|
||
# 创建模型
|
||
model = InputMethodEngine(
|
||
vocab_size=vocab_size,
|
||
pinyin_vocab_size=pinyin_vocab_size,
|
||
dim=dim,
|
||
num_slots=num_slots,
|
||
n_layers=2, # 测试时减少层数
|
||
n_heads=4,
|
||
num_experts=20,
|
||
max_seq_len=seq_len,
|
||
compile=False, # 手动控制编译
|
||
)
|
||
|
||
# 测试数据
|
||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
||
token_type_ids = torch.randint(0, 2, (batch_size, seq_len))
|
||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||
attention_mask[0, 20:] = 0
|
||
pinyin_ids = torch.randint(0, pinyin_vocab_size, (batch_size, 24))
|
||
history_slot_ids = torch.randint(0, vocab_size, (batch_size, num_slots))
|
||
|
||
# 测试未编译版本
|
||
with torch.no_grad():
|
||
output = model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
print(f"✓ 完整模型输出形状: {output.shape}")
|
||
assert output.shape == (batch_size, vocab_size)
|
||
|
||
# 测试编译版本
|
||
try:
|
||
# 创建一个新模型用于编译测试
|
||
model_for_compile = InputMethodEngine(
|
||
vocab_size=vocab_size,
|
||
pinyin_vocab_size=pinyin_vocab_size,
|
||
dim=dim,
|
||
num_slots=num_slots,
|
||
n_layers=2,
|
||
n_heads=4,
|
||
num_experts=20,
|
||
max_seq_len=seq_len,
|
||
compile=False,
|
||
)
|
||
|
||
# 手动编译
|
||
compiled_model = torch.compile(model_for_compile, mode="reduce-overhead")
|
||
|
||
# 预热
|
||
for _ in range(3):
|
||
_ = compiled_model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
with torch.no_grad():
|
||
compiled_output = compiled_model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
# 检查输出是否一致
|
||
diff = torch.abs(output - compiled_output).max().item()
|
||
print(f"✓ 编译前后输出差异: {diff:.6f}")
|
||
|
||
# 完整模型精度要求可能较低
|
||
if diff < 1e-3:
|
||
print("✓ 完整模型通过 torch.compile 测试")
|
||
else:
|
||
print(f"⚠ 完整模型输出差异较大: {diff:.6f}")
|
||
|
||
# 性能测试
|
||
n_iter = 30
|
||
print(f"\n完整模型性能测试 ({n_iter} 次迭代):")
|
||
|
||
# 未编译版本
|
||
if torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
start = time.time()
|
||
with torch.no_grad():
|
||
for _ in range(n_iter):
|
||
_ = model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
base_time = time.time() - start
|
||
|
||
# 编译版本
|
||
if torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
start = time.time()
|
||
with torch.no_grad():
|
||
for _ in range(n_iter):
|
||
_ = compiled_model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
compiled_time = time.time() - start
|
||
|
||
print(f" 未编译: {base_time:.4f} 秒")
|
||
print(f" 已编译: {compiled_time:.4f} 秒")
|
||
if compiled_time > 0:
|
||
speedup = base_time / compiled_time
|
||
print(f" 加速比: {speedup:.2f}x")
|
||
|
||
except Exception as e:
|
||
print(f"⚠ 完整模型 torch.compile 测试失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
|
||
def check_compile_issues():
|
||
"""检查可能导致编译问题的代码模式"""
|
||
print("\n" + "=" * 60)
|
||
print("检查编译问题")
|
||
|
||
issues = []
|
||
|
||
# 检查 components.py 中的潜在问题
|
||
with open("src/model/components.py", "r") as f:
|
||
content = f.read()
|
||
|
||
# 检查 float('-inf')
|
||
if "float('-inf')" in content:
|
||
issues.append("❌ 发现 float('-inf'),应替换为 -1e9")
|
||
|
||
# 检查 .item() 调用
|
||
if ".item()" in content and "def forward" in content:
|
||
# 统计 forward 方法中的 .item() 调用
|
||
lines = content.split("\n")
|
||
in_forward = False
|
||
item_calls = []
|
||
for i, line in enumerate(lines):
|
||
if "def forward" in line:
|
||
in_forward = True
|
||
elif in_forward and "def " in line and "def forward" not in line:
|
||
in_forward = False
|
||
if in_forward and ".item()" in line:
|
||
item_calls.append((i + 1, line.strip()))
|
||
|
||
if item_calls:
|
||
issues.append(
|
||
f"❌ 在 forward 方法中发现 {len(item_calls)} 个 .item() 调用:"
|
||
)
|
||
for line_num, line in item_calls[:3]: # 显示前3个
|
||
issues.append(f" 第 {line_num} 行: {line}")
|
||
|
||
# 检查动态控制流
|
||
dynamic_patterns = [
|
||
"if not mask.any():",
|
||
"if mask is not None:",
|
||
"if token_mask.any():",
|
||
"continue",
|
||
]
|
||
for pattern in dynamic_patterns:
|
||
if pattern in content:
|
||
# 检查是否在 forward 方法中
|
||
issues.append(f"⚠ 发现动态控制流: {pattern}")
|
||
|
||
if not issues:
|
||
print("✓ 未发现明显的编译问题")
|
||
else:
|
||
print("发现以下可能影响编译的问题:")
|
||
for issue in issues:
|
||
print(issue)
|
||
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("=" * 70)
|
||
print("torch.compile 兼容性测试")
|
||
print("=" * 70)
|
||
|
||
# 设置设备
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"使用设备: {device}")
|
||
print(f"PyTorch 版本: {torch.__version__}")
|
||
|
||
# 检查 torch.compile 是否可用
|
||
if hasattr(torch, "compile"):
|
||
print(f"✓ torch.compile 可用")
|
||
else:
|
||
print("❌ torch.compile 不可用,需要 PyTorch 2.0+")
|
||
return
|
||
|
||
# 运行测试
|
||
try:
|
||
# 首先检查代码问题
|
||
check_compile_issues()
|
||
|
||
# 模块测试
|
||
test_attention_pooling()
|
||
test_moe_layer()
|
||
test_cross_attention_fusion()
|
||
test_context_encoder()
|
||
|
||
# 完整模型测试
|
||
test_full_model_compile()
|
||
|
||
print("\n" + "=" * 70)
|
||
print("✅ 所有测试完成!")
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 测试失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|