437 lines
15 KiB
Python
437 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
束搜索算法演示
|
||
|
||
展示如何使用导出的两个ONNX模型进行束搜索推理,
|
||
模拟输入法场景:给定上下文、拼音和已确认字符,生成候选汉字序列。
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import List, Tuple, Dict, Any
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
# 添加src目录到路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
|
||
class ONNXBeamSearch:
|
||
"""
|
||
基于ONNX模型的束搜索解码器
|
||
"""
|
||
|
||
def __init__(
|
||
self, context_encoder_path: str, decoder_path: str, device: str = "cpu"
|
||
):
|
||
"""
|
||
初始化
|
||
|
||
Args:
|
||
context_encoder_path: 上下文编码器ONNX模型路径
|
||
decoder_path: 解码器ONNX模型路径
|
||
device: 推理设备(cpu或cuda)
|
||
"""
|
||
self.device = device
|
||
|
||
# 配置ONNX Runtime提供程序
|
||
if device == "cuda":
|
||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||
else:
|
||
providers = ["CPUExecutionProvider"]
|
||
|
||
# 创建ONNX Runtime会话
|
||
self.context_session = ort.InferenceSession(
|
||
context_encoder_path, providers=providers
|
||
)
|
||
self.decoder_session = ort.InferenceSession(decoder_path, providers=providers)
|
||
|
||
# 获取输入输出信息
|
||
self.context_input_info = {
|
||
input.name: input for input in self.context_session.get_inputs()
|
||
}
|
||
self.context_output_info = {
|
||
output.name: output for output in self.context_session.get_outputs()
|
||
}
|
||
self.decoder_input_info = {
|
||
input.name: input for input in self.decoder_session.get_inputs()
|
||
}
|
||
self.decoder_output_info = {
|
||
output.name: output for output in self.decoder_session.get_outputs()
|
||
}
|
||
|
||
print("📦 ONNX模型加载完成")
|
||
print(f" 上下文编码器输入: {list(self.context_input_info.keys())}")
|
||
print(f" 上下文编码器输出: {list(self.context_output_info.keys())}")
|
||
print(f" 解码器输入: {list(self.decoder_input_info.keys())}")
|
||
print(f" 解码器输出: {list(self.decoder_output_info.keys())}")
|
||
|
||
def prepare_inputs(
|
||
self,
|
||
text_before: str,
|
||
text_after: str,
|
||
pinyin: str,
|
||
context_prompts: List[str] = None,
|
||
tokenizer=None,
|
||
query_engine=None,
|
||
max_seq_len: int = 128,
|
||
) -> Dict[str, np.ndarray]:
|
||
"""
|
||
准备模型输入
|
||
|
||
注意: 这是一个简化的版本,实际应用中需要实现完整的预处理逻辑
|
||
这里使用随机数据模拟,实际应使用tokenizer和查询引擎
|
||
|
||
Args:
|
||
text_before: 光标前文本
|
||
text_after: 光标后文本
|
||
pinyin: 拼音输入
|
||
context_prompts: 上下文提示
|
||
tokenizer: 分词器(未实现)
|
||
query_engine: 查询引擎(未实现)
|
||
max_seq_len: 最大序列长度
|
||
|
||
Returns:
|
||
输入字典,包含input_ids, pinyin_ids, attention_mask
|
||
"""
|
||
# 在实际应用中,这里应该实现:
|
||
# 1. 使用tokenizer将文本转换为input_ids
|
||
# 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids
|
||
# 3. 构建attention_mask
|
||
|
||
# 简化为随机数据(用于演示)
|
||
batch_size = 1
|
||
seq_len = min(max_seq_len, 64) # 简化长度
|
||
|
||
# 模拟输入
|
||
input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64)
|
||
pinyin_ids = np.random.randint(
|
||
0, 30, (batch_size, 24), dtype=np.int64
|
||
) # 固定24长度
|
||
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
|
||
|
||
# 在实际应用中,应该这样处理拼音:
|
||
# from src.model.dataset import text_to_pinyin_ids
|
||
# pinyin_ids_list = text_to_pinyin_ids(pinyin)
|
||
# pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1)
|
||
|
||
return {
|
||
"input_ids": input_ids,
|
||
"pinyin_ids": pinyin_ids,
|
||
"attention_mask": attention_mask,
|
||
}
|
||
|
||
def run_context_encoder(
|
||
self, inputs: Dict[str, np.ndarray]
|
||
) -> Tuple[np.ndarray, ...]:
|
||
"""
|
||
运行上下文编码器
|
||
|
||
Args:
|
||
inputs: 输入字典
|
||
|
||
Returns:
|
||
上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask)
|
||
"""
|
||
# 准备ONNX输入(确保顺序正确)
|
||
onnx_inputs = {}
|
||
for input_name in self.context_input_info.keys():
|
||
if input_name in inputs:
|
||
onnx_inputs[input_name] = inputs[input_name]
|
||
else:
|
||
# 对于缺失的输入,使用默认值
|
||
shape = self.context_input_info[input_name].shape
|
||
dtype = self.context_input_info[input_name].type
|
||
# 简化处理:创建零数组
|
||
if "int" in dtype:
|
||
onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64)
|
||
else:
|
||
onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32)
|
||
|
||
# 运行推理
|
||
outputs = self.context_session.run(None, onnx_inputs)
|
||
|
||
return tuple(outputs)
|
||
|
||
def run_decoder(
|
||
self,
|
||
context_H: np.ndarray,
|
||
pinyin_P: np.ndarray,
|
||
history_slot_ids: np.ndarray,
|
||
context_mask: np.ndarray,
|
||
pinyin_mask: np.ndarray,
|
||
) -> np.ndarray:
|
||
"""
|
||
运行解码器
|
||
|
||
Args:
|
||
context_H: 上下文编码 [batch, seq_len, 512]
|
||
pinyin_P: 拼音编码 [batch, 24, 512]
|
||
history_slot_ids: 历史槽位IDs [batch, 8]
|
||
context_mask: 上下文掩码 [batch, seq_len]
|
||
pinyin_mask: 拼音掩码 [batch, 24]
|
||
|
||
Returns:
|
||
logits [batch, vocab_size]
|
||
"""
|
||
inputs = {
|
||
"context_H": context_H,
|
||
"pinyin_P": pinyin_P,
|
||
"history_slot_ids": history_slot_ids,
|
||
"context_mask": context_mask,
|
||
"pinyin_mask": pinyin_mask,
|
||
}
|
||
|
||
# 运行推理
|
||
outputs = self.decoder_session.run(None, inputs)
|
||
|
||
return outputs[0]
|
||
|
||
def beam_search(
|
||
self,
|
||
context_H: np.ndarray,
|
||
pinyin_P: np.ndarray,
|
||
context_mask: np.ndarray,
|
||
pinyin_mask: np.ndarray,
|
||
beam_size: int = 5,
|
||
max_length: int = 8,
|
||
vocab_size: int = 10019,
|
||
temperature: float = 1.0,
|
||
length_penalty: float = 0.6,
|
||
) -> List[Tuple[List[int], float]]:
|
||
"""
|
||
束搜索算法
|
||
|
||
Args:
|
||
context_H: 上下文编码
|
||
pinyin_P: 拼音编码
|
||
context_mask: 上下文掩码
|
||
pinyin_mask: 拼音掩码
|
||
beam_size: 束大小
|
||
max_length: 最大生成长度
|
||
vocab_size: 词汇表大小
|
||
temperature: 温度参数(控制随机性)
|
||
length_penalty: 长度惩罚系数
|
||
|
||
Returns:
|
||
排序后的候选序列列表,每个元素为(序列, 分数)
|
||
"""
|
||
# 初始束:空序列,对数概率为0
|
||
beams = [([], 0.0)] # (序列, 累计对数概率)
|
||
|
||
for step in range(max_length):
|
||
new_beams = []
|
||
|
||
for seq, score in beams:
|
||
# 构建history_slot_ids
|
||
if len(seq) < 8:
|
||
# 填充到8个槽位
|
||
history = seq + [0] * (8 - len(seq))
|
||
else:
|
||
# 只保留最近8个字符
|
||
history = seq[-8:]
|
||
|
||
history_array = np.array([history], dtype=np.int64)
|
||
|
||
# 运行解码器
|
||
logits = self.run_decoder(
|
||
context_H, pinyin_P, history_array, context_mask, pinyin_mask
|
||
)
|
||
|
||
# 应用温度
|
||
logits = logits / temperature
|
||
|
||
# 转换为概率
|
||
probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy()
|
||
|
||
# 获取top-k候选(k = beam_size * 2 以增加多样性)
|
||
top_k = min(beam_size * 2, vocab_size)
|
||
top_indices = np.argsort(probs)[-top_k:][::-1]
|
||
top_probs = probs[top_indices]
|
||
|
||
# 扩展束
|
||
for idx, prob in zip(top_indices, top_probs):
|
||
new_seq = seq + [int(idx)]
|
||
# 更新分数:累计对数概率 + log(prob)
|
||
new_score = score + np.log(prob + 1e-10)
|
||
|
||
# 应用长度惩罚
|
||
length_penalized_score = new_score / (
|
||
(5 + len(new_seq)) ** length_penalty
|
||
)
|
||
|
||
new_beams.append((new_seq, length_penalized_score, new_score))
|
||
|
||
# 剪枝:保留beam_size个最佳候选
|
||
new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序
|
||
beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]]
|
||
|
||
# 检查是否所有序列都已结束(结束符ID为0)
|
||
all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0)
|
||
if all_ended:
|
||
break
|
||
|
||
# 返回原始分数(未应用长度惩罚)
|
||
return beams
|
||
|
||
def interactive_beam_search(self):
|
||
"""交互式束搜索演示"""
|
||
print("\n" + "=" * 60)
|
||
print("束搜索交互演示")
|
||
print("=" * 60)
|
||
|
||
print("\n📝 输入法场景模拟:")
|
||
print(" 假设用户正在输入拼音 'nihao',已确认第一个字 '你'")
|
||
print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'")
|
||
print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)")
|
||
print("-" * 60)
|
||
|
||
# 模拟输入(实际应用中应从用户获取)
|
||
text_before = "今天天气很好"
|
||
text_after = "我们去公园玩"
|
||
pinyin = "hao" # 继续输入"hao"
|
||
slot_chars = ["你"] # 已确认的字符
|
||
context_prompts = ["张三", "李四"]
|
||
|
||
print(f"\n📋 输入参数:")
|
||
print(f" 光标前文本: '{text_before}'")
|
||
print(f" 光标后文本: '{text_after}'")
|
||
print(f" 拼音: '{pinyin}'")
|
||
print(f" 槽位历史: {slot_chars}")
|
||
print(f" 上下文提示: {context_prompts}")
|
||
|
||
# 准备输入(简化版,使用随机数据)
|
||
print(f"\n🔧 准备模型输入...")
|
||
inputs = self.prepare_inputs(
|
||
text_before=text_before,
|
||
text_after=text_after,
|
||
pinyin=pinyin,
|
||
context_prompts=context_prompts,
|
||
)
|
||
|
||
# 运行上下文编码器
|
||
print(f"🧠 运行上下文编码器...")
|
||
context_outputs = self.run_context_encoder(inputs)
|
||
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
|
||
|
||
print(f"✅ 上下文编码完成")
|
||
print(f" context_H形状: {context_H.shape}")
|
||
print(f" pinyin_P形状: {pinyin_P.shape}")
|
||
|
||
# 将槽位历史转换为ID(简化:使用随机ID)
|
||
# 实际应用中应使用query_engine将汉字转换为ID
|
||
slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42
|
||
if len(slot_ids) < 8:
|
||
slot_ids = slot_ids + [0] * (8 - len(slot_ids))
|
||
|
||
# 运行束搜索
|
||
print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...")
|
||
beams = self.beam_search(
|
||
context_H,
|
||
pinyin_P,
|
||
context_mask,
|
||
pinyin_mask,
|
||
beam_size=3,
|
||
max_length=4,
|
||
vocab_size=10019,
|
||
)
|
||
|
||
# 显示结果
|
||
print(f"\n🏆 束搜索结果:")
|
||
print("-" * 50)
|
||
for i, (seq, score) in enumerate(beams):
|
||
# 将ID序列转换为汉字(简化:显示ID)
|
||
seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq])
|
||
print(f"{i + 1}. 序列: [{seq_str}]")
|
||
print(f" 对数概率: {score:.4f}")
|
||
print(f" 概率: {np.exp(score):.6f}")
|
||
print()
|
||
|
||
print("📝 说明:")
|
||
print(" - 'END' 表示结束符 (ID: 0)")
|
||
print(" - 实际应用中应将ID转换为汉字")
|
||
print(" - 最高分数的序列作为最终预测")
|
||
|
||
return beams
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="束搜索算法演示")
|
||
parser.add_argument(
|
||
"--context-encoder",
|
||
type=str,
|
||
default="./exported_models/context_encoder.onnx",
|
||
help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)",
|
||
)
|
||
parser.add_argument(
|
||
"--decoder",
|
||
type=str,
|
||
default="./exported_models/decoder.onnx",
|
||
help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cpu",
|
||
choices=["cpu", "cuda"],
|
||
help="推理设备(默认: cpu)",
|
||
)
|
||
parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3)")
|
||
parser.add_argument(
|
||
"--max-length", type=int, default=4, help="最大生成长度(默认: 4)"
|
||
)
|
||
parser.add_argument(
|
||
"--interactive",
|
||
action="store_true",
|
||
default=True,
|
||
help="交互模式(默认: True)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查模型文件是否存在
|
||
if not os.path.exists(args.context_encoder):
|
||
print(f"❌ 上下文编码器文件不存在: {args.context_encoder}")
|
||
print(f" 请先运行 export_onnx.py 导出模型")
|
||
return
|
||
|
||
if not os.path.exists(args.decoder):
|
||
print(f"❌ 解码器文件不存在: {args.decoder}")
|
||
print(f" 请先运行 export_onnx.py 导出模型")
|
||
return
|
||
|
||
print("🚀 束搜索算法演示")
|
||
print("=" * 60)
|
||
print(f"上下文编码器: {args.context_encoder}")
|
||
print(f"解码器: {args.decoder}")
|
||
print(f"设备: {args.device}")
|
||
print(f"束大小: {args.beam_size}")
|
||
print(f"最大长度: {args.max_length}")
|
||
|
||
# 初始化束搜索器
|
||
beam_searcher = ONNXBeamSearch(
|
||
context_encoder_path=args.context_encoder,
|
||
decoder_path=args.decoder,
|
||
device=args.device,
|
||
)
|
||
|
||
if args.interactive:
|
||
# 运行交互演示
|
||
beam_searcher.interactive_beam_search()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎉 演示完成")
|
||
print("\n💡 下一步:")
|
||
print(" 1. 实现完整的输入预处理(tokenizer和拼音转换)")
|
||
print(" 2. 集成查询引擎以将ID转换为汉字")
|
||
print(" 3. 根据实际场景调整束搜索参数")
|
||
print(" 4. 性能优化:批量处理、缓存等")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|