SUimeModelTraner/verify_onnx.py

410 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
ONNX模型验证脚本
验证导出的ONNX模型与原始PyTorch模型输出的一致性
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import torch
import onnxruntime as ort
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.export_models import create_export_models_from_checkpoint
def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5):
"""
比较PyTorch和ONNX输出
Args:
pytorch_output: PyTorch张量
onnx_output: ONNX Runtime输出numpy数组
name: 输出名称(用于错误信息)
rtol: 相对容差
atol: 绝对容差
Returns:
bool: 是否匹配
"""
# 转换PyTorch输出为numpy
if isinstance(pytorch_output, torch.Tensor):
pytorch_np = pytorch_output.detach().cpu().numpy()
else:
pytorch_np = np.array(pytorch_output)
# 确保形状一致
if pytorch_np.shape != onnx_output.shape:
print(
f"{name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}"
)
return False
# 计算差异
diff = np.abs(pytorch_np - onnx_output)
max_diff = np.max(diff)
mean_diff = np.mean(diff)
# 检查是否在容差范围内
is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol)
if is_close:
print(f"{name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
else:
print(f"{name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print(
f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], "
f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]"
)
return is_close
def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"):
"""
验证上下文编码器
Args:
checkpoint_path: PyTorch checkpoint路径
onnx_path: ONNX模型路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 验证上下文编码器: {onnx_path}")
# 加载PyTorch模型
context_encoder_export, _, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
# 创建ONNX Runtime会话
session = ort.InferenceSession(
onnx_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 2 # 使用batch_size=2测试动态批处理
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
# 随机屏蔽一些位置
attention_mask[:, seq_len // 2 :] = 0
# PyTorch推理
with torch.no_grad():
pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask)
# ONNX推理
onnx_inputs = {
"input_ids": input_ids.numpy(),
"pinyin_ids": pinyin_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
onnx_outputs = session.run(None, onnx_inputs)
# 比较输出
output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"]
all_match = True
for i, name in enumerate(output_names):
if i < len(pytorch_outputs) and i < len(onnx_outputs):
match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name)
all_match = all_match and match
return all_match
def verify_decoder(checkpoint_path, onnx_path, device="cpu"):
"""
验证解码器
Args:
checkpoint_path: PyTorch checkpoint路径
onnx_path: ONNX模型路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 验证解码器: {onnx_path}")
# 加载PyTorch模型
_, decoder_export, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
# 创建ONNX Runtime会话
session = ort.InferenceSession(
onnx_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long)
# PyTorch推理
with torch.no_grad():
pytorch_output = decoder_export(
context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
)
# ONNX推理
onnx_inputs = {
"context_H": context_H.numpy(),
"pinyin_P": pinyin_P.numpy(),
"history_slot_ids": history_slot_ids.numpy(),
"context_mask": context_mask.numpy(),
"pinyin_mask": pinyin_mask.numpy(),
}
onnx_outputs = session.run(None, onnx_inputs)
# 比较输出
return compare_outputs(pytorch_output, onnx_outputs[0], "logits")
def verify_end_to_end(
checkpoint_path, context_encoder_path, decoder_path, device="cpu"
):
"""
端到端验证:比较完整推理流程
Args:
checkpoint_path: PyTorch checkpoint路径
context_encoder_path: 上下文编码器ONNX路径
decoder_path: 解码器ONNX路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 端到端验证")
# 加载原始PyTorch模型
from src.model.model import InputMethodEngine
checkpoint = torch.load(checkpoint_path, map_location=device)
if "config" in checkpoint:
config = checkpoint["config"]
else:
config = {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
}
model = InputMethodEngine(
vocab_size=config.get("vocab_size", 10019),
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
dim=config.get("dim", 512),
num_slots=config.get("num_slots", 8),
n_layers=config.get("n_layers", 4),
n_heads=config.get("n_heads", 4),
num_experts=config.get("num_experts", 10),
max_seq_len=config.get("max_seq_len", 128),
compile=False,
)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
# 创建ONNX Runtime会话
context_session = ort.InferenceSession(
context_encoder_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
decoder_session = ort.InferenceSession(
decoder_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 1
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
# PyTorch完整推理
with torch.no_grad():
pytorch_logits = model(
input_ids=input_ids,
token_type_ids=torch.zeros_like(input_ids), # 简化处理
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# ONNX推理流程
# 1. 上下文编码器
context_inputs = {
"input_ids": input_ids.numpy(),
"pinyin_ids": pinyin_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
context_outputs = context_session.run(None, context_inputs)
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
# 2. 解码器
decoder_inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids.numpy(),
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
onnx_outputs = decoder_session.run(None, decoder_inputs)
# 比较输出
return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits")
def main():
parser = argparse.ArgumentParser(description="ONNX模型验证")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径"
)
parser.add_argument(
"--context-encoder",
type=str,
help="上下文编码器ONNX路径默认: ./exported_models/context_encoder.onnx",
)
parser.add_argument(
"--decoder",
type=str,
help="解码器ONNX路径默认: ./exported_models/decoder.onnx",
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="导出目录(如果未指定单个模型路径)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="验证设备(默认: cpu",
)
parser.add_argument(
"--skip-context", action="store_true", help="跳过上下文编码器验证"
)
parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证")
parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证")
args = parser.parse_args()
# 确定模型路径
if args.context_encoder:
context_encoder_path = args.context_encoder
else:
context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx")
if args.decoder:
decoder_path = args.decoder
else:
decoder_path = os.path.join(args.output_dir, "decoder.onnx")
print("🔬 ONNX模型验证")
print("=" * 60)
print(f"Checkpoint: {args.checkpoint}")
print(f"上下文编码器: {context_encoder_path}")
print(f"解码器: {decoder_path}")
print(f"设备: {args.device}")
print()
all_pass = True
# 验证上下文编码器
if not args.skip_context and os.path.exists(context_encoder_path):
if verify_context_encoder(args.checkpoint, context_encoder_path, args.device):
print("✅ 上下文编码器验证通过")
else:
print("❌ 上下文编码器验证失败")
all_pass = False
elif not args.skip_context:
print("⚠️ 上下文编码器文件不存在,跳过验证")
# 验证解码器
if not args.skip_decoder and os.path.exists(decoder_path):
if verify_decoder(args.checkpoint, decoder_path, args.device):
print("✅ 解码器验证通过")
else:
print("❌ 解码器验证失败")
all_pass = False
elif not args.skip_decoder:
print("⚠️ 解码器文件不存在,跳过验证")
# 端到端验证
if (
not args.skip_end_to_end
and os.path.exists(context_encoder_path)
and os.path.exists(decoder_path)
):
if verify_end_to_end(
args.checkpoint, context_encoder_path, decoder_path, args.device
):
print("✅ 端到端验证通过")
else:
print("❌ 端到端验证失败")
all_pass = False
print("\n" + "=" * 60)
if all_pass:
print("🎉 所有验证通过ONNX模型与PyTorch模型输出一致")
else:
print("❌ 部分验证失败,请检查模型导出过程")
sys.exit(1)
if __name__ == "__main__":
main()