#!/usr/bin/env python3 """ ONNX模型验证脚本 验证导出的ONNX模型与原始PyTorch模型输出的一致性 """ import argparse import os import sys from pathlib import Path import numpy as np import torch import onnxruntime as ort # 添加src目录到路径 sys.path.insert(0, str(Path(__file__).parent)) from src.model.export_models import create_export_models_from_checkpoint def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5): """ 比较PyTorch和ONNX输出 Args: pytorch_output: PyTorch张量 onnx_output: ONNX Runtime输出(numpy数组) name: 输出名称(用于错误信息) rtol: 相对容差 atol: 绝对容差 Returns: bool: 是否匹配 """ # 转换PyTorch输出为numpy if isinstance(pytorch_output, torch.Tensor): pytorch_np = pytorch_output.detach().cpu().numpy() else: pytorch_np = np.array(pytorch_output) # 确保形状一致 if pytorch_np.shape != onnx_output.shape: print( f"❌ {name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}" ) return False # 计算差异 diff = np.abs(pytorch_np - onnx_output) max_diff = np.max(diff) mean_diff = np.mean(diff) # 检查是否在容差范围内 is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol) if is_close: print(f"✅ {name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") else: print(f"❌ {name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") print( f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], " f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]" ) return is_close def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"): """ 验证上下文编码器 Args: checkpoint_path: PyTorch checkpoint路径 onnx_path: ONNX模型路径 device: 设备 Returns: bool: 验证是否通过 """ print(f"\n🔍 验证上下文编码器: {onnx_path}") # 加载PyTorch模型 context_encoder_export, _, config = create_export_models_from_checkpoint( checkpoint_path, device ) # 创建ONNX Runtime会话 session = ort.InferenceSession( onnx_path, providers=[ "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" ], ) # 创建测试输入 batch_size = 2 # 使用batch_size=2测试动态批处理 seq_len = config.get("max_seq_len", 128) pinyin_len = 24 input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) # 随机屏蔽一些位置 attention_mask[:, seq_len // 2 :] = 0 # PyTorch推理 with torch.no_grad(): pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask) # ONNX推理 onnx_inputs = { "input_ids": input_ids.numpy(), "pinyin_ids": pinyin_ids.numpy(), "attention_mask": attention_mask.numpy(), } onnx_outputs = session.run(None, onnx_inputs) # 比较输出 output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"] all_match = True for i, name in enumerate(output_names): if i < len(pytorch_outputs) and i < len(onnx_outputs): match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name) all_match = all_match and match return all_match def verify_decoder(checkpoint_path, onnx_path, device="cpu"): """ 验证解码器 Args: checkpoint_path: PyTorch checkpoint路径 onnx_path: ONNX模型路径 device: 设备 Returns: bool: 验证是否通过 """ print(f"\n🔍 验证解码器: {onnx_path}") # 加载PyTorch模型 _, decoder_export, config = create_export_models_from_checkpoint( checkpoint_path, device ) # 创建ONNX Runtime会话 session = ort.InferenceSession( onnx_path, providers=[ "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" ], ) # 创建测试输入 batch_size = 2 seq_len = config.get("max_seq_len", 128) pinyin_len = 24 dim = config.get("dim", 512) num_slots = config.get("num_slots", 8) context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32) pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32) context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32) pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32) history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long) # PyTorch推理 with torch.no_grad(): pytorch_output = decoder_export( context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask ) # ONNX推理 onnx_inputs = { "context_H": context_H.numpy(), "pinyin_P": pinyin_P.numpy(), "history_slot_ids": history_slot_ids.numpy(), "context_mask": context_mask.numpy(), "pinyin_mask": pinyin_mask.numpy(), } onnx_outputs = session.run(None, onnx_inputs) # 比较输出 return compare_outputs(pytorch_output, onnx_outputs[0], "logits") def verify_end_to_end( checkpoint_path, context_encoder_path, decoder_path, device="cpu" ): """ 端到端验证:比较完整推理流程 Args: checkpoint_path: PyTorch checkpoint路径 context_encoder_path: 上下文编码器ONNX路径 decoder_path: 解码器ONNX路径 device: 设备 Returns: bool: 验证是否通过 """ print(f"\n🔍 端到端验证") # 加载原始PyTorch模型 from src.model.model import InputMethodEngine checkpoint = torch.load(checkpoint_path, map_location=device) if "config" in checkpoint: config = checkpoint["config"] else: config = { "vocab_size": 10019, "pinyin_vocab_size": 30, "dim": 512, "num_slots": 8, "n_layers": 4, "n_heads": 4, "num_experts": 10, "max_seq_len": 128, } model = InputMethodEngine( vocab_size=config.get("vocab_size", 10019), pinyin_vocab_size=config.get("pinyin_vocab_size", 30), dim=config.get("dim", 512), num_slots=config.get("num_slots", 8), n_layers=config.get("n_layers", 4), n_heads=config.get("n_heads", 4), num_experts=config.get("num_experts", 10), max_seq_len=config.get("max_seq_len", 128), compile=False, ) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) else: model.load_state_dict(checkpoint) model.eval() model.to(device) # 创建ONNX Runtime会话 context_session = ort.InferenceSession( context_encoder_path, providers=[ "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" ], ) decoder_session = ort.InferenceSession( decoder_path, providers=[ "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider" ], ) # 创建测试输入 batch_size = 1 seq_len = config.get("max_seq_len", 128) pinyin_len = 24 input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long) # PyTorch完整推理 with torch.no_grad(): pytorch_logits = model( input_ids=input_ids, token_type_ids=torch.zeros_like(input_ids), # 简化处理 attention_mask=attention_mask, pinyin_ids=pinyin_ids, history_slot_ids=history_slot_ids, ) # ONNX推理流程 # 1. 上下文编码器 context_inputs = { "input_ids": input_ids.numpy(), "pinyin_ids": pinyin_ids.numpy(), "attention_mask": attention_mask.numpy(), } context_outputs = context_session.run(None, context_inputs) context_H, pinyin_P, context_mask, pinyin_mask = context_outputs # 2. 解码器 decoder_inputs = { "context_H": context_H, "pinyin_P": pinyin_P, "history_slot_ids": history_slot_ids.numpy(), "context_mask": context_mask, "pinyin_mask": pinyin_mask, } onnx_outputs = decoder_session.run(None, decoder_inputs) # 比较输出 return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits") def main(): parser = argparse.ArgumentParser(description="ONNX模型验证") parser.add_argument( "--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径" ) parser.add_argument( "--context-encoder", type=str, help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)", ) parser.add_argument( "--decoder", type=str, help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)", ) parser.add_argument( "--output-dir", "-o", type=str, default="./exported_models", help="导出目录(如果未指定单个模型路径)", ) parser.add_argument( "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="验证设备(默认: cpu)", ) parser.add_argument( "--skip-context", action="store_true", help="跳过上下文编码器验证" ) parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证") parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证") args = parser.parse_args() # 确定模型路径 if args.context_encoder: context_encoder_path = args.context_encoder else: context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx") if args.decoder: decoder_path = args.decoder else: decoder_path = os.path.join(args.output_dir, "decoder.onnx") print("🔬 ONNX模型验证") print("=" * 60) print(f"Checkpoint: {args.checkpoint}") print(f"上下文编码器: {context_encoder_path}") print(f"解码器: {decoder_path}") print(f"设备: {args.device}") print() all_pass = True # 验证上下文编码器 if not args.skip_context and os.path.exists(context_encoder_path): if verify_context_encoder(args.checkpoint, context_encoder_path, args.device): print("✅ 上下文编码器验证通过") else: print("❌ 上下文编码器验证失败") all_pass = False elif not args.skip_context: print("⚠️ 上下文编码器文件不存在,跳过验证") # 验证解码器 if not args.skip_decoder and os.path.exists(decoder_path): if verify_decoder(args.checkpoint, decoder_path, args.device): print("✅ 解码器验证通过") else: print("❌ 解码器验证失败") all_pass = False elif not args.skip_decoder: print("⚠️ 解码器文件不存在,跳过验证") # 端到端验证 if ( not args.skip_end_to_end and os.path.exists(context_encoder_path) and os.path.exists(decoder_path) ): if verify_end_to_end( args.checkpoint, context_encoder_path, decoder_path, args.device ): print("✅ 端到端验证通过") else: print("❌ 端到端验证失败") all_pass = False print("\n" + "=" * 60) if all_pass: print("🎉 所有验证通过!ONNX模型与PyTorch模型输出一致") else: print("❌ 部分验证失败,请检查模型导出过程") sys.exit(1) if __name__ == "__main__": main()