#!/usr/bin/env python3 """ 测试 torch.compile 兼容性的脚本 用于验证 components.py 中的修改是否与 torch.compile 兼容 """ import os import sys import time import torch import torch.nn as nn # 添加 src 目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.model.components import ( AttentionPooling, ContextEncoder, CrossAttentionFusion, Expert, MoELayer, ResidualBlock, SlotMemory, ) def test_attention_pooling(): """测试 AttentionPooling 模块""" print("=" * 60) print("测试 AttentionPooling 模块") batch_size = 4 seq_len = 10 hidden_size = 512 # 创建模块 attn_pool = AttentionPooling(hidden_size) # 测试数据 x = torch.randn(batch_size, seq_len, hidden_size) mask = torch.ones(batch_size, seq_len, dtype=torch.long) mask[0, 5:] = 0 # 第一个样本的部分位置mask token_type_ids = torch.randint(0, 3, (batch_size, seq_len)) # 测试未编译版本 output = attn_pool(x, mask=mask, token_type_ids=token_type_ids) print(f"✓ AttentionPooling 输出形状: {output.shape}") assert output.shape == (batch_size, hidden_size) # 测试编译版本 try: compiled_attn_pool = torch.compile(attn_pool, mode="reduce-overhead") compiled_output = compiled_attn_pool( x, mask=mask, token_type_ids=token_type_ids ) # 检查输出是否一致 diff = torch.abs(output - compiled_output).max().item() print(f"✓ 编译前后输出差异: {diff:.6f}") assert diff < 1e-4, f"输出差异过大: {diff}" print("✓ AttentionPooling 通过 torch.compile 测试") except Exception as e: print(f"⚠ AttentionPooling torch.compile 测试失败: {e}") raise def test_moe_layer(): """测试 MoELayer 模块""" print("\n" + "=" * 60) print("测试 MoELayer 模块") batch_size = 2 seq_len = 8 dim = 512 num_experts = 20 top_k = 2 # 创建模块 moe = MoELayer(dim=dim, num_experts=num_experts, top_k=top_k) # 测试数据 x = torch.randn(batch_size, seq_len, dim) # 测试未编译版本 output = moe(x) print(f"✓ MoELayer 输出形状: {output.shape}") assert output.shape == (batch_size, seq_len, dim) # 检查门控权重 gates = moe.gate(x.view(-1, dim)) topk_vals, topk_indices = torch.topk(gates, top_k, dim=-1) print(f"✓ 门控权重形状: {gates.shape}, top-k 索引形状: {topk_indices.shape}") # 测试编译版本 try: compiled_moe = torch.compile(moe, mode="reduce-overhead") # 预热 for _ in range(3): _ = compiled_moe(x) compiled_output = compiled_moe(x) # 检查输出是否一致 diff = torch.abs(output - compiled_output).max().item() print(f"✓ 编译前后输出差异: {diff:.6f}") assert diff < 1e-4, f"输出差异过大: {diff}" print("✓ MoELayer 通过 torch.compile 测试") # 性能测试 n_iter = 50 print(f"\n性能测试 ({n_iter} 次迭代):") # 未编译版本 torch.cuda.synchronize() if torch.cuda.is_available() else None start = time.time() for _ in range(n_iter): _ = moe(x) torch.cuda.synchronize() if torch.cuda.is_available() else None base_time = time.time() - start # 编译版本 torch.cuda.synchronize() if torch.cuda.is_available() else None start = time.time() for _ in range(n_iter): _ = compiled_moe(x) torch.cuda.synchronize() if torch.cuda.is_available() else None compiled_time = time.time() - start print(f" 未编译: {base_time:.4f} 秒") print(f" 已编译: {compiled_time:.4f} 秒") if compiled_time > 0: speedup = base_time / compiled_time print(f" 加速比: {speedup:.2f}x") except Exception as e: print(f"⚠ MoELayer torch.compile 测试失败: {e}") import traceback traceback.print_exc() raise def test_cross_attention_fusion(): """测试 CrossAttentionFusion 模块""" print("\n" + "=" * 60) print("测试 CrossAttentionFusion 模块") batch_size = 2 num_slots = 8 ctx_len = 16 dim = 512 # 创建模块 cross_attn = CrossAttentionFusion(dim=dim, n_heads=4) # 测试数据 slots_S = torch.randn(batch_size, num_slots, dim) context_H = torch.randn(batch_size, ctx_len, dim) context_mask = torch.ones(batch_size, ctx_len, dtype=torch.long) context_mask[0, 10:] = 0 # 第一个样本的部分位置mask # 测试未编译版本 output = cross_attn(slots_S, context_H, context_mask=context_mask) print(f"✓ CrossAttentionFusion 输出形状: {output.shape}") assert output.shape == (batch_size, num_slots, dim) # 测试编译版本 try: compiled_cross_attn = torch.compile(cross_attn, mode="reduce-overhead") compiled_output = compiled_cross_attn( slots_S, context_H, context_mask=context_mask ) # 检查输出是否一致 diff = torch.abs(output - compiled_output).max().item() print(f"✓ 编译前后输出差异: {diff:.6f}") assert diff < 1e-4, f"输出差异过大: {diff}" print("✓ CrossAttentionFusion 通过 torch.compile 测试") except Exception as e: print(f"⚠ CrossAttentionFusion torch.compile 测试失败: {e}") raise def test_context_encoder(): """测试 ContextEncoder 模块""" print("\n" + "=" * 60) print("测试 ContextEncoder 模块") batch_size = 2 seq_len = 32 vocab_size = 1000 pinyin_vocab_size = 30 dim = 512 # 创建模块 context_encoder = ContextEncoder( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, n_layers=2, # 测试时减少层数 n_heads=4, max_len=seq_len, ) # 测试数据 text_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) pinyin_ids = torch.randint( 0, pinyin_vocab_size, (batch_size, 24) ) # pinyin_ids固定长度24 mask = torch.ones(batch_size, seq_len, dtype=torch.long) mask[0, 20:] = 0 # 第一个样本的部分位置mask # 测试未编译版本 output = context_encoder(text_ids, pinyin_ids, mask=mask) print(f"✓ ContextEncoder 输出形状: {output.shape}") assert output.shape == (batch_size, seq_len, dim) # 测试编译版本(ContextEncoder包含外部模型,可能不完全兼容) try: compiled_context_encoder = torch.compile( context_encoder, mode="reduce-overhead" ) compiled_output = compiled_context_encoder(text_ids, pinyin_ids, mask=mask) # 检查输出是否一致 diff = torch.abs(output - compiled_output).max().item() print(f"✓ 编译前后输出差异: {diff:.6f}") if diff < 1e-3: # ContextEncoder 可能精度要求稍低 print("✓ ContextEncoder 通过 torch.compile 测试") else: print(f"⚠ ContextEncoder 输出差异较大: {diff:.6f} (可能因外部模型)") except Exception as e: print( f"⚠ ContextEncoder torch.compile 测试失败: {e} (可能是正常现象,因为包含外部模型)" ) def test_full_model_compile(): """测试完整模型编译""" print("\n" + "=" * 60) print("测试完整模型编译") # 导入完整模型 from src.model.model import InputMethodEngine batch_size = 2 vocab_size = 10019 pinyin_vocab_size = 30 dim = 512 num_slots = 8 seq_len = 32 # 创建模型 model = InputMethodEngine( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, num_slots=num_slots, n_layers=2, # 测试时减少层数 n_heads=4, num_experts=20, max_seq_len=seq_len, compile=False, # 手动控制编译 ) # 测试数据 input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) token_type_ids = torch.randint(0, 2, (batch_size, seq_len)) attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) attention_mask[0, 20:] = 0 pinyin_ids = torch.randint(0, pinyin_vocab_size, (batch_size, 24)) history_slot_ids = torch.randint(0, vocab_size, (batch_size, num_slots)) # 测试未编译版本 with torch.no_grad(): output = model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) print(f"✓ 完整模型输出形状: {output.shape}") assert output.shape == (batch_size, vocab_size) # 测试编译版本 try: # 创建一个新模型用于编译测试 model_for_compile = InputMethodEngine( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, num_slots=num_slots, n_layers=2, n_heads=4, num_experts=20, max_seq_len=seq_len, compile=False, ) # 手动编译 compiled_model = torch.compile(model_for_compile, mode="reduce-overhead") # 预热 for _ in range(3): _ = compiled_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) with torch.no_grad(): compiled_output = compiled_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) # 检查输出是否一致 diff = torch.abs(output - compiled_output).max().item() print(f"✓ 编译前后输出差异: {diff:.6f}") # 完整模型精度要求可能较低 if diff < 1e-3: print("✓ 完整模型通过 torch.compile 测试") else: print(f"⚠ 完整模型输出差异较大: {diff:.6f}") # 性能测试 n_iter = 30 print(f"\n完整模型性能测试 ({n_iter} 次迭代):") # 未编译版本 if torch.cuda.is_available(): torch.cuda.synchronize() start = time.time() with torch.no_grad(): for _ in range(n_iter): _ = model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) if torch.cuda.is_available(): torch.cuda.synchronize() base_time = time.time() - start # 编译版本 if torch.cuda.is_available(): torch.cuda.synchronize() start = time.time() with torch.no_grad(): for _ in range(n_iter): _ = compiled_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) if torch.cuda.is_available(): torch.cuda.synchronize() compiled_time = time.time() - start print(f" 未编译: {base_time:.4f} 秒") print(f" 已编译: {compiled_time:.4f} 秒") if compiled_time > 0: speedup = base_time / compiled_time print(f" 加速比: {speedup:.2f}x") except Exception as e: print(f"⚠ 完整模型 torch.compile 测试失败: {e}") import traceback traceback.print_exc() def check_compile_issues(): """检查可能导致编译问题的代码模式""" print("\n" + "=" * 60) print("检查编译问题") issues = [] # 检查 components.py 中的潜在问题 with open("src/model/components.py", "r") as f: content = f.read() # 检查 float('-inf') if "float('-inf')" in content: issues.append("❌ 发现 float('-inf'),应替换为 -1e9") # 检查 .item() 调用 if ".item()" in content and "def forward" in content: # 统计 forward 方法中的 .item() 调用 lines = content.split("\n") in_forward = False item_calls = [] for i, line in enumerate(lines): if "def forward" in line: in_forward = True elif in_forward and "def " in line and "def forward" not in line: in_forward = False if in_forward and ".item()" in line: item_calls.append((i + 1, line.strip())) if item_calls: issues.append( f"❌ 在 forward 方法中发现 {len(item_calls)} 个 .item() 调用:" ) for line_num, line in item_calls[:3]: # 显示前3个 issues.append(f" 第 {line_num} 行: {line}") # 检查动态控制流 dynamic_patterns = [ "if not mask.any():", "if mask is not None:", "if token_mask.any():", "continue", ] for pattern in dynamic_patterns: if pattern in content: # 检查是否在 forward 方法中 issues.append(f"⚠ 发现动态控制流: {pattern}") if not issues: print("✓ 未发现明显的编译问题") else: print("发现以下可能影响编译的问题:") for issue in issues: print(issue) def main(): """主测试函数""" print("=" * 70) print("torch.compile 兼容性测试") print("=" * 70) # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") print(f"PyTorch 版本: {torch.__version__}") # 检查 torch.compile 是否可用 if hasattr(torch, "compile"): print(f"✓ torch.compile 可用") else: print("❌ torch.compile 不可用,需要 PyTorch 2.0+") return # 运行测试 try: # 首先检查代码问题 check_compile_issues() # 模块测试 test_attention_pooling() test_moe_layer() test_cross_attention_fusion() test_context_encoder() # 完整模型测试 test_full_model_compile() print("\n" + "=" * 70) print("✅ 所有测试完成!") except Exception as e: print(f"\n❌ 测试失败: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()