SUimeModelTraner/export_onnx.py

563 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
输入法模型ONNX导出脚本
将模型导出为两个ONNX部分
1. context_encoder.onnx - 上下文编码器
2. decoder.onnx - 解码器
使用方法:
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
依赖安装:
pip install onnx onnxruntime
"""
import argparse
import os
import sys
from pathlib import Path
import torch
import numpy as np
# 检查ONNX是否可用
try:
import onnx
import onnxruntime as ort
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
print("警告: ONNX或ONNX Runtime未安装")
print("请运行: pip install onnx onnxruntime")
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.export_models import create_export_models_from_checkpoint
def check_onnx_available():
"""检查ONNX依赖是否可用"""
if not ONNX_AVAILABLE:
print("错误: ONNX导出需要以下依赖:")
print(" pip install onnx onnxruntime")
print("请安装后重试")
return False
return True
def export_context_encoder(model, output_path, config):
"""
导出上下文编码器为ONNX格式
Args:
model: ContextEncoderExport实例
output_path: 输出路径
config: 模型配置
"""
print(f"正在导出上下文编码器到: {output_path}")
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24 # 固定长度
dim = config.get("dim", 512)
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
# 设置动态轴
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"pinyin_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
}
# 导出模型
torch.onnx.export(
model,
(input_ids, pinyin_ids, attention_mask),
output_path,
input_names=["input_ids", "pinyin_ids", "attention_mask"],
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
dynamic_axes=dynamic_axes,
opset_version=18, # 支持现代操作符
do_constant_folding=True,
verbose=False,
)
print("✅ 上下文编码器导出完成")
# 验证导出
try:
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX模型验证通过")
except Exception as e:
print(f"⚠️ ONNX模型验证警告: {e}")
return input_ids, pinyin_ids, attention_mask
def export_decoder(model, output_path, config, example_inputs=None):
"""
导出解码器为ONNX格式
Args:
model: DecoderExport实例
output_path: 输出路径
config: 模型配置
example_inputs: 示例输入(用于验证一致性)
"""
print(f"正在导解码器到: {output_path}")
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
# 使用提供的示例输入或创建新的
if example_inputs is not None:
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
batch_size = context_H.size(0) # 使用实际batch size
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
else:
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
# 设置动态轴
dynamic_axes = {
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"history_slot_ids": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
}
# 导出模型
torch.onnx.export(
model,
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
output_path,
input_names=[
"context_H",
"pinyin_P",
"history_slot_ids",
"context_mask",
"pinyin_mask",
],
output_names=["logits"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print("✅ 解码器导出完成")
# 验证导出
try:
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX模型验证通过")
except Exception as e:
print(f"⚠️ ONNX模型验证警告: {e}")
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
def save_example_inputs(output_dir, example_inputs_dict):
"""
保存示例输入为NPZ文件用于验证
Args:
output_dir: 输出目录
example_inputs_dict: 示例输入字典
"""
npz_path = os.path.join(output_dir, "example_inputs.npz")
# 转换为numpy数组
np_data = {}
for key, tensor in example_inputs_dict.items():
if isinstance(tensor, torch.Tensor):
np_data[key] = tensor.cpu().numpy()
elif isinstance(tensor, tuple):
for i, t in enumerate(tensor):
if isinstance(t, torch.Tensor):
np_data[f"{key}_{i}"] = t.cpu().numpy()
np.savez(npz_path, **np_data)
print(f"✅ 示例输入已保存到: {npz_path}")
# 同时保存为PyTorch格式
torch_path = os.path.join(output_dir, "example_inputs.pt")
torch.save(example_inputs_dict, torch_path)
print(f"✅ PyTorch示例输入已保存到: {torch_path}")
def create_inference_example(output_dir, config):
"""
创建推理示例脚本
Args:
output_dir: 输出目录
config: 模型配置
"""
example_path = os.path.join(output_dir, "inference_example.py")
example_code = '''#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法。
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入(与原始推理脚本保持一致)
注意: 这里需要实现文本到token的转换
为了简化示例,假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()
'''
with open(example_path, "w", encoding="utf-8") as f:
f.write(example_code)
print(f"✅ 推理示例脚本已保存到: {example_path}")
def main():
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="输出目录(默认: ./exported_models",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="导出设备(默认: cpu",
)
parser.add_argument(
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
)
args = parser.parse_args()
# 检查依赖
if not check_onnx_available():
sys.exit(1)
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"📁 输出目录: {output_dir.absolute()}")
# 加载checkpoint并创建导出模型
print(f"📦 加载checkpoint: {args.checkpoint}")
context_encoder_export, decoder_export, config = (
create_export_models_from_checkpoint(args.checkpoint, args.device)
)
print(f"📊 模型配置: {config}")
# 导出上下文编码器
context_encoder_path = output_dir / "context_encoder.onnx"
example_inputs = export_context_encoder(
context_encoder_export, str(context_encoder_path), config
)
# 使用上下文编码器的输出作为解码器的示例输入
with torch.no_grad():
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
*example_inputs
)
# 导出解码器
decoder_path = output_dir / "decoder.onnx"
export_decoder(
decoder_export,
str(decoder_path),
config,
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
)
# 保存示例输入
example_inputs_dict = {
"input_ids": example_inputs[0],
"pinyin_ids": example_inputs[1],
"attention_mask": example_inputs[2],
"context_H": context_H,
"pinyin_P": pinyin_P,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
save_example_inputs(output_dir, example_inputs_dict)
# 创建推理示例脚本
create_inference_example(output_dir, config)
print("\n" + "=" * 60)
print("🎉 ONNX导出完成")
print("=" * 60)
print(f"生成的模型文件:")
print(f" - {context_encoder_path}")
print(f" - {decoder_path}")
print(f" - {output_dir}/example_inputs.npz")
print(f" - {output_dir}/example_inputs.pt")
print(f" - {output_dir}/inference_example.py")
print("\n使用方法:")
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类")
print("\n注意:")
print(" - 请确保安装了onnxruntime: pip install onnxruntime")
print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu")
print(" - 束搜索算法需要根据实际需求进行调整")
if __name__ == "__main__":
main()