410 lines
12 KiB
Python
410 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
ONNX模型验证脚本
|
||
|
||
验证导出的ONNX模型与原始PyTorch模型输出的一致性
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import onnxruntime as ort
|
||
|
||
# 添加src目录到路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from src.model.export_models import create_export_models_from_checkpoint
|
||
|
||
|
||
def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5):
|
||
"""
|
||
比较PyTorch和ONNX输出
|
||
|
||
Args:
|
||
pytorch_output: PyTorch张量
|
||
onnx_output: ONNX Runtime输出(numpy数组)
|
||
name: 输出名称(用于错误信息)
|
||
rtol: 相对容差
|
||
atol: 绝对容差
|
||
|
||
Returns:
|
||
bool: 是否匹配
|
||
"""
|
||
# 转换PyTorch输出为numpy
|
||
if isinstance(pytorch_output, torch.Tensor):
|
||
pytorch_np = pytorch_output.detach().cpu().numpy()
|
||
else:
|
||
pytorch_np = np.array(pytorch_output)
|
||
|
||
# 确保形状一致
|
||
if pytorch_np.shape != onnx_output.shape:
|
||
print(
|
||
f"❌ {name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}"
|
||
)
|
||
return False
|
||
|
||
# 计算差异
|
||
diff = np.abs(pytorch_np - onnx_output)
|
||
max_diff = np.max(diff)
|
||
mean_diff = np.mean(diff)
|
||
|
||
# 检查是否在容差范围内
|
||
is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol)
|
||
|
||
if is_close:
|
||
print(f"✅ {name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||
else:
|
||
print(f"❌ {name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||
print(
|
||
f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], "
|
||
f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]"
|
||
)
|
||
|
||
return is_close
|
||
|
||
|
||
def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"):
|
||
"""
|
||
验证上下文编码器
|
||
|
||
Args:
|
||
checkpoint_path: PyTorch checkpoint路径
|
||
onnx_path: ONNX模型路径
|
||
device: 设备
|
||
|
||
Returns:
|
||
bool: 验证是否通过
|
||
"""
|
||
print(f"\n🔍 验证上下文编码器: {onnx_path}")
|
||
|
||
# 加载PyTorch模型
|
||
context_encoder_export, _, config = create_export_models_from_checkpoint(
|
||
checkpoint_path, device
|
||
)
|
||
|
||
# 创建ONNX Runtime会话
|
||
session = ort.InferenceSession(
|
||
onnx_path,
|
||
providers=[
|
||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||
],
|
||
)
|
||
|
||
# 创建测试输入
|
||
batch_size = 2 # 使用batch_size=2测试动态批处理
|
||
seq_len = config.get("max_seq_len", 128)
|
||
pinyin_len = 24
|
||
|
||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||
# 随机屏蔽一些位置
|
||
attention_mask[:, seq_len // 2 :] = 0
|
||
|
||
# PyTorch推理
|
||
with torch.no_grad():
|
||
pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask)
|
||
|
||
# ONNX推理
|
||
onnx_inputs = {
|
||
"input_ids": input_ids.numpy(),
|
||
"pinyin_ids": pinyin_ids.numpy(),
|
||
"attention_mask": attention_mask.numpy(),
|
||
}
|
||
|
||
onnx_outputs = session.run(None, onnx_inputs)
|
||
|
||
# 比较输出
|
||
output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"]
|
||
all_match = True
|
||
|
||
for i, name in enumerate(output_names):
|
||
if i < len(pytorch_outputs) and i < len(onnx_outputs):
|
||
match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name)
|
||
all_match = all_match and match
|
||
|
||
return all_match
|
||
|
||
|
||
def verify_decoder(checkpoint_path, onnx_path, device="cpu"):
|
||
"""
|
||
验证解码器
|
||
|
||
Args:
|
||
checkpoint_path: PyTorch checkpoint路径
|
||
onnx_path: ONNX模型路径
|
||
device: 设备
|
||
|
||
Returns:
|
||
bool: 验证是否通过
|
||
"""
|
||
print(f"\n🔍 验证解码器: {onnx_path}")
|
||
|
||
# 加载PyTorch模型
|
||
_, decoder_export, config = create_export_models_from_checkpoint(
|
||
checkpoint_path, device
|
||
)
|
||
|
||
# 创建ONNX Runtime会话
|
||
session = ort.InferenceSession(
|
||
onnx_path,
|
||
providers=[
|
||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||
],
|
||
)
|
||
|
||
# 创建测试输入
|
||
batch_size = 2
|
||
seq_len = config.get("max_seq_len", 128)
|
||
pinyin_len = 24
|
||
dim = config.get("dim", 512)
|
||
num_slots = config.get("num_slots", 8)
|
||
|
||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||
history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long)
|
||
|
||
# PyTorch推理
|
||
with torch.no_grad():
|
||
pytorch_output = decoder_export(
|
||
context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||
)
|
||
|
||
# ONNX推理
|
||
onnx_inputs = {
|
||
"context_H": context_H.numpy(),
|
||
"pinyin_P": pinyin_P.numpy(),
|
||
"history_slot_ids": history_slot_ids.numpy(),
|
||
"context_mask": context_mask.numpy(),
|
||
"pinyin_mask": pinyin_mask.numpy(),
|
||
}
|
||
|
||
onnx_outputs = session.run(None, onnx_inputs)
|
||
|
||
# 比较输出
|
||
return compare_outputs(pytorch_output, onnx_outputs[0], "logits")
|
||
|
||
|
||
def verify_end_to_end(
|
||
checkpoint_path, context_encoder_path, decoder_path, device="cpu"
|
||
):
|
||
"""
|
||
端到端验证:比较完整推理流程
|
||
|
||
Args:
|
||
checkpoint_path: PyTorch checkpoint路径
|
||
context_encoder_path: 上下文编码器ONNX路径
|
||
decoder_path: 解码器ONNX路径
|
||
device: 设备
|
||
|
||
Returns:
|
||
bool: 验证是否通过
|
||
"""
|
||
print(f"\n🔍 端到端验证")
|
||
|
||
# 加载原始PyTorch模型
|
||
from src.model.model import InputMethodEngine
|
||
|
||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||
|
||
if "config" in checkpoint:
|
||
config = checkpoint["config"]
|
||
else:
|
||
config = {
|
||
"vocab_size": 10019,
|
||
"pinyin_vocab_size": 30,
|
||
"dim": 512,
|
||
"num_slots": 8,
|
||
"n_layers": 4,
|
||
"n_heads": 4,
|
||
"num_experts": 10,
|
||
"max_seq_len": 128,
|
||
}
|
||
|
||
model = InputMethodEngine(
|
||
vocab_size=config.get("vocab_size", 10019),
|
||
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
|
||
dim=config.get("dim", 512),
|
||
num_slots=config.get("num_slots", 8),
|
||
n_layers=config.get("n_layers", 4),
|
||
n_heads=config.get("n_heads", 4),
|
||
num_experts=config.get("num_experts", 10),
|
||
max_seq_len=config.get("max_seq_len", 128),
|
||
compile=False,
|
||
)
|
||
|
||
if "model_state_dict" in checkpoint:
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
else:
|
||
model.load_state_dict(checkpoint)
|
||
|
||
model.eval()
|
||
model.to(device)
|
||
|
||
# 创建ONNX Runtime会话
|
||
context_session = ort.InferenceSession(
|
||
context_encoder_path,
|
||
providers=[
|
||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||
],
|
||
)
|
||
decoder_session = ort.InferenceSession(
|
||
decoder_path,
|
||
providers=[
|
||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||
],
|
||
)
|
||
|
||
# 创建测试输入
|
||
batch_size = 1
|
||
seq_len = config.get("max_seq_len", 128)
|
||
pinyin_len = 24
|
||
|
||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
|
||
|
||
# PyTorch完整推理
|
||
with torch.no_grad():
|
||
pytorch_logits = model(
|
||
input_ids=input_ids,
|
||
token_type_ids=torch.zeros_like(input_ids), # 简化处理
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
# ONNX推理流程
|
||
# 1. 上下文编码器
|
||
context_inputs = {
|
||
"input_ids": input_ids.numpy(),
|
||
"pinyin_ids": pinyin_ids.numpy(),
|
||
"attention_mask": attention_mask.numpy(),
|
||
}
|
||
context_outputs = context_session.run(None, context_inputs)
|
||
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
|
||
|
||
# 2. 解码器
|
||
decoder_inputs = {
|
||
"context_H": context_H,
|
||
"pinyin_P": pinyin_P,
|
||
"history_slot_ids": history_slot_ids.numpy(),
|
||
"context_mask": context_mask,
|
||
"pinyin_mask": pinyin_mask,
|
||
}
|
||
onnx_outputs = decoder_session.run(None, decoder_inputs)
|
||
|
||
# 比较输出
|
||
return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="ONNX模型验证")
|
||
parser.add_argument(
|
||
"--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径"
|
||
)
|
||
parser.add_argument(
|
||
"--context-encoder",
|
||
type=str,
|
||
help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)",
|
||
)
|
||
parser.add_argument(
|
||
"--decoder",
|
||
type=str,
|
||
help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)",
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
"-o",
|
||
type=str,
|
||
default="./exported_models",
|
||
help="导出目录(如果未指定单个模型路径)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cpu",
|
||
choices=["cpu", "cuda"],
|
||
help="验证设备(默认: cpu)",
|
||
)
|
||
parser.add_argument(
|
||
"--skip-context", action="store_true", help="跳过上下文编码器验证"
|
||
)
|
||
parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证")
|
||
parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 确定模型路径
|
||
if args.context_encoder:
|
||
context_encoder_path = args.context_encoder
|
||
else:
|
||
context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx")
|
||
|
||
if args.decoder:
|
||
decoder_path = args.decoder
|
||
else:
|
||
decoder_path = os.path.join(args.output_dir, "decoder.onnx")
|
||
|
||
print("🔬 ONNX模型验证")
|
||
print("=" * 60)
|
||
print(f"Checkpoint: {args.checkpoint}")
|
||
print(f"上下文编码器: {context_encoder_path}")
|
||
print(f"解码器: {decoder_path}")
|
||
print(f"设备: {args.device}")
|
||
print()
|
||
|
||
all_pass = True
|
||
|
||
# 验证上下文编码器
|
||
if not args.skip_context and os.path.exists(context_encoder_path):
|
||
if verify_context_encoder(args.checkpoint, context_encoder_path, args.device):
|
||
print("✅ 上下文编码器验证通过")
|
||
else:
|
||
print("❌ 上下文编码器验证失败")
|
||
all_pass = False
|
||
elif not args.skip_context:
|
||
print("⚠️ 上下文编码器文件不存在,跳过验证")
|
||
|
||
# 验证解码器
|
||
if not args.skip_decoder and os.path.exists(decoder_path):
|
||
if verify_decoder(args.checkpoint, decoder_path, args.device):
|
||
print("✅ 解码器验证通过")
|
||
else:
|
||
print("❌ 解码器验证失败")
|
||
all_pass = False
|
||
elif not args.skip_decoder:
|
||
print("⚠️ 解码器文件不存在,跳过验证")
|
||
|
||
# 端到端验证
|
||
if (
|
||
not args.skip_end_to_end
|
||
and os.path.exists(context_encoder_path)
|
||
and os.path.exists(decoder_path)
|
||
):
|
||
if verify_end_to_end(
|
||
args.checkpoint, context_encoder_path, decoder_path, args.device
|
||
):
|
||
print("✅ 端到端验证通过")
|
||
else:
|
||
print("❌ 端到端验证失败")
|
||
all_pass = False
|
||
|
||
print("\n" + "=" * 60)
|
||
if all_pass:
|
||
print("🎉 所有验证通过!ONNX模型与PyTorch模型输出一致")
|
||
else:
|
||
print("❌ 部分验证失败,请检查模型导出过程")
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|