128 lines
5.3 KiB
Python
128 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
ONNX模型推理示例
|
||
|
||
展示如何使用导出的两个ONNX模型进行推理,
|
||
包括束搜索(beam search)算法。
|
||
"""
|
||
import os
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from typing import List, Tuple
|
||
|
||
|
||
class ONNXInference:
|
||
"""ONNX模型推理器"""
|
||
|
||
def __init__(self, context_encoder_path, decoder_path):
|
||
self.context_encoder_session = ort.InferenceSession(
|
||
context_encoder_path,
|
||
providers=['CPUExecutionProvider']
|
||
)
|
||
self.decoder_session = ort.InferenceSession(
|
||
decoder_path,
|
||
providers=['CPUExecutionProvider']
|
||
)
|
||
|
||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||
|
||
print(f"上下文编码器输入: {self.context_input_names}")
|
||
print(f"上下文编码器输出: {self.context_output_names}")
|
||
print(f"解码器输入: {self.decoder_input_names}")
|
||
print(f"解码器输出: {self.decoder_output_names}")
|
||
|
||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||
raise NotImplementedError("请实现实际的输入预处理")
|
||
|
||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||
inputs = {
|
||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||
}
|
||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||
return (
|
||
torch.from_numpy(context_H),
|
||
torch.from_numpy(pinyin_P),
|
||
torch.from_numpy(context_mask),
|
||
torch.from_numpy(pinyin_mask),
|
||
)
|
||
|
||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||
inputs = {
|
||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||
}
|
||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||
logits = outputs[0]
|
||
return torch.from_numpy(logits)
|
||
|
||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||
beam_size=5, max_length=10, vocab_size=10019):
|
||
beams = [([], 0.0)]
|
||
for step in range(max_length):
|
||
new_beams = []
|
||
for seq, score in beams:
|
||
if len(seq) < 8:
|
||
history = seq + [0] * (8 - len(seq))
|
||
else:
|
||
history = seq[-8:]
|
||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||
logits = self.run_decoder(
|
||
context_H, pinyin_P, history_tensor,
|
||
context_mask, pinyin_mask
|
||
)
|
||
probs = F.softmax(logits[0], dim=-1)
|
||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||
for prob, idx in zip(top_probs, top_indices):
|
||
new_seq = seq + [idx.item()]
|
||
new_score = score + torch.log(prob).item()
|
||
new_beams.append((new_seq, new_score))
|
||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||
beams = new_beams[:beam_size]
|
||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||
if all_ended:
|
||
break
|
||
return beams
|
||
|
||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||
input_ids, pinyin_ids, attention_mask
|
||
)
|
||
logits = self.run_decoder(
|
||
context_H, pinyin_P, history_slot_ids,
|
||
context_mask, pinyin_mask
|
||
)
|
||
return logits
|
||
|
||
|
||
def main():
|
||
"""示例主函数"""
|
||
print("ONNX模型推理示例")
|
||
print("=" * 60)
|
||
|
||
context_encoder_path = "context_encoder.onnx"
|
||
decoder_path = "decoder.onnx"
|
||
|
||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||
print("错误: 找不到ONNX模型文件")
|
||
print("请先运行 train-model export 导出模型")
|
||
return
|
||
|
||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||
print("\u2705 ONNX推理器初始化完成")
|
||
print("请参考此示例实现完整的输入法推理流程")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|