#!/usr/bin/env python3 """ ONNX模型推理示例 展示如何使用导出的两个ONNX模型进行推理, 包括束搜索(beam search)算法。 """ import os import numpy as np import onnxruntime as ort import torch import torch.nn.functional as F from typing import List, Tuple class ONNXInference: """ONNX模型推理器""" def __init__(self, context_encoder_path, decoder_path): self.context_encoder_session = ort.InferenceSession( context_encoder_path, providers=['CPUExecutionProvider'] ) self.decoder_session = ort.InferenceSession( decoder_path, providers=['CPUExecutionProvider'] ) self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] print(f"上下文编码器输入: {self.context_input_names}") print(f"上下文编码器输出: {self.context_output_names}") print(f"解码器输入: {self.decoder_input_names}") print(f"解码器输出: {self.decoder_output_names}") def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): raise NotImplementedError("请实现实际的输入预处理") def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): inputs = { "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, } outputs = self.context_encoder_session.run(self.context_output_names, inputs) context_H, pinyin_P, context_mask, pinyin_mask = outputs return ( torch.from_numpy(context_H), torch.from_numpy(pinyin_P), torch.from_numpy(context_mask), torch.from_numpy(pinyin_mask), ) def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): inputs = { "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, } outputs = self.decoder_session.run(self.decoder_output_names, inputs) logits = outputs[0] return torch.from_numpy(logits) def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, beam_size=5, max_length=10, vocab_size=10019): beams = [([], 0.0)] for step in range(max_length): new_beams = [] for seq, score in beams: if len(seq) < 8: history = seq + [0] * (8 - len(seq)) else: history = seq[-8:] history_tensor = torch.tensor([history], dtype=torch.long) logits = self.run_decoder( context_H, pinyin_P, history_tensor, context_mask, pinyin_mask ) probs = F.softmax(logits[0], dim=-1) top_probs, top_indices = torch.topk(probs, beam_size) for prob, idx in zip(top_probs, top_indices): new_seq = seq + [idx.item()] new_score = score + torch.log(prob).item() new_beams.append((new_seq, new_score)) new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) if all_ended: break return beams def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( input_ids, pinyin_ids, attention_mask ) logits = self.run_decoder( context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask ) return logits def main(): """示例主函数""" print("ONNX模型推理示例") print("=" * 60) context_encoder_path = "context_encoder.onnx" decoder_path = "decoder.onnx" if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): print("错误: 找不到ONNX模型文件") print("请先运行 train-model export 导出模型") return inference = ONNXInference(context_encoder_path, decoder_path) print("\u2705 ONNX推理器初始化完成") print("请参考此示例实现完整的输入法推理流程") if __name__ == "__main__": main()