SUimeModelTraner/exported_models/inference_example.py

128 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法。
"""
import os
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider']
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
beams = [([], 0.0)]
for step in range(max_length):
new_beams = []
for seq, score in beams:
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:]
history_tensor = torch.tensor([history], dtype=torch.long)
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
probs = F.softmax(logits[0], dim=-1)
top_probs, top_indices = torch.topk(probs, beam_size)
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行 train-model export 导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("\u2705 ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()