SUimeModelTraner/test_compile.py

488 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试 torch.compile 兼容性的脚本
用于验证 components.py 中的修改是否与 torch.compile 兼容
"""
import os
import sys
import time
import torch
import torch.nn as nn
# 添加 src 目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model.components import (
AttentionPooling,
ContextEncoder,
CrossAttentionFusion,
Expert,
MoELayer,
ResidualBlock,
SlotMemory,
)
def test_attention_pooling():
"""测试 AttentionPooling 模块"""
print("=" * 60)
print("测试 AttentionPooling 模块")
batch_size = 4
seq_len = 10
hidden_size = 512
# 创建模块
attn_pool = AttentionPooling(hidden_size)
# 测试数据
x = torch.randn(batch_size, seq_len, hidden_size)
mask = torch.ones(batch_size, seq_len, dtype=torch.long)
mask[0, 5:] = 0 # 第一个样本的部分位置mask
token_type_ids = torch.randint(0, 3, (batch_size, seq_len))
# 测试未编译版本
output = attn_pool(x, mask=mask, token_type_ids=token_type_ids)
print(f"✓ AttentionPooling 输出形状: {output.shape}")
assert output.shape == (batch_size, hidden_size)
# 测试编译版本
try:
compiled_attn_pool = torch.compile(attn_pool, mode="reduce-overhead")
compiled_output = compiled_attn_pool(
x, mask=mask, token_type_ids=token_type_ids
)
# 检查输出是否一致
diff = torch.abs(output - compiled_output).max().item()
print(f"✓ 编译前后输出差异: {diff:.6f}")
assert diff < 1e-4, f"输出差异过大: {diff}"
print("✓ AttentionPooling 通过 torch.compile 测试")
except Exception as e:
print(f"⚠ AttentionPooling torch.compile 测试失败: {e}")
raise
def test_moe_layer():
"""测试 MoELayer 模块"""
print("\n" + "=" * 60)
print("测试 MoELayer 模块")
batch_size = 2
seq_len = 8
dim = 512
num_experts = 20
top_k = 2
# 创建模块
moe = MoELayer(dim=dim, num_experts=num_experts, top_k=top_k)
# 测试数据
x = torch.randn(batch_size, seq_len, dim)
# 测试未编译版本
output = moe(x)
print(f"✓ MoELayer 输出形状: {output.shape}")
assert output.shape == (batch_size, seq_len, dim)
# 检查门控权重
gates = moe.gate(x.view(-1, dim))
topk_vals, topk_indices = torch.topk(gates, top_k, dim=-1)
print(f"✓ 门控权重形状: {gates.shape}, top-k 索引形状: {topk_indices.shape}")
# 测试编译版本
try:
compiled_moe = torch.compile(moe, mode="reduce-overhead")
# 预热
for _ in range(3):
_ = compiled_moe(x)
compiled_output = compiled_moe(x)
# 检查输出是否一致
diff = torch.abs(output - compiled_output).max().item()
print(f"✓ 编译前后输出差异: {diff:.6f}")
assert diff < 1e-4, f"输出差异过大: {diff}"
print("✓ MoELayer 通过 torch.compile 测试")
# 性能测试
n_iter = 50
print(f"\n性能测试 ({n_iter} 次迭代):")
# 未编译版本
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
for _ in range(n_iter):
_ = moe(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
base_time = time.time() - start
# 编译版本
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
for _ in range(n_iter):
_ = compiled_moe(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
compiled_time = time.time() - start
print(f" 未编译: {base_time:.4f}")
print(f" 已编译: {compiled_time:.4f}")
if compiled_time > 0:
speedup = base_time / compiled_time
print(f" 加速比: {speedup:.2f}x")
except Exception as e:
print(f"⚠ MoELayer torch.compile 测试失败: {e}")
import traceback
traceback.print_exc()
raise
def test_cross_attention_fusion():
"""测试 CrossAttentionFusion 模块"""
print("\n" + "=" * 60)
print("测试 CrossAttentionFusion 模块")
batch_size = 2
num_slots = 8
ctx_len = 16
dim = 512
# 创建模块
cross_attn = CrossAttentionFusion(dim=dim, n_heads=4)
# 测试数据
slots_S = torch.randn(batch_size, num_slots, dim)
context_H = torch.randn(batch_size, ctx_len, dim)
context_mask = torch.ones(batch_size, ctx_len, dtype=torch.long)
context_mask[0, 10:] = 0 # 第一个样本的部分位置mask
# 测试未编译版本
output = cross_attn(slots_S, context_H, context_mask=context_mask)
print(f"✓ CrossAttentionFusion 输出形状: {output.shape}")
assert output.shape == (batch_size, num_slots, dim)
# 测试编译版本
try:
compiled_cross_attn = torch.compile(cross_attn, mode="reduce-overhead")
compiled_output = compiled_cross_attn(
slots_S, context_H, context_mask=context_mask
)
# 检查输出是否一致
diff = torch.abs(output - compiled_output).max().item()
print(f"✓ 编译前后输出差异: {diff:.6f}")
assert diff < 1e-4, f"输出差异过大: {diff}"
print("✓ CrossAttentionFusion 通过 torch.compile 测试")
except Exception as e:
print(f"⚠ CrossAttentionFusion torch.compile 测试失败: {e}")
raise
def test_context_encoder():
"""测试 ContextEncoder 模块"""
print("\n" + "=" * 60)
print("测试 ContextEncoder 模块")
batch_size = 2
seq_len = 32
vocab_size = 1000
pinyin_vocab_size = 30
dim = 512
# 创建模块
context_encoder = ContextEncoder(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
n_layers=2, # 测试时减少层数
n_heads=4,
max_len=seq_len,
)
# 测试数据
text_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
pinyin_ids = torch.randint(
0, pinyin_vocab_size, (batch_size, 24)
) # pinyin_ids固定长度24
mask = torch.ones(batch_size, seq_len, dtype=torch.long)
mask[0, 20:] = 0 # 第一个样本的部分位置mask
# 测试未编译版本
output = context_encoder(text_ids, pinyin_ids, mask=mask)
print(f"✓ ContextEncoder 输出形状: {output.shape}")
assert output.shape == (batch_size, seq_len, dim)
# 测试编译版本ContextEncoder包含外部模型可能不完全兼容
try:
compiled_context_encoder = torch.compile(
context_encoder, mode="reduce-overhead"
)
compiled_output = compiled_context_encoder(text_ids, pinyin_ids, mask=mask)
# 检查输出是否一致
diff = torch.abs(output - compiled_output).max().item()
print(f"✓ 编译前后输出差异: {diff:.6f}")
if diff < 1e-3: # ContextEncoder 可能精度要求稍低
print("✓ ContextEncoder 通过 torch.compile 测试")
else:
print(f"⚠ ContextEncoder 输出差异较大: {diff:.6f} (可能因外部模型)")
except Exception as e:
print(
f"⚠ ContextEncoder torch.compile 测试失败: {e} (可能是正常现象,因为包含外部模型)"
)
def test_full_model_compile():
"""测试完整模型编译"""
print("\n" + "=" * 60)
print("测试完整模型编译")
# 导入完整模型
from src.model.model import InputMethodEngine
batch_size = 2
vocab_size = 10019
pinyin_vocab_size = 30
dim = 512
num_slots = 8
seq_len = 32
# 创建模型
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=2, # 测试时减少层数
n_heads=4,
num_experts=20,
max_seq_len=seq_len,
compile=False, # 手动控制编译
)
# 测试数据
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
token_type_ids = torch.randint(0, 2, (batch_size, seq_len))
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
attention_mask[0, 20:] = 0
pinyin_ids = torch.randint(0, pinyin_vocab_size, (batch_size, 24))
history_slot_ids = torch.randint(0, vocab_size, (batch_size, num_slots))
# 测试未编译版本
with torch.no_grad():
output = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
print(f"✓ 完整模型输出形状: {output.shape}")
assert output.shape == (batch_size, vocab_size)
# 测试编译版本
try:
# 创建一个新模型用于编译测试
model_for_compile = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=2,
n_heads=4,
num_experts=20,
max_seq_len=seq_len,
compile=False,
)
# 手动编译
compiled_model = torch.compile(model_for_compile, mode="reduce-overhead")
# 预热
for _ in range(3):
_ = compiled_model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
with torch.no_grad():
compiled_output = compiled_model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# 检查输出是否一致
diff = torch.abs(output - compiled_output).max().item()
print(f"✓ 编译前后输出差异: {diff:.6f}")
# 完整模型精度要求可能较低
if diff < 1e-3:
print("✓ 完整模型通过 torch.compile 测试")
else:
print(f"⚠ 完整模型输出差异较大: {diff:.6f}")
# 性能测试
n_iter = 30
print(f"\n完整模型性能测试 ({n_iter} 次迭代):")
# 未编译版本
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
for _ in range(n_iter):
_ = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
base_time = time.time() - start
# 编译版本
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
for _ in range(n_iter):
_ = compiled_model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
compiled_time = time.time() - start
print(f" 未编译: {base_time:.4f}")
print(f" 已编译: {compiled_time:.4f}")
if compiled_time > 0:
speedup = base_time / compiled_time
print(f" 加速比: {speedup:.2f}x")
except Exception as e:
print(f"⚠ 完整模型 torch.compile 测试失败: {e}")
import traceback
traceback.print_exc()
def check_compile_issues():
"""检查可能导致编译问题的代码模式"""
print("\n" + "=" * 60)
print("检查编译问题")
issues = []
# 检查 components.py 中的潜在问题
with open("src/model/components.py", "r") as f:
content = f.read()
# 检查 float('-inf')
if "float('-inf')" in content:
issues.append("❌ 发现 float('-inf'),应替换为 -1e9")
# 检查 .item() 调用
if ".item()" in content and "def forward" in content:
# 统计 forward 方法中的 .item() 调用
lines = content.split("\n")
in_forward = False
item_calls = []
for i, line in enumerate(lines):
if "def forward" in line:
in_forward = True
elif in_forward and "def " in line and "def forward" not in line:
in_forward = False
if in_forward and ".item()" in line:
item_calls.append((i + 1, line.strip()))
if item_calls:
issues.append(
f"❌ 在 forward 方法中发现 {len(item_calls)} 个 .item() 调用:"
)
for line_num, line in item_calls[:3]: # 显示前3个
issues.append(f"{line_num} 行: {line}")
# 检查动态控制流
dynamic_patterns = [
"if not mask.any():",
"if mask is not None:",
"if token_mask.any():",
"continue",
]
for pattern in dynamic_patterns:
if pattern in content:
# 检查是否在 forward 方法中
issues.append(f"⚠ 发现动态控制流: {pattern}")
if not issues:
print("✓ 未发现明显的编译问题")
else:
print("发现以下可能影响编译的问题:")
for issue in issues:
print(issue)
def main():
"""主测试函数"""
print("=" * 70)
print("torch.compile 兼容性测试")
print("=" * 70)
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
print(f"PyTorch 版本: {torch.__version__}")
# 检查 torch.compile 是否可用
if hasattr(torch, "compile"):
print(f"✓ torch.compile 可用")
else:
print("❌ torch.compile 不可用,需要 PyTorch 2.0+")
return
# 运行测试
try:
# 首先检查代码问题
check_compile_issues()
# 模块测试
test_attention_pooling()
test_moe_layer()
test_cross_attention_fusion()
test_context_encoder()
# 完整模型测试
test_full_model_compile()
print("\n" + "=" * 70)
print("✅ 所有测试完成!")
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()