563 lines
18 KiB
Python
563 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
输入法模型ONNX导出脚本
|
||
|
||
将模型导出为两个ONNX部分:
|
||
1. context_encoder.onnx - 上下文编码器
|
||
2. decoder.onnx - 解码器
|
||
|
||
使用方法:
|
||
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
|
||
|
||
依赖安装:
|
||
pip install onnx onnxruntime
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import numpy as np
|
||
|
||
# 检查ONNX是否可用
|
||
try:
|
||
import onnx
|
||
import onnxruntime as ort
|
||
|
||
ONNX_AVAILABLE = True
|
||
except ImportError:
|
||
ONNX_AVAILABLE = False
|
||
print("警告: ONNX或ONNX Runtime未安装")
|
||
print("请运行: pip install onnx onnxruntime")
|
||
|
||
# 添加src目录到路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from src.model.export_models import create_export_models_from_checkpoint
|
||
|
||
|
||
def check_onnx_available():
|
||
"""检查ONNX依赖是否可用"""
|
||
if not ONNX_AVAILABLE:
|
||
print("错误: ONNX导出需要以下依赖:")
|
||
print(" pip install onnx onnxruntime")
|
||
print("请安装后重试")
|
||
return False
|
||
return True
|
||
|
||
|
||
def export_context_encoder(model, output_path, config):
|
||
"""
|
||
导出上下文编码器为ONNX格式
|
||
|
||
Args:
|
||
model: ContextEncoderExport实例
|
||
output_path: 输出路径
|
||
config: 模型配置
|
||
"""
|
||
print(f"正在导出上下文编码器到: {output_path}")
|
||
|
||
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
|
||
batch_size = 2
|
||
seq_len = config.get("max_seq_len", 128)
|
||
pinyin_len = 24 # 固定长度
|
||
dim = config.get("dim", 512)
|
||
|
||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||
|
||
# 设置动态轴
|
||
dynamic_axes = {
|
||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||
"pinyin_ids": {0: "batch_size"},
|
||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||
"context_H": {0: "batch_size", 1: "seq_len"},
|
||
"pinyin_P": {0: "batch_size"},
|
||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||
"pinyin_mask": {0: "batch_size"},
|
||
}
|
||
|
||
# 导出模型
|
||
torch.onnx.export(
|
||
model,
|
||
(input_ids, pinyin_ids, attention_mask),
|
||
output_path,
|
||
input_names=["input_ids", "pinyin_ids", "attention_mask"],
|
||
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
|
||
dynamic_axes=dynamic_axes,
|
||
opset_version=18, # 支持现代操作符
|
||
do_constant_folding=True,
|
||
verbose=False,
|
||
)
|
||
|
||
print("✅ 上下文编码器导出完成")
|
||
|
||
# 验证导出
|
||
try:
|
||
onnx_model = onnx.load(output_path)
|
||
onnx.checker.check_model(onnx_model)
|
||
print("✅ ONNX模型验证通过")
|
||
except Exception as e:
|
||
print(f"⚠️ ONNX模型验证警告: {e}")
|
||
|
||
return input_ids, pinyin_ids, attention_mask
|
||
|
||
|
||
def export_decoder(model, output_path, config, example_inputs=None):
|
||
"""
|
||
导出解码器为ONNX格式
|
||
|
||
Args:
|
||
model: DecoderExport实例
|
||
output_path: 输出路径
|
||
config: 模型配置
|
||
example_inputs: 示例输入(用于验证一致性)
|
||
"""
|
||
print(f"正在导解码器到: {output_path}")
|
||
|
||
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
|
||
batch_size = 2
|
||
seq_len = config.get("max_seq_len", 128)
|
||
pinyin_len = 24
|
||
dim = config.get("dim", 512)
|
||
num_slots = config.get("num_slots", 8)
|
||
|
||
# 使用提供的示例输入或创建新的
|
||
if example_inputs is not None:
|
||
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
|
||
batch_size = context_H.size(0) # 使用实际batch size
|
||
history_slot_ids = torch.randint(
|
||
0, 100, (batch_size, num_slots), dtype=torch.long
|
||
)
|
||
else:
|
||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||
history_slot_ids = torch.randint(
|
||
0, 100, (batch_size, num_slots), dtype=torch.long
|
||
)
|
||
|
||
# 设置动态轴
|
||
dynamic_axes = {
|
||
"context_H": {0: "batch_size", 1: "seq_len"},
|
||
"pinyin_P": {0: "batch_size"},
|
||
"history_slot_ids": {0: "batch_size"},
|
||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||
"pinyin_mask": {0: "batch_size"},
|
||
"logits": {0: "batch_size"},
|
||
}
|
||
|
||
# 导出模型
|
||
torch.onnx.export(
|
||
model,
|
||
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
|
||
output_path,
|
||
input_names=[
|
||
"context_H",
|
||
"pinyin_P",
|
||
"history_slot_ids",
|
||
"context_mask",
|
||
"pinyin_mask",
|
||
],
|
||
output_names=["logits"],
|
||
dynamic_axes=dynamic_axes,
|
||
opset_version=18,
|
||
do_constant_folding=True,
|
||
verbose=False,
|
||
)
|
||
|
||
print("✅ 解码器导出完成")
|
||
|
||
# 验证导出
|
||
try:
|
||
onnx_model = onnx.load(output_path)
|
||
onnx.checker.check_model(onnx_model)
|
||
print("✅ ONNX模型验证通过")
|
||
except Exception as e:
|
||
print(f"⚠️ ONNX模型验证警告: {e}")
|
||
|
||
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||
|
||
|
||
def save_example_inputs(output_dir, example_inputs_dict):
|
||
"""
|
||
保存示例输入为NPZ文件,用于验证
|
||
|
||
Args:
|
||
output_dir: 输出目录
|
||
example_inputs_dict: 示例输入字典
|
||
"""
|
||
npz_path = os.path.join(output_dir, "example_inputs.npz")
|
||
|
||
# 转换为numpy数组
|
||
np_data = {}
|
||
for key, tensor in example_inputs_dict.items():
|
||
if isinstance(tensor, torch.Tensor):
|
||
np_data[key] = tensor.cpu().numpy()
|
||
elif isinstance(tensor, tuple):
|
||
for i, t in enumerate(tensor):
|
||
if isinstance(t, torch.Tensor):
|
||
np_data[f"{key}_{i}"] = t.cpu().numpy()
|
||
|
||
np.savez(npz_path, **np_data)
|
||
print(f"✅ 示例输入已保存到: {npz_path}")
|
||
|
||
# 同时保存为PyTorch格式
|
||
torch_path = os.path.join(output_dir, "example_inputs.pt")
|
||
torch.save(example_inputs_dict, torch_path)
|
||
print(f"✅ PyTorch示例输入已保存到: {torch_path}")
|
||
|
||
|
||
def create_inference_example(output_dir, config):
|
||
"""
|
||
创建推理示例脚本
|
||
|
||
Args:
|
||
output_dir: 输出目录
|
||
config: 模型配置
|
||
"""
|
||
example_path = os.path.join(output_dir, "inference_example.py")
|
||
|
||
example_code = '''#!/usr/bin/env python3
|
||
"""
|
||
ONNX模型推理示例
|
||
|
||
展示如何使用导出的两个ONNX模型进行推理,
|
||
包括束搜索(beam search)算法。
|
||
"""
|
||
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from typing import List, Tuple
|
||
|
||
|
||
class ONNXInference:
|
||
"""ONNX模型推理器"""
|
||
|
||
def __init__(self, context_encoder_path, decoder_path):
|
||
"""
|
||
初始化ONNX推理器
|
||
|
||
Args:
|
||
context_encoder_path: 上下文编码器ONNX模型路径
|
||
decoder_path: 解码器ONNX模型路径
|
||
"""
|
||
# 创建ONNX Runtime会话
|
||
self.context_encoder_session = ort.InferenceSession(
|
||
context_encoder_path,
|
||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||
)
|
||
self.decoder_session = ort.InferenceSession(
|
||
decoder_path,
|
||
providers=['CPUExecutionProvider']
|
||
)
|
||
|
||
# 获取输入输出名称
|
||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||
|
||
print(f"上下文编码器输入: {self.context_input_names}")
|
||
print(f"上下文编码器输出: {self.context_output_names}")
|
||
print(f"解码器输入: {self.decoder_input_names}")
|
||
print(f"解码器输出: {self.decoder_output_names}")
|
||
|
||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||
"""
|
||
准备模型输入(与原始推理脚本保持一致)
|
||
|
||
注意: 这里需要实现文本到token的转换
|
||
为了简化示例,假设已经实现了相关函数
|
||
"""
|
||
# 这里应该调用实际的预处理函数
|
||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||
raise NotImplementedError("请实现实际的输入预处理")
|
||
|
||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||
"""
|
||
运行上下文编码器
|
||
|
||
Args:
|
||
input_ids: [batch, seq_len]
|
||
pinyin_ids: [batch, 24]
|
||
attention_mask: [batch, seq_len]
|
||
|
||
Returns:
|
||
context_H, pinyin_P, context_mask, pinyin_mask
|
||
"""
|
||
# 准备输入
|
||
inputs = {
|
||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||
}
|
||
|
||
# 运行推理
|
||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||
|
||
# 解包输出
|
||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||
|
||
return (
|
||
torch.from_numpy(context_H),
|
||
torch.from_numpy(pinyin_P),
|
||
torch.from_numpy(context_mask),
|
||
torch.from_numpy(pinyin_mask),
|
||
)
|
||
|
||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||
"""
|
||
运行解码器
|
||
|
||
Args:
|
||
context_H: [batch, seq_len, 512]
|
||
pinyin_P: [batch, 24, 512]
|
||
history_slot_ids: [batch, 8]
|
||
context_mask: [batch, seq_len]
|
||
pinyin_mask: [batch, 24]
|
||
|
||
Returns:
|
||
logits: [batch, vocab_size]
|
||
"""
|
||
# 准备输入
|
||
inputs = {
|
||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||
}
|
||
|
||
# 运行推理
|
||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||
|
||
# 解包输出
|
||
logits = outputs[0]
|
||
return torch.from_numpy(logits)
|
||
|
||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||
beam_size=5, max_length=10, vocab_size=10019):
|
||
"""
|
||
束搜索算法示例
|
||
|
||
Args:
|
||
context_H: 上下文编码
|
||
pinyin_P: 拼音编码
|
||
context_mask: 上下文掩码
|
||
pinyin_mask: 拼音掩码
|
||
beam_size: 束大小
|
||
max_length: 最大生成长度
|
||
vocab_size: 词汇表大小
|
||
|
||
Returns:
|
||
最佳序列列表
|
||
"""
|
||
# 初始束:空序列,分数为0
|
||
beams = [([], 0.0)] # (序列, 对数概率)
|
||
|
||
for step in range(max_length):
|
||
new_beams = []
|
||
|
||
for seq, score in beams:
|
||
# 构建history_slot_ids:已确认的字符ID
|
||
if len(seq) < 8:
|
||
history = seq + [0] * (8 - len(seq))
|
||
else:
|
||
history = seq[-8:] # 只保留最近8个
|
||
|
||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||
|
||
# 运行解码器
|
||
logits = self.run_decoder(
|
||
context_H, pinyin_P, history_tensor,
|
||
context_mask, pinyin_mask
|
||
)
|
||
|
||
# 获取概率
|
||
probs = F.softmax(logits[0], dim=-1)
|
||
|
||
# 获取top-k候选
|
||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||
|
||
# 扩展束
|
||
for prob, idx in zip(top_probs, top_indices):
|
||
new_seq = seq + [idx.item()]
|
||
new_score = score + torch.log(prob).item()
|
||
new_beams.append((new_seq, new_score))
|
||
|
||
# 剪枝:保留beam_size个最佳候选
|
||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||
beams = new_beams[:beam_size]
|
||
|
||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||
if all_ended:
|
||
break
|
||
|
||
return beams
|
||
|
||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||
"""
|
||
单步预测
|
||
|
||
Args:
|
||
input_ids: 输入token IDs
|
||
pinyin_ids: 拼音IDs
|
||
attention_mask: 注意力掩码
|
||
history_slot_ids: 历史槽位IDs
|
||
|
||
Returns:
|
||
预测logits
|
||
"""
|
||
# 1. 运行上下文编码器
|
||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||
input_ids, pinyin_ids, attention_mask
|
||
)
|
||
|
||
# 2. 运行解码器
|
||
logits = self.run_decoder(
|
||
context_H, pinyin_P, history_slot_ids,
|
||
context_mask, pinyin_mask
|
||
)
|
||
|
||
return logits
|
||
|
||
|
||
def main():
|
||
"""示例主函数"""
|
||
print("ONNX模型推理示例")
|
||
print("=" * 60)
|
||
|
||
# 初始化推理器
|
||
context_encoder_path = "context_encoder.onnx"
|
||
decoder_path = "decoder.onnx"
|
||
|
||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||
print("错误: 找不到ONNX模型文件")
|
||
print("请先运行export_onnx.py导出模型")
|
||
return
|
||
|
||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||
|
||
print("✅ ONNX推理器初始化完成")
|
||
print("请参考此示例实现完整的输入法推理流程")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
'''
|
||
|
||
with open(example_path, "w", encoding="utf-8") as f:
|
||
f.write(example_code)
|
||
|
||
print(f"✅ 推理示例脚本已保存到: {example_path}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
|
||
parser.add_argument(
|
||
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
"-o",
|
||
type=str,
|
||
default="./exported_models",
|
||
help="输出目录(默认: ./exported_models)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cpu",
|
||
choices=["cpu", "cuda"],
|
||
help="导出设备(默认: cpu)",
|
||
)
|
||
parser.add_argument(
|
||
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查依赖
|
||
if not check_onnx_available():
|
||
sys.exit(1)
|
||
|
||
# 创建输出目录
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"📁 输出目录: {output_dir.absolute()}")
|
||
|
||
# 加载checkpoint并创建导出模型
|
||
print(f"📦 加载checkpoint: {args.checkpoint}")
|
||
context_encoder_export, decoder_export, config = (
|
||
create_export_models_from_checkpoint(args.checkpoint, args.device)
|
||
)
|
||
|
||
print(f"📊 模型配置: {config}")
|
||
|
||
# 导出上下文编码器
|
||
context_encoder_path = output_dir / "context_encoder.onnx"
|
||
example_inputs = export_context_encoder(
|
||
context_encoder_export, str(context_encoder_path), config
|
||
)
|
||
|
||
# 使用上下文编码器的输出作为解码器的示例输入
|
||
with torch.no_grad():
|
||
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
|
||
*example_inputs
|
||
)
|
||
|
||
# 导出解码器
|
||
decoder_path = output_dir / "decoder.onnx"
|
||
export_decoder(
|
||
decoder_export,
|
||
str(decoder_path),
|
||
config,
|
||
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
|
||
)
|
||
|
||
# 保存示例输入
|
||
example_inputs_dict = {
|
||
"input_ids": example_inputs[0],
|
||
"pinyin_ids": example_inputs[1],
|
||
"attention_mask": example_inputs[2],
|
||
"context_H": context_H,
|
||
"pinyin_P": pinyin_P,
|
||
"context_mask": context_mask,
|
||
"pinyin_mask": pinyin_mask,
|
||
}
|
||
save_example_inputs(output_dir, example_inputs_dict)
|
||
|
||
# 创建推理示例脚本
|
||
create_inference_example(output_dir, config)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎉 ONNX导出完成!")
|
||
print("=" * 60)
|
||
print(f"生成的模型文件:")
|
||
print(f" - {context_encoder_path}")
|
||
print(f" - {decoder_path}")
|
||
print(f" - {output_dir}/example_inputs.npz")
|
||
print(f" - {output_dir}/example_inputs.pt")
|
||
print(f" - {output_dir}/inference_example.py")
|
||
print("\n使用方法:")
|
||
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
|
||
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
|
||
print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类")
|
||
print("\n注意:")
|
||
print(" - 请确保安装了onnxruntime: pip install onnxruntime")
|
||
print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu")
|
||
print(" - 束搜索算法需要根据实际需求进行调整")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|