SUimeModelTraner/beam_search_demo.py

437 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
束搜索算法演示
展示如何使用导出的两个ONNX模型进行束搜索推理
模拟输入法场景:给定上下文、拼音和已确认字符,生成候选汉字序列。
"""
import argparse
import os
import sys
from pathlib import Path
from typing import List, Tuple, Dict, Any
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
class ONNXBeamSearch:
"""
基于ONNX模型的束搜索解码器
"""
def __init__(
self, context_encoder_path: str, decoder_path: str, device: str = "cpu"
):
"""
初始化
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
device: 推理设备cpu或cuda
"""
self.device = device
# 配置ONNX Runtime提供程序
if device == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
# 创建ONNX Runtime会话
self.context_session = ort.InferenceSession(
context_encoder_path, providers=providers
)
self.decoder_session = ort.InferenceSession(decoder_path, providers=providers)
# 获取输入输出信息
self.context_input_info = {
input.name: input for input in self.context_session.get_inputs()
}
self.context_output_info = {
output.name: output for output in self.context_session.get_outputs()
}
self.decoder_input_info = {
input.name: input for input in self.decoder_session.get_inputs()
}
self.decoder_output_info = {
output.name: output for output in self.decoder_session.get_outputs()
}
print("📦 ONNX模型加载完成")
print(f" 上下文编码器输入: {list(self.context_input_info.keys())}")
print(f" 上下文编码器输出: {list(self.context_output_info.keys())}")
print(f" 解码器输入: {list(self.decoder_input_info.keys())}")
print(f" 解码器输出: {list(self.decoder_output_info.keys())}")
def prepare_inputs(
self,
text_before: str,
text_after: str,
pinyin: str,
context_prompts: List[str] = None,
tokenizer=None,
query_engine=None,
max_seq_len: int = 128,
) -> Dict[str, np.ndarray]:
"""
准备模型输入
注意: 这是一个简化的版本,实际应用中需要实现完整的预处理逻辑
这里使用随机数据模拟实际应使用tokenizer和查询引擎
Args:
text_before: 光标前文本
text_after: 光标后文本
pinyin: 拼音输入
context_prompts: 上下文提示
tokenizer: 分词器(未实现)
query_engine: 查询引擎(未实现)
max_seq_len: 最大序列长度
Returns:
输入字典包含input_ids, pinyin_ids, attention_mask
"""
# 在实际应用中,这里应该实现:
# 1. 使用tokenizer将文本转换为input_ids
# 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids
# 3. 构建attention_mask
# 简化为随机数据(用于演示)
batch_size = 1
seq_len = min(max_seq_len, 64) # 简化长度
# 模拟输入
input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64)
pinyin_ids = np.random.randint(
0, 30, (batch_size, 24), dtype=np.int64
) # 固定24长度
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
# 在实际应用中,应该这样处理拼音:
# from src.model.dataset import text_to_pinyin_ids
# pinyin_ids_list = text_to_pinyin_ids(pinyin)
# pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1)
return {
"input_ids": input_ids,
"pinyin_ids": pinyin_ids,
"attention_mask": attention_mask,
}
def run_context_encoder(
self, inputs: Dict[str, np.ndarray]
) -> Tuple[np.ndarray, ...]:
"""
运行上下文编码器
Args:
inputs: 输入字典
Returns:
上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask)
"""
# 准备ONNX输入确保顺序正确
onnx_inputs = {}
for input_name in self.context_input_info.keys():
if input_name in inputs:
onnx_inputs[input_name] = inputs[input_name]
else:
# 对于缺失的输入,使用默认值
shape = self.context_input_info[input_name].shape
dtype = self.context_input_info[input_name].type
# 简化处理:创建零数组
if "int" in dtype:
onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64)
else:
onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32)
# 运行推理
outputs = self.context_session.run(None, onnx_inputs)
return tuple(outputs)
def run_decoder(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
history_slot_ids: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
) -> np.ndarray:
"""
运行解码器
Args:
context_H: 上下文编码 [batch, seq_len, 512]
pinyin_P: 拼音编码 [batch, 24, 512]
history_slot_ids: 历史槽位IDs [batch, 8]
context_mask: 上下文掩码 [batch, seq_len]
pinyin_mask: 拼音掩码 [batch, 24]
Returns:
logits [batch, vocab_size]
"""
inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(None, inputs)
return outputs[0]
def beam_search(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
beam_size: int = 5,
max_length: int = 8,
vocab_size: int = 10019,
temperature: float = 1.0,
length_penalty: float = 0.6,
) -> List[Tuple[List[int], float]]:
"""
束搜索算法
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
temperature: 温度参数(控制随机性)
length_penalty: 长度惩罚系数
Returns:
排序后的候选序列列表,每个元素为(序列, 分数)
"""
# 初始束空序列对数概率为0
beams = [([], 0.0)] # (序列, 累计对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids
if len(seq) < 8:
# 填充到8个槽位
history = seq + [0] * (8 - len(seq))
else:
# 只保留最近8个字符
history = seq[-8:]
history_array = np.array([history], dtype=np.int64)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_array, context_mask, pinyin_mask
)
# 应用温度
logits = logits / temperature
# 转换为概率
probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy()
# 获取top-k候选k = beam_size * 2 以增加多样性)
top_k = min(beam_size * 2, vocab_size)
top_indices = np.argsort(probs)[-top_k:][::-1]
top_probs = probs[top_indices]
# 扩展束
for idx, prob in zip(top_indices, top_probs):
new_seq = seq + [int(idx)]
# 更新分数:累计对数概率 + log(prob)
new_score = score + np.log(prob + 1e-10)
# 应用长度惩罚
length_penalized_score = new_score / (
(5 + len(new_seq)) ** length_penalty
)
new_beams.append((new_seq, length_penalized_score, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序
beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]]
# 检查是否所有序列都已结束结束符ID为0
all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0)
if all_ended:
break
# 返回原始分数(未应用长度惩罚)
return beams
def interactive_beam_search(self):
"""交互式束搜索演示"""
print("\n" + "=" * 60)
print("束搜索交互演示")
print("=" * 60)
print("\n📝 输入法场景模拟:")
print(" 假设用户正在输入拼音 'nihao',已确认第一个字 ''")
print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'")
print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)")
print("-" * 60)
# 模拟输入(实际应用中应从用户获取)
text_before = "今天天气很好"
text_after = "我们去公园玩"
pinyin = "hao" # 继续输入"hao"
slot_chars = [""] # 已确认的字符
context_prompts = ["张三", "李四"]
print(f"\n📋 输入参数:")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars}")
print(f" 上下文提示: {context_prompts}")
# 准备输入(简化版,使用随机数据)
print(f"\n🔧 准备模型输入...")
inputs = self.prepare_inputs(
text_before=text_before,
text_after=text_after,
pinyin=pinyin,
context_prompts=context_prompts,
)
# 运行上下文编码器
print(f"🧠 运行上下文编码器...")
context_outputs = self.run_context_encoder(inputs)
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
print(f"✅ 上下文编码完成")
print(f" context_H形状: {context_H.shape}")
print(f" pinyin_P形状: {pinyin_P.shape}")
# 将槽位历史转换为ID简化使用随机ID
# 实际应用中应使用query_engine将汉字转换为ID
slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42
if len(slot_ids) < 8:
slot_ids = slot_ids + [0] * (8 - len(slot_ids))
# 运行束搜索
print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...")
beams = self.beam_search(
context_H,
pinyin_P,
context_mask,
pinyin_mask,
beam_size=3,
max_length=4,
vocab_size=10019,
)
# 显示结果
print(f"\n🏆 束搜索结果:")
print("-" * 50)
for i, (seq, score) in enumerate(beams):
# 将ID序列转换为汉字简化显示ID
seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq])
print(f"{i + 1}. 序列: [{seq_str}]")
print(f" 对数概率: {score:.4f}")
print(f" 概率: {np.exp(score):.6f}")
print()
print("📝 说明:")
print(" - 'END' 表示结束符 (ID: 0)")
print(" - 实际应用中应将ID转换为汉字")
print(" - 最高分数的序列作为最终预测")
return beams
def main():
parser = argparse.ArgumentParser(description="束搜索算法演示")
parser.add_argument(
"--context-encoder",
type=str,
default="./exported_models/context_encoder.onnx",
help="上下文编码器ONNX路径默认: ./exported_models/context_encoder.onnx",
)
parser.add_argument(
"--decoder",
type=str,
default="./exported_models/decoder.onnx",
help="解码器ONNX路径默认: ./exported_models/decoder.onnx",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="推理设备(默认: cpu",
)
parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3")
parser.add_argument(
"--max-length", type=int, default=4, help="最大生成长度(默认: 4"
)
parser.add_argument(
"--interactive",
action="store_true",
default=True,
help="交互模式(默认: True",
)
args = parser.parse_args()
# 检查模型文件是否存在
if not os.path.exists(args.context_encoder):
print(f"❌ 上下文编码器文件不存在: {args.context_encoder}")
print(f" 请先运行 export_onnx.py 导出模型")
return
if not os.path.exists(args.decoder):
print(f"❌ 解码器文件不存在: {args.decoder}")
print(f" 请先运行 export_onnx.py 导出模型")
return
print("🚀 束搜索算法演示")
print("=" * 60)
print(f"上下文编码器: {args.context_encoder}")
print(f"解码器: {args.decoder}")
print(f"设备: {args.device}")
print(f"束大小: {args.beam_size}")
print(f"最大长度: {args.max_length}")
# 初始化束搜索器
beam_searcher = ONNXBeamSearch(
context_encoder_path=args.context_encoder,
decoder_path=args.decoder,
device=args.device,
)
if args.interactive:
# 运行交互演示
beam_searcher.interactive_beam_search()
print("\n" + "=" * 60)
print("🎉 演示完成")
print("\n💡 下一步:")
print(" 1. 实现完整的输入预处理tokenizer和拼音转换")
print(" 2. 集成查询引擎以将ID转换为汉字")
print(" 3. 根据实际场景调整束搜索参数")
print(" 4. 性能优化:批量处理、缓存等")
if __name__ == "__main__":
main()