Compare commits

..

3 Commits

23 changed files with 4302 additions and 395 deletions

6
.gitignore vendored
View File

@ -177,3 +177,9 @@ cython_debug/
uv.lock uv.lock
data/* data/*
**/*.onnx
**/*.data
**/*.npz
**/*.pt

105
README.md
View File

@ -790,112 +790,7 @@ train-model evaluate \
- 在评估数据集上计算准确率、困惑度等指标 - 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告 - 生成详细的性能报告
### 6.8 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 6.9 导出模型(开发中) ### 6.9 导出模型(开发中)

436
beam_search_demo.py Normal file
View File

@ -0,0 +1,436 @@
#!/usr/bin/env python3
"""
束搜索算法演示
展示如何使用导出的两个ONNX模型进行束搜索推理
模拟输入法场景给定上下文拼音和已确认字符生成候选汉字序列
"""
import argparse
import os
import sys
from pathlib import Path
from typing import List, Tuple, Dict, Any
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
class ONNXBeamSearch:
"""
基于ONNX模型的束搜索解码器
"""
def __init__(
self, context_encoder_path: str, decoder_path: str, device: str = "cpu"
):
"""
初始化
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
device: 推理设备cpu或cuda
"""
self.device = device
# 配置ONNX Runtime提供程序
if device == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
# 创建ONNX Runtime会话
self.context_session = ort.InferenceSession(
context_encoder_path, providers=providers
)
self.decoder_session = ort.InferenceSession(decoder_path, providers=providers)
# 获取输入输出信息
self.context_input_info = {
input.name: input for input in self.context_session.get_inputs()
}
self.context_output_info = {
output.name: output for output in self.context_session.get_outputs()
}
self.decoder_input_info = {
input.name: input for input in self.decoder_session.get_inputs()
}
self.decoder_output_info = {
output.name: output for output in self.decoder_session.get_outputs()
}
print("📦 ONNX模型加载完成")
print(f" 上下文编码器输入: {list(self.context_input_info.keys())}")
print(f" 上下文编码器输出: {list(self.context_output_info.keys())}")
print(f" 解码器输入: {list(self.decoder_input_info.keys())}")
print(f" 解码器输出: {list(self.decoder_output_info.keys())}")
def prepare_inputs(
self,
text_before: str,
text_after: str,
pinyin: str,
context_prompts: List[str] = None,
tokenizer=None,
query_engine=None,
max_seq_len: int = 128,
) -> Dict[str, np.ndarray]:
"""
准备模型输入
注意: 这是一个简化的版本实际应用中需要实现完整的预处理逻辑
这里使用随机数据模拟实际应使用tokenizer和查询引擎
Args:
text_before: 光标前文本
text_after: 光标后文本
pinyin: 拼音输入
context_prompts: 上下文提示
tokenizer: 分词器未实现
query_engine: 查询引擎未实现
max_seq_len: 最大序列长度
Returns:
输入字典包含input_ids, pinyin_ids, attention_mask
"""
# 在实际应用中,这里应该实现:
# 1. 使用tokenizer将文本转换为input_ids
# 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids
# 3. 构建attention_mask
# 简化为随机数据(用于演示)
batch_size = 1
seq_len = min(max_seq_len, 64) # 简化长度
# 模拟输入
input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64)
pinyin_ids = np.random.randint(
0, 30, (batch_size, 24), dtype=np.int64
) # 固定24长度
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
# 在实际应用中,应该这样处理拼音:
# from src.model.dataset import text_to_pinyin_ids
# pinyin_ids_list = text_to_pinyin_ids(pinyin)
# pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1)
return {
"input_ids": input_ids,
"pinyin_ids": pinyin_ids,
"attention_mask": attention_mask,
}
def run_context_encoder(
self, inputs: Dict[str, np.ndarray]
) -> Tuple[np.ndarray, ...]:
"""
运行上下文编码器
Args:
inputs: 输入字典
Returns:
上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask)
"""
# 准备ONNX输入确保顺序正确
onnx_inputs = {}
for input_name in self.context_input_info.keys():
if input_name in inputs:
onnx_inputs[input_name] = inputs[input_name]
else:
# 对于缺失的输入,使用默认值
shape = self.context_input_info[input_name].shape
dtype = self.context_input_info[input_name].type
# 简化处理:创建零数组
if "int" in dtype:
onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64)
else:
onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32)
# 运行推理
outputs = self.context_session.run(None, onnx_inputs)
return tuple(outputs)
def run_decoder(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
history_slot_ids: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
) -> np.ndarray:
"""
运行解码器
Args:
context_H: 上下文编码 [batch, seq_len, 512]
pinyin_P: 拼音编码 [batch, 24, 512]
history_slot_ids: 历史槽位IDs [batch, 8]
context_mask: 上下文掩码 [batch, seq_len]
pinyin_mask: 拼音掩码 [batch, 24]
Returns:
logits [batch, vocab_size]
"""
inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(None, inputs)
return outputs[0]
def beam_search(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
beam_size: int = 5,
max_length: int = 8,
vocab_size: int = 10019,
temperature: float = 1.0,
length_penalty: float = 0.6,
) -> List[Tuple[List[int], float]]:
"""
束搜索算法
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
temperature: 温度参数控制随机性
length_penalty: 长度惩罚系数
Returns:
排序后的候选序列列表每个元素为(序列, 分数)
"""
# 初始束空序列对数概率为0
beams = [([], 0.0)] # (序列, 累计对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids
if len(seq) < 8:
# 填充到8个槽位
history = seq + [0] * (8 - len(seq))
else:
# 只保留最近8个字符
history = seq[-8:]
history_array = np.array([history], dtype=np.int64)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_array, context_mask, pinyin_mask
)
# 应用温度
logits = logits / temperature
# 转换为概率
probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy()
# 获取top-k候选k = beam_size * 2 以增加多样性)
top_k = min(beam_size * 2, vocab_size)
top_indices = np.argsort(probs)[-top_k:][::-1]
top_probs = probs[top_indices]
# 扩展束
for idx, prob in zip(top_indices, top_probs):
new_seq = seq + [int(idx)]
# 更新分数:累计对数概率 + log(prob)
new_score = score + np.log(prob + 1e-10)
# 应用长度惩罚
length_penalized_score = new_score / (
(5 + len(new_seq)) ** length_penalty
)
new_beams.append((new_seq, length_penalized_score, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序
beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]]
# 检查是否所有序列都已结束结束符ID为0
all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0)
if all_ended:
break
# 返回原始分数(未应用长度惩罚)
return beams
def interactive_beam_search(self):
"""交互式束搜索演示"""
print("\n" + "=" * 60)
print("束搜索交互演示")
print("=" * 60)
print("\n📝 输入法场景模拟:")
print(" 假设用户正在输入拼音 'nihao',已确认第一个字 ''")
print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'")
print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)")
print("-" * 60)
# 模拟输入(实际应用中应从用户获取)
text_before = "今天天气很好"
text_after = "我们去公园玩"
pinyin = "hao" # 继续输入"hao"
slot_chars = [""] # 已确认的字符
context_prompts = ["张三", "李四"]
print(f"\n📋 输入参数:")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars}")
print(f" 上下文提示: {context_prompts}")
# 准备输入(简化版,使用随机数据)
print(f"\n🔧 准备模型输入...")
inputs = self.prepare_inputs(
text_before=text_before,
text_after=text_after,
pinyin=pinyin,
context_prompts=context_prompts,
)
# 运行上下文编码器
print(f"🧠 运行上下文编码器...")
context_outputs = self.run_context_encoder(inputs)
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
print(f"✅ 上下文编码完成")
print(f" context_H形状: {context_H.shape}")
print(f" pinyin_P形状: {pinyin_P.shape}")
# 将槽位历史转换为ID简化使用随机ID
# 实际应用中应使用query_engine将汉字转换为ID
slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42
if len(slot_ids) < 8:
slot_ids = slot_ids + [0] * (8 - len(slot_ids))
# 运行束搜索
print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...")
beams = self.beam_search(
context_H,
pinyin_P,
context_mask,
pinyin_mask,
beam_size=3,
max_length=4,
vocab_size=10019,
)
# 显示结果
print(f"\n🏆 束搜索结果:")
print("-" * 50)
for i, (seq, score) in enumerate(beams):
# 将ID序列转换为汉字简化显示ID
seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq])
print(f"{i + 1}. 序列: [{seq_str}]")
print(f" 对数概率: {score:.4f}")
print(f" 概率: {np.exp(score):.6f}")
print()
print("📝 说明:")
print(" - 'END' 表示结束符 (ID: 0)")
print(" - 实际应用中应将ID转换为汉字")
print(" - 最高分数的序列作为最终预测")
return beams
def main():
parser = argparse.ArgumentParser(description="束搜索算法演示")
parser.add_argument(
"--context-encoder",
type=str,
default="./exported_models/context_encoder.onnx",
help="上下文编码器ONNX路径默认: ./exported_models/context_encoder.onnx",
)
parser.add_argument(
"--decoder",
type=str,
default="./exported_models/decoder.onnx",
help="解码器ONNX路径默认: ./exported_models/decoder.onnx",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="推理设备(默认: cpu",
)
parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3")
parser.add_argument(
"--max-length", type=int, default=4, help="最大生成长度(默认: 4"
)
parser.add_argument(
"--interactive",
action="store_true",
default=True,
help="交互模式(默认: True",
)
args = parser.parse_args()
# 检查模型文件是否存在
if not os.path.exists(args.context_encoder):
print(f"❌ 上下文编码器文件不存在: {args.context_encoder}")
print(f" 请先运行 export_onnx.py 导出模型")
return
if not os.path.exists(args.decoder):
print(f"❌ 解码器文件不存在: {args.decoder}")
print(f" 请先运行 export_onnx.py 导出模型")
return
print("🚀 束搜索算法演示")
print("=" * 60)
print(f"上下文编码器: {args.context_encoder}")
print(f"解码器: {args.decoder}")
print(f"设备: {args.device}")
print(f"束大小: {args.beam_size}")
print(f"最大长度: {args.max_length}")
# 初始化束搜索器
beam_searcher = ONNXBeamSearch(
context_encoder_path=args.context_encoder,
decoder_path=args.decoder,
device=args.device,
)
if args.interactive:
# 运行交互演示
beam_searcher.interactive_beam_search()
print("\n" + "=" * 60)
print("🎉 演示完成")
print("\n💡 下一步:")
print(" 1. 实现完整的输入预处理tokenizer和拼音转换")
print(" 2. 集成查询引擎以将ID转换为汉字")
print(" 3. 根据实际场景调整束搜索参数")
print(" 4. 性能优化:批量处理、缓存等")
if __name__ == "__main__":
main()

View File

@ -479,118 +479,11 @@ train-model evaluate \
--batch-size 32 --batch-size 32
``` ```
命令将显示"评估功能待实现"的提示信息。该功能计划用于: 命令将显示"评估功能待实现"的提示信息。该功能计划用于:
- 加载训练好的模型检查点 - 加载训练好的模型检查点
- 在评估数据集上计算准确率、困惑度等指标 - 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告 - 生成详细的性能报告
### 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 导出模型(开发中) ### 导出模型(开发中)
当前导出功能尚在开发中: 当前导出功能尚在开发中:

112
export.record Normal file

File diff suppressed because one or more lines are too long

562
export_onnx.py Normal file
View File

@ -0,0 +1,562 @@
#!/usr/bin/env python3
"""
输入法模型ONNX导出脚本
将模型导出为两个ONNX部分
1. context_encoder.onnx - 上下文编码器
2. decoder.onnx - 解码器
使用方法:
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
依赖安装:
pip install onnx onnxruntime
"""
import argparse
import os
import sys
from pathlib import Path
import torch
import numpy as np
# 检查ONNX是否可用
try:
import onnx
import onnxruntime as ort
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
print("警告: ONNX或ONNX Runtime未安装")
print("请运行: pip install onnx onnxruntime")
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.export_models import create_export_models_from_checkpoint
def check_onnx_available():
"""检查ONNX依赖是否可用"""
if not ONNX_AVAILABLE:
print("错误: ONNX导出需要以下依赖:")
print(" pip install onnx onnxruntime")
print("请安装后重试")
return False
return True
def export_context_encoder(model, output_path, config):
"""
导出上下文编码器为ONNX格式
Args:
model: ContextEncoderExport实例
output_path: 输出路径
config: 模型配置
"""
print(f"正在导出上下文编码器到: {output_path}")
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24 # 固定长度
dim = config.get("dim", 512)
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
# 设置动态轴
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"pinyin_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
}
# 导出模型
torch.onnx.export(
model,
(input_ids, pinyin_ids, attention_mask),
output_path,
input_names=["input_ids", "pinyin_ids", "attention_mask"],
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
dynamic_axes=dynamic_axes,
opset_version=18, # 支持现代操作符
do_constant_folding=True,
verbose=False,
)
print("✅ 上下文编码器导出完成")
# 验证导出
try:
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX模型验证通过")
except Exception as e:
print(f"⚠️ ONNX模型验证警告: {e}")
return input_ids, pinyin_ids, attention_mask
def export_decoder(model, output_path, config, example_inputs=None):
"""
导出解码器为ONNX格式
Args:
model: DecoderExport实例
output_path: 输出路径
config: 模型配置
example_inputs: 示例输入用于验证一致性
"""
print(f"正在导解码器到: {output_path}")
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
# 使用提供的示例输入或创建新的
if example_inputs is not None:
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
batch_size = context_H.size(0) # 使用实际batch size
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
else:
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
# 设置动态轴
dynamic_axes = {
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"history_slot_ids": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
}
# 导出模型
torch.onnx.export(
model,
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
output_path,
input_names=[
"context_H",
"pinyin_P",
"history_slot_ids",
"context_mask",
"pinyin_mask",
],
output_names=["logits"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print("✅ 解码器导出完成")
# 验证导出
try:
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX模型验证通过")
except Exception as e:
print(f"⚠️ ONNX模型验证警告: {e}")
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
def save_example_inputs(output_dir, example_inputs_dict):
"""
保存示例输入为NPZ文件用于验证
Args:
output_dir: 输出目录
example_inputs_dict: 示例输入字典
"""
npz_path = os.path.join(output_dir, "example_inputs.npz")
# 转换为numpy数组
np_data = {}
for key, tensor in example_inputs_dict.items():
if isinstance(tensor, torch.Tensor):
np_data[key] = tensor.cpu().numpy()
elif isinstance(tensor, tuple):
for i, t in enumerate(tensor):
if isinstance(t, torch.Tensor):
np_data[f"{key}_{i}"] = t.cpu().numpy()
np.savez(npz_path, **np_data)
print(f"✅ 示例输入已保存到: {npz_path}")
# 同时保存为PyTorch格式
torch_path = os.path.join(output_dir, "example_inputs.pt")
torch.save(example_inputs_dict, torch_path)
print(f"✅ PyTorch示例输入已保存到: {torch_path}")
def create_inference_example(output_dir, config):
"""
创建推理示例脚本
Args:
output_dir: 输出目录
config: 模型配置
"""
example_path = os.path.join(output_dir, "inference_example.py")
example_code = '''#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()
'''
with open(example_path, "w", encoding="utf-8") as f:
f.write(example_code)
print(f"✅ 推理示例脚本已保存到: {example_path}")
def main():
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="输出目录(默认: ./exported_models",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="导出设备(默认: cpu",
)
parser.add_argument(
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
)
args = parser.parse_args()
# 检查依赖
if not check_onnx_available():
sys.exit(1)
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"📁 输出目录: {output_dir.absolute()}")
# 加载checkpoint并创建导出模型
print(f"📦 加载checkpoint: {args.checkpoint}")
context_encoder_export, decoder_export, config = (
create_export_models_from_checkpoint(args.checkpoint, args.device)
)
print(f"📊 模型配置: {config}")
# 导出上下文编码器
context_encoder_path = output_dir / "context_encoder.onnx"
example_inputs = export_context_encoder(
context_encoder_export, str(context_encoder_path), config
)
# 使用上下文编码器的输出作为解码器的示例输入
with torch.no_grad():
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
*example_inputs
)
# 导出解码器
decoder_path = output_dir / "decoder.onnx"
export_decoder(
decoder_export,
str(decoder_path),
config,
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
)
# 保存示例输入
example_inputs_dict = {
"input_ids": example_inputs[0],
"pinyin_ids": example_inputs[1],
"attention_mask": example_inputs[2],
"context_H": context_H,
"pinyin_P": pinyin_P,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
save_example_inputs(output_dir, example_inputs_dict)
# 创建推理示例脚本
create_inference_example(output_dir, config)
print("\n" + "=" * 60)
print("🎉 ONNX导出完成")
print("=" * 60)
print(f"生成的模型文件:")
print(f" - {context_encoder_path}")
print(f" - {decoder_path}")
print(f" - {output_dir}/example_inputs.npz")
print(f" - {output_dir}/example_inputs.pt")
print(f" - {output_dir}/inference_example.py")
print("\n使用方法:")
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类")
print("\n注意:")
print(" - 请确保安装了onnxruntime: pip install onnxruntime")
print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu")
print(" - 束搜索算法需要根据实际需求进行调整")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -22,6 +22,8 @@
""" """
import argparse import argparse
import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -60,6 +62,98 @@ class InputMethodInference:
print(f"✅ 推理器初始化完成 (设备: {self.device})") print(f"✅ 推理器初始化完成 (设备: {self.device})")
# 尝试启用readline以获得更好的行编辑功能
try:
import readline
# 设置readline使用UTF-8编码
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
print("📝 readline已启用支持更好的行编辑功能")
except ImportError:
print("📝 readline不可用使用标准输入")
def _safe_input(self, prompt: str, default: str = "") -> str:
"""
安全的输入函数尝试正确处理UTF-8字符和退格键
Args:
prompt: 提示文本
default: 默认值
Returns:
用户输入的字符串
"""
try:
# 显示提示和默认值
if default:
full_prompt = f"{prompt} [{default}]: "
else:
full_prompt = f"{prompt}: "
# 使用标准input
result = input(full_prompt)
# 如果用户直接回车且存在默认值,则返回默认值
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
# 用户按Ctrl+D或Ctrl+C
print()
return ""
except Exception as e:
# 其他错误
print(f"\n⚠️ 输入错误: {e}")
return ""
def _clean_pinyin_input(self, pinyin: str) -> str:
"""
清理拼音输入字符串处理退格键等特殊字符
拼音只允许: a-z, `, ', -
中文字符和其他字符会被忽略
Args:
pinyin: 原始拼音输入字符串
Returns:
清理后的拼音字符串
"""
if not pinyin:
return ""
result = []
for c in pinyin:
# 检查是否为合法拼音字符 (a-z, `, ', -)
# 注意: 中文字符的isalpha()也返回True所以需要额外检查
is_valid_pinyin_char = (
("a" <= c <= "z")
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
or c in ["`", "'", "-"]
)
if is_valid_pinyin_char:
# 合法拼音字符,转换为小写
result.append(c.lower())
elif c == " ":
# 空格忽略
continue
elif c == "\b" or c == "\x7f" or c == "\x08":
# 退格键、删除键:删除前一个字符
if result:
result.pop()
elif c == "\x1b":
# ESC键清空所有输入
result.clear()
else:
# 其他字符(包括中文字符)忽略
# 注意这里不添加到result所以退格键无法删除它们
# 但用户可能在拼音输入中误输入中文字符,应该忽略
pass
return "".join(result)
def load_model(self): def load_model(self):
"""加载训练好的模型""" """加载训练好的模型"""
# 创建模型实例(不编译) # 创建模型实例(不编译)
@ -184,7 +278,7 @@ class InputMethodInference:
""" """
# 1. 构建tokenizer输入 # 1. 构建tokenizer输入
# 根据dataset.py格式为: "part4|part1" 和 part3 # 根据test.py和dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握) # part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before # part1: text_before
# part3: text_after # part3: text_after
@ -192,13 +286,14 @@ class InputMethodInference:
# 处理上下文提示 # 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else "" context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本 # 构建输入文本 - 与test.py保持一致
# test.py: f"{part4}|{part1}" 作为第一个参数part3作为第二个参数
if context_text: if context_text:
input_text = f"{context_text}|{text_before}" input_text = f"{context_text}|{text_before}"
else: else:
input_text = text_before input_text = text_before
# 2. Tokenize # 2. Tokenize - 与test.py保持一致
encoded = self.tokenizer( encoded = self.tokenizer(
input_text, input_text,
text_after, text_after,
@ -209,8 +304,11 @@ class InputMethodInference:
return_token_type_ids=True, return_token_type_ids=True,
) )
# 3. 处理拼音输入 # 3. 处理拼音输入 - 与test.py保持一致
pinyin_ids = text_to_pinyin_ids(pinyin) # 首先清理拼音字符串,处理退格键等特殊字符
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24: if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids))) pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else: else:
@ -396,7 +494,16 @@ class InputMethodInference:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式") print("输入法模型推理 - 交互模式")
print("=" * 60) print("=" * 60)
print("说明:")
# 检查终端编码
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
if encoding.lower() not in ["utf-8", "utf8"]:
print("⚠️ 警告: 终端编码不是UTF-8中文输入可能有问题")
print(" 建议设置: export LANG=en_US.UTF-8")
print(" 或设置: export LC_ALL=en_US.UTF-8")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本") print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本") print(" - 光标后文本: 光标后的连续文本")
@ -411,7 +518,7 @@ class InputMethodInference:
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)") print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空") print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学") print("示例: 张三,李四,北京大学")
context_input = input("请输入上下文提示(直接回车跳过): ").strip() context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]: if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -434,7 +541,7 @@ class InputMethodInference:
print("第2步: 光标前文本") print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容") print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好") print("示例: 今天天气很好")
text_before = input("请输入光标前文本: ").strip() text_before = self._safe_input("请输入光标前文本")
if text_before.lower() in ["quit", "exit", "q"]: if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -446,7 +553,7 @@ class InputMethodInference:
print("第3步: 光标后文本") print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容") print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩") print("示例: 我们去公园玩")
text_after = input("请输入光标后文本: ").strip() text_after = self._safe_input("请输入光标后文本")
if text_after.lower() in ["quit", "exit", "q"]: if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -458,7 +565,7 @@ class InputMethodInference:
print("第4步: 拼音输入") print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音") print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao") print("示例: tian, shang, hao")
pinyin = input("请输入拼音: ").strip() pinyin = self._safe_input("请输入拼音")
if pinyin.lower() in ["quit", "exit", "q"]: if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -471,7 +578,7 @@ class InputMethodInference:
print("说明: 用户已确认的输入历史,用逗号分隔") print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''") print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)") print(" 今天,天气 (表示已确认两个词)")
slot_input = input("请输入槽位历史(直接回车表示无): ").strip() slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]: if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -524,7 +631,9 @@ class InputMethodInference:
# 询问是否继续 # 询问是否继续
print("\n" + "-" * 40) print("\n" + "-" * 40)
continue_input = input("是否继续推理?(y/n): ").strip().lower() continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]: if continue_input not in ["y", "yes", ""]:
print("退出交互模式") print("退出交互模式")
break break
@ -539,7 +648,9 @@ class InputMethodInference:
traceback.print_exc() traceback.print_exc()
# 询问是否继续 # 询问是否继续
continue_input = input("\n是否继续?(y/n): ").strip().lower() continue_input = (
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]: if continue_input not in ["y", "yes", ""]:
print("退出交互模式") print("退出交互模式")
break break

763
onnx_inference.py Normal file
View File

@ -0,0 +1,763 @@
#!/usr/bin/env python3
"""
ONNX输入法模型推理脚本
使用ONNX Runtime进行推理测量每个阶段的执行时长
使用方法:
python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx
交互模式: 分步询问输入
1. 上下文提示: 模型不掌握的专有词汇姓名等可为空
2. 光标前文本: 光标前的连续文本
3. 光标后文本: 光标后的连续文本
4. 拼音: 当前输入的拼音
5. 槽位历史: 用户已确认的输入历史如输入shanghai已确认""
"""
import argparse
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from modelscope import AutoTokenizer
from src.model.dataset import text_to_pinyin_ids
from src.model.query import QueryEngine
class ONNXInference:
"""ONNX输入法模型推理器"""
def __init__(
self,
context_encoder_path: str,
decoder_path: str,
vocab_size: int = 10019,
device: str = "cpu",
use_beam_search: bool = False,
beam_size: int = 5,
):
self.vocab_size = vocab_size
self.device = device
self.use_beam_search = use_beam_search
self.beam_size = beam_size
# 加载组件
print(f"正在加载上下文编码器: {context_encoder_path}")
load_start = time.perf_counter()
self.load_context_encoder(context_encoder_path)
self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)")
print(f"正在加载解码器: {decoder_path}")
load_start = time.perf_counter()
self.load_decoder(decoder_path)
self.decoder_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)")
# 加载tokenizer
print("正在加载tokenizer...")
load_start = time.perf_counter()
self.load_tokenizer()
self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)")
# 加载查询引擎
print("正在加载查询引擎...")
load_start = time.perf_counter()
self.load_query_engine()
self.query_engine_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)")
total_load_time = (
self.context_encoder_load_time
+ self.decoder_load_time
+ self.tokenizer_load_time
+ self.query_engine_load_time
)
print(f"\n✅ 推理器初始化完成 (设备: {device})")
print(f" 总加载时间: {total_load_time:.2f}ms")
# 尝试启用readline
try:
import readline
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
except ImportError:
pass
def load_context_encoder(self, model_path: str):
"""加载上下文编码器ONNX模型"""
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if self.device == "cuda"
else ["CPUExecutionProvider"]
)
self.context_encoder_session = ort.InferenceSession(
model_path, providers=providers
)
self.context_input_names = [
inp.name for inp in self.context_encoder_session.get_inputs()
]
self.context_output_names = [
out.name for out in self.context_encoder_session.get_outputs()
]
def load_decoder(self, model_path: str):
"""加载解码器ONNX模型"""
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if self.device == "cuda"
else ["CPUExecutionProvider"]
)
self.decoder_session = ort.InferenceSession(model_path, providers=providers)
self.decoder_input_names = [
inp.name for inp in self.decoder_session.get_inputs()
]
self.decoder_output_names = [
out.name for out in self.decoder_session.get_outputs()
]
def load_tokenizer(self):
"""加载tokenizer"""
try:
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
except Exception:
print(" ⚠️ 无法加载自定义tokenizer使用bert-base-chinese")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def load_query_engine(self):
"""加载查询引擎"""
try:
self.query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
self.query_engine.load(stats_path)
except Exception:
self.query_engine = None
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
"""将汉字转换为ID"""
if char == "//":
return 0
if self.query_engine is None:
return ord(char) if len(char) == 1 else 0
try:
if pinyin is not None:
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
if info:
return info.id
results = self.query_engine.query_by_char(char, limit=1)
if results:
return results[0][0]
return 0
except:
return 0
def id_to_char(self, id: int) -> str:
"""将ID转换为汉字"""
if id == 0:
return "//"
if self.query_engine is None:
return chr(id) if id < 0x110000 else f"[ID:{id}]"
try:
info = self.query_engine.query_by_id(id)
return info.char if info else f"[ID:{id}]"
except:
return f"[ID:{id}]"
def _clean_pinyin_input(self, pinyin: str) -> str:
"""清理拼音输入字符串"""
if not pinyin:
return ""
result = []
for c in pinyin:
is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"]
if is_valid:
result.append(c.lower())
elif c == " ":
continue
elif c in ["\b", "\x7f", "\x08"]:
if result:
result.pop()
elif c == "\x1b":
result.clear()
return "".join(result)
def _safe_input(self, prompt: str, default: str = "") -> str:
"""安全的输入函数"""
try:
full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: "
result = input(full_prompt)
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
print()
return ""
except Exception as e:
print(f"\n⚠️ 输入错误: {e}")
return ""
def prepare_inputs(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
max_seq_len: int = 128,
) -> dict:
"""
准备模型输入
Returns:
dict: {
'preprocess_time': float, # 预处理时间(ms)
'input_ids': numpy array,
'attention_mask': numpy array,
'pinyin_ids': numpy array,
'history_slot_ids': numpy array,
}
"""
preprocess_start = time.perf_counter()
# 1. 构建tokenizer输入
context_text = "|".join(context_prompts) if context_prompts else ""
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize
encoded = self.tokenizer(
input_text,
text_after,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
input_ids = encoded["input_ids"].numpy()
attention_mask = encoded["attention_mask"].numpy()
# 3. 处理拼音输入
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = np.array([pinyin_ids], dtype=np.int64)
# 4. 处理历史槽位
history_slot_ids = []
for char in slot_chars:
char_id = self.char_to_id(char)
history_slot_ids.append(char_id)
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
history_slot_ids = np.array([history_slot_ids], dtype=np.int64)
preprocess_time = (time.perf_counter() - preprocess_start) * 1000
return {
"preprocess_time": preprocess_time,
"input_ids": input_ids,
"attention_mask": attention_mask,
"pinyin_ids": pinyin_ids,
"history_slot_ids": history_slot_ids,
}
def run_context_encoder(
self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
运行上下文编码器
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
context_start = time.perf_counter()
inputs = {
"input_ids": input_ids,
"pinyin_ids": pinyin_ids,
"attention_mask": attention_mask,
}
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
context_H, pinyin_P, context_mask, pinyin_mask = outputs
self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000
return context_H, pinyin_P, context_mask, pinyin_mask
def run_decoder(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
history_slot_ids: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
) -> np.ndarray:
"""
运行解码器
Returns:
logits: [batch, vocab_size]
"""
decoder_start = time.perf_counter()
inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000
return outputs[0]
def predict(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
top_k: int = 20,
use_beam_search: bool = False,
beam_size: int = 5,
max_length: int = 10,
) -> Tuple[List[Tuple[str, float, int]], dict]:
"""
执行推理
Args:
context_prompts: 上下文提示
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表
top_k: 返回top-k个预测结果
use_beam_search: 是否使用束搜索
beam_size: 束大小
max_length: 最大生成长度
Returns:
(predictions, timing_info)
predictions: List[Tuple[char, prob, id]]
timing_info: 各阶段耗时字典
"""
total_start = time.perf_counter()
# 阶段1: 预处理
prep_start = time.perf_counter()
inputs = self.prepare_inputs(
context_prompts, text_before, text_after, pinyin, slot_chars
)
preprocess_time = inputs["preprocess_time"]
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
pinyin_ids = inputs["pinyin_ids"]
history_slot_ids = inputs["history_slot_ids"]
prep_time = (time.perf_counter() - prep_start) * 1000
# 阶段2: 上下文编码
context_start = time.perf_counter()
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
context_encoder_time = self.last_context_encoder_time
if use_beam_search:
# 阶段3: 束搜索解码
decode_start = time.perf_counter()
predictions, beam_decode_time = self._beam_search_decode(
context_H,
pinyin_P,
context_mask,
pinyin_mask,
beam_size,
max_length,
top_k,
)
decoder_time = beam_decode_time
else:
# 阶段3: 单步解码
decode_start = time.perf_counter()
logits = self.run_decoder(
context_H,
pinyin_P,
history_slot_ids,
context_mask,
pinyin_mask,
)
# 阶段4: 后处理
postprocess_start = time.perf_counter()
probs = self._softmax(logits)
top_indices, top_probs = self._topk(probs, top_k)
predictions = []
for i in range(top_k):
idx = int(top_indices[0, i])
prob = float(top_probs[0, i])
char = self.id_to_char(idx)
predictions.append((char, prob, idx))
postprocess_time = (time.perf_counter() - postprocess_start) * 1000
decoder_time = self.last_decoder_time
total_time = (time.perf_counter() - total_start) * 1000
timing_info = {
"预处理": prep_time,
"上下文编码": context_encoder_time,
"解码": decoder_time,
"后处理": postprocess_time if not use_beam_search else 0,
"总耗时": total_time,
}
return predictions, timing_info
def _softmax(self, logits: np.ndarray) -> np.ndarray:
"""计算softmax"""
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
"""获取top-k"""
topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1]
topk_probs = np.take_along_axis(probs, topk_indices, axis=-1)
return topk_indices, topk_probs
def _beam_search_decode(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
beam_size: int,
max_length: int,
top_k: int,
) -> Tuple[List[Tuple[str, float, int]], float]:
"""束搜索解码"""
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:]
history_tensor = np.array([history], dtype=np.int64)
logits = self.run_decoder(
context_H,
pinyin_P,
history_tensor,
context_mask,
pinyin_mask,
)
probs = self._softmax(logits)[0]
topk_indices = np.argsort(probs)[-beam_size:][::-1]
topk_probs = probs[topk_indices]
for idx, prob in zip(topk_indices, topk_probs):
new_seq = seq + [int(idx)]
new_score = score + np.log(prob + 1e-10)
new_beams.append((new_seq, new_score))
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
# 返回top-k个候选
predictions = []
for seq, score in beams[:top_k]:
if seq:
char = self.id_to_char(seq[-1])
prob = np.exp(score / max(len(seq), 1))
else:
char = self.id_to_char(0)
prob = 0.0
predictions.append((char, prob, seq[-1] if seq else 0))
decode_time = self.last_decoder_time # 只记录最后一次解码的时间
return predictions, decode_time
def interactive_mode(self):
"""交互式推理模式"""
print("\n" + "=" * 60)
print("ONNX输入法模型推理 - 交互模式")
print("=" * 60)
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
print(" - 拼音: 当前输入的拼音")
print(" - 槽位历史: 用户已确认的输入历史")
if self.use_beam_search:
print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})")
else:
print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)")
print("提示: 输入 'quit''exit''q' 可随时退出")
print("-" * 60)
while True:
try:
print("\n" + "=" * 60)
context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]:
break
context_prompts = [
item.strip() for item in context_input.split(",") if item.strip()
]
print("\n" + "-" * 40)
text_before = self._safe_input("第2步: 光标前文本")
if text_before.lower() in ["quit", "exit", "q"]:
break
print("\n" + "-" * 40)
text_after = self._safe_input("第3步: 光标后文本")
if text_after.lower() in ["quit", "exit", "q"]:
break
print("\n" + "-" * 40)
pinyin = self._safe_input("第4步: 拼音输入")
if pinyin.lower() in ["quit", "exit", "q"]:
break
print("\n" + "-" * 40)
slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]:
break
slot_chars = [
char.strip() for char in slot_input.split(",") if char.strip()
]
print("\n" + "=" * 60)
print("📝 输入汇总:")
print(f" 上下文提示: {context_prompts if context_prompts else ''}")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars if slot_chars else ''}")
print("\n🔮 推理中...")
predictions, timing_info = self.predict(
context_prompts,
text_before,
text_after,
pinyin,
slot_chars,
top_k=20,
use_beam_search=self.use_beam_search,
beam_size=self.beam_size,
)
# 显示时间统计
print(f"\n⏱️ 执行时间统计:")
print("-" * 40)
for stage, duration in timing_info.items():
if duration > 0:
print(f" {stage:<12}: {duration:>8.2f} ms")
print("-" * 40)
# 显示结果
print("\n🏆 Top-20 预测结果:")
print("-" * 50)
for i, (char, prob, idx) in enumerate(predictions):
if char == "//":
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
else:
print(
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
)
# 显示拼音参考
if pinyin and self.query_engine:
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
if pinyin_results:
for j, (pid, char, count) in enumerate(pinyin_results):
print(f" {char} (频次: {count:,})")
print("\n" + "-" * 40)
continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
break
except KeyboardInterrupt:
print("\n\n退出交互模式")
break
except Exception as e:
print(f"\n❌ 推理出错: {e}")
import traceback
traceback.print_exc()
def main():
parser = argparse.ArgumentParser(description="ONNX输入法模型推理")
parser.add_argument(
"--context-encoder",
"-c",
type=str,
required=True,
help="上下文编码器ONNX模型路径",
)
parser.add_argument(
"--decoder",
"-d",
type=str,
required=True,
help="解码器ONNX模型路径",
)
parser.add_argument(
"--vocab-size",
type=int,
default=10019,
help="词汇表大小 (默认: 10019)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="推理设备 (默认: cpu)",
)
parser.add_argument(
"--interactive",
action="store_true",
default=True,
help="交互模式 (默认: True)",
)
parser.add_argument("--test", action="store_true", help="运行测试推理")
parser.add_argument(
"--beam",
action="store_true",
help="使用束搜索解码 (默认: 单步解码)",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="束大小 (默认: 5)",
)
args = parser.parse_args()
# 检查文件是否存在
if not os.path.exists(args.context_encoder):
print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}")
sys.exit(1)
if not os.path.exists(args.decoder):
print(f"❌ 错误: 解码器文件不存在: {args.decoder}")
sys.exit(1)
# 初始化推理器
inference = ONNXInference(
context_encoder_path=args.context_encoder,
decoder_path=args.decoder,
vocab_size=args.vocab_size,
device=args.device,
use_beam_search=args.beam,
beam_size=args.beam_size,
)
# 测试推理
if args.test:
print("\n🧪 运行测试推理...")
print("测试场景: 输入'shanghai',已确认第一个字'',继续输入'tian'")
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
predictions, timing_info = inference.predict(
context_prompts=["张三", "李四"],
text_before="今天天气",
text_after="很好",
pinyin="tian",
slot_chars=[""],
use_beam_search=args.beam,
beam_size=args.beam_size,
)
print(f"\n⏱️ 执行时间统计:")
print("-" * 40)
for stage, duration in timing_info.items():
if duration > 0:
print(f" {stage:<12}: {duration:>8.2f} ms")
print("-" * 40)
print(f"\nTop-5 结果:")
for i, (char, prob, idx) in enumerate(predictions[:5]):
if char == "//":
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
else:
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
# 交互模式
if args.interactive:
inference.interactive_mode()
if __name__ == "__main__":
main()

View File

@ -14,6 +14,7 @@ dependencies = [
"onnxruntime>=1.24.2", "onnxruntime>=1.24.2",
"pandas>=3.0.0", "pandas>=3.0.0",
"plotly>=5.0.0", "plotly>=5.0.0",
"jieba>=0.42.1",
"pypinyin>=0.55.0", "pypinyin>=0.55.0",
"requests>=2.32.5", "requests>=2.32.5",
"rich>=14.3.1", "rich>=14.3.1",
@ -23,6 +24,7 @@ dependencies = [
"transformers==5.1.0", "transformers==5.1.0",
"typer>=0.21.1", "typer>=0.21.1",
"waitress>=3.0.2", "waitress>=3.0.2",
"onnx>=1.21.0",
] ]
[project.scripts] [project.scripts]

155
src/analyze_data.py Normal file
View File

@ -0,0 +1,155 @@
#!/usr/bin/env python3
"""
数据分布查看脚本
主要关注 labels 0-10, 200-210, 2900-2910, 5100-5110, 9520-9530 的分布情况
使用方法:
uv run python src/analyze_data.py --data_path ./data/train.txt
注意:
data_path 可以是:
- 本地文本文件 (每行一个文本)
- HuggingFace 数据集路径
- 目录路径 (该目录下需要有 dataset_info.json 或类似配置文件)
"""
import os
import random
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from model.dataset import PinyinInputDataset
def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 10000):
"""分析 label 在指定区间的分布"""
target_ranges = [
(0, 10),
(200, 210),
(2900, 2910),
(5100, 5110),
(9520, 9530),
]
id_counts = defaultdict(int)
all_examples = []
dataloader = DataLoader(dataset, batch_size=1, num_workers=16)
total_count = 0
sample_collected = 0
for batch in dataloader:
if sample_collected >= sample_size:
break
label = batch["label"].item()
prefix = batch["prefix"][0]
suffix = batch["suffix"][0]
pinyin = batch["pinyin"][0]
history = batch["history_slot_ids"][0].tolist()
part4 = prefix.split("^")[0] if "^" in prefix else ""
sample_collected += 1
total_count += 1
in_target_range = False
for start, end in target_ranges:
if start <= label <= end:
id_counts[label] += 1
in_target_range = True
if len(all_examples) < 200:
all_examples.append(
{
"label": label,
"prefix": prefix,
"suffix": suffix,
"pinyin": pinyin,
"history": history,
"part4": part4,
}
)
print(f"\n{'=' * 60}")
print(f"采样总数: {total_count}")
print(f"{'=' * 60}\n")
print("Label ID 分布统计:")
print("-" * 50)
for start, end in target_ranges:
print(f"\n区间 [{start:5d} - {end:5d}]:")
for label_id in range(start, end + 1):
count = id_counts[label_id]
percentage = (count / total_count * 100) if total_count > 0 else 0
bar = "" * min(count, 50)
print(f" ID {label_id:5d}: {count:6d} ({percentage:.3f}%) {bar}")
print("\n")
print("=" * 80)
print("随机抽取 20 个样本详情:")
print("=" * 80)
random.shuffle(all_examples)
for idx, ex in enumerate(all_examples[:20], 1):
print(f"\n样本 {idx}:")
print(f" Label: {ex['label']}")
print(f" Part4: {ex['part4']}")
print(f" 光标前: {ex['prefix']}")
print(f" 光标后: {ex['suffix']}")
print(f" 拼音: {ex['pinyin']}")
print(f" 历史槽位: {ex['history']}")
def main():
import argparse
default_data_path = os.path.expanduser("~/Data/corpus/CCI-Data/")
parser = argparse.ArgumentParser(
description="分析数据集 label 分布",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=f"""
默认数据集路径: {default_data_path}
示例:
uv run python src/analyze_data.py --data_path {default_data_path} --sample_size 10000
uv run python src/analyze_data.py --data_path ./data/eval.txt --sample_size 5000
""",
)
parser.add_argument(
"--data_path",
type=str,
default=default_data_path,
help="数据集路径 (本地文件或HuggingFace路径)",
)
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
parser.add_argument("--max_workers", type=int, default=-1, help="DataLoader workers")
args = parser.parse_args()
print(f"加载数据集: {args.data_path}")
print()
try:
dataset = PinyinInputDataset(
data_path=args.data_path,
max_workers=args.max_workers,
max_iter_length=args.sample_size,
)
except Exception as e:
print(f"\n错误: 无法加载数据集 '{args.data_path}'")
print(f"原因: {e}")
print("\n请确保:")
print(" 1. 数据路径存在且正确")
print(" 2. 如果是本地文本文件,每行应该是一个 JSON 对象或纯文本")
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
return
analyze_label_distribution(dataset, sample_size=args.sample_size)
if __name__ == "__main__":
main()

View File

@ -8,7 +8,7 @@
"id": 0, "id": 0,
"char": "", "char": "",
"pinyin": "", "pinyin": "",
"count": 11067734826 "count": 494748360
}, },
"1": { "1": {
"id": 1, "id": 1,

View File

@ -1,3 +1,4 @@
import jieba
import random import random
import re import re
from importlib.resources import files from importlib.resources import files
@ -23,6 +24,26 @@ CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横 CHAR_TO_ID["-"] = 29 # 显式添加短横
jieba.setLogLevel(jieba.logging.INFO)
def segment_text(text: str) -> List[str]:
"""使用 jieba 分词,返回词列表"""
return list(jieba.cut(text, HMM=False))
def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]:
"""建立词边界列表 [(start, end), ...],基于顺序位置累加"""
result = []
pos = 0
for word in words:
start = pos
end = pos + len(word)
result.append((start, end))
pos = end
return result
def text_to_pinyin_ids(pinyin_str: str) -> List[int]: def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
""" """
将拼音字符串转换为 ID 列表 将拼音字符串转换为 ID 列表
@ -41,17 +62,22 @@ class PinyinInputDataset(IterableDataset):
max_iter_length=1e6, max_iter_length=1e6,
max_seq_length=128, max_seq_length=128,
text_field: str = "text", text_field: str = "text",
py_style_weight=(9, 2, 1), py_style_weight=(90, 2, 1),
shuffle_buffer_size: int = 100000, shuffle_buffer_size: int = 100000,
retention_ratio: float = 0.5, retention_ratio: float = 0.8,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
): ):
# 频率调整参数 (可根据需要调整) # 频率调整参数 (可根据需要调整)
self.drop_start_freq = 30_000_000 self.drop_start_freq = 10_000_000
self.max_drop_prob = 0.8 self.max_drop_prob = 0.9
self.repeat_end_freq = 10_000 self.repeat_end_freq = 10_000
self.max_repeat_expect = 50 self.max_repeat_expect = 50
self.min_freq = 109 self.min_freq = 109
self.word_break_prob = 0.10
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
jieba.initialize()
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
Path(str(files(__package__))) / "assets" / "tokenizer" Path(str(files(__package__))) / "assets" / "tokenizer"
@ -83,14 +109,15 @@ class PinyinInputDataset(IterableDataset):
# 提取每个样本的目标字符及其频率 # 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights() self.sample_freqs = self.query_engine.get_all_weights()
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
def adjust_frequency(self, freq: int) -> int: def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃""" """削峰填谷 - 根据频率调整采样次数0表示丢弃"""
# 1. 削峰处理(高频字) # 1. 削峰处理(高频字)
if freq >= self.drop_start_freq: if freq >= self.drop_start_freq:
# 线性丢弃概率计算 # 线性丢弃概率计算
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大 max_freq = self.max_freq # 使用预计算的最大频率
if max_freq == self.drop_start_freq: if max_freq <= self.drop_start_freq:
drop_prob = 0.0 drop_prob = 0.0
else: else:
drop_prob = ( drop_prob = (
@ -119,7 +146,11 @@ class PinyinInputDataset(IterableDataset):
) )
# 使用泊松分布实现随机重复 # 使用泊松分布实现随机重复
repeat_count = np.random.poisson(repeat_expect) repeat_count = np.random.poisson(repeat_expect)
return max(1, repeat_count) if repeat_expect < 1.0:
# 小期望值时,以概率 repeat_expect 采样 1 次
return 1 if random.random() < repeat_expect else 0
else:
return max(1, repeat_count) # 原逻辑
# 3. 中间频率字 # 3. 中间频率字
else: else:
@ -196,6 +227,55 @@ class PinyinInputDataset(IterableDataset):
mask_pinyin.append(py) mask_pinyin.append(py)
return len(mask_pinyin), mask_pinyin return len(mask_pinyin), mask_pinyin
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
pinyin_ids = text_to_pinyin_ids(pinyin_str)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
return torch.tensor(pinyin_ids, dtype=torch.long)
def _add_word_samples(
self,
batch_samples: list,
labels: list,
encoded: dict,
part4: str,
part1: str,
part3: str,
pinyin_str: str,
pinyin_ids: torch.Tensor,
) -> list:
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0))
if base_repeats == 0:
continue
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
history = labels[:label_idx]
len_h = len(history)
history.extend([0] * (8 - len_h))
sample_dict = {
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(history, dtype=torch.long),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": pinyin_str,
"pinyin_ids": pinyin_ids,
}
batch_samples.extend([sample_dict] * repeats)
return batch_samples
def __iter__(self): def __iter__(self):
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: if worker_info is not None:
@ -208,219 +288,219 @@ class PinyinInputDataset(IterableDataset):
random.seed(seed % (2**32)) random.seed(seed % (2**32))
np.random.seed(seed % (2**32)) np.random.seed(seed % (2**32))
# 安全检查如果worker_id >= num_workers则该worker不应该工作
# 这可能发生在self.max_workers小于实际worker数量时
if worker_id >= num_workers: if worker_id >= num_workers:
return # 产生空迭代器 return
# 使用局部变量存储分片数据集,避免竞争条件
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id) worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
# 计算每个worker的配额
# 将 max_iter_length 转换为整数以确保整数除法
total_quota = int(self.max_iter_length) total_quota = int(self.max_iter_length)
base_quota = total_quota // num_workers base_quota = total_quota // num_workers
remainder = total_quota % num_workers remainder = total_quota % num_workers
# 最后一个worker处理剩余的样本如果有余数
if worker_id == num_workers - 1: if worker_id == num_workers - 1:
worker_quota = base_quota + remainder worker_quota = base_quota + remainder
else: else:
worker_quota = base_quota worker_quota = base_quota
else: else:
# 单worker情况使用全部配额
worker_quota = int(self.max_iter_length) worker_quota = int(self.max_iter_length)
num_workers = 1 num_workers = 1
worker_dataset = self.dataset # 不使用分片 worker_dataset = self.dataset
# 每个worker有自己的迭代计数器
current_iter_index = 0 current_iter_index = 0
batch_samples = [] batch_samples = []
for sample in worker_dataset: for sample in worker_dataset:
# 检查是否达到最大迭代次数
if current_iter_index >= worker_quota: if current_iter_index >= worker_quota:
break break
text = sample.get(self.text_field, "") text = sample.get(self.text_field, "")
if text: if not text:
pinyin_list = self.generate_pinyin(text) continue
for i in range(len(text)):
# 在开始处理每个字符前检查配额
if current_iter_index >= worker_quota:
break
labels = [] words = segment_text(text)
# 如果text[i]不在字符库中,则跳过 word_boundaries = build_word_boundaries(words)
# 当i小于48时候则将part1取text[0:i] pinyin_list = self.generate_pinyin(text)
# 当i大于48时候则将part1取text[i-48:i]
if not self.query_engine.is_chinese_char(text[i]):
continue
if i < 48:
part1 = text[0:i]
else:
part1 = text[i - 48 : i]
# 方案C提前检查从位置i开始连续有多少个字符在词库中 for word_start, word_end in word_boundaries:
max_valid_len = 0 char_positions = []
for j in range(i, min(i + 8, len(text))): for i in range(word_start, word_end):
if self.query_engine.is_chinese_char(text[j]): if self.query_engine.is_chinese_char(text[i]):
max_valid_len += 1 char_positions.append(i)
if not char_positions:
continue
word_len_chars = len(char_positions)
should_break = (
word_len_chars > 1 and random.random() < self.word_break_prob
)
if should_break:
break_pos = random.randint(1, word_len_chars - 1)
else:
break_pos = word_len_chars
# ========== Phase 1: 前缀/整词 ==========
prefix_positions = char_positions[:break_pos]
prefix_text = "".join(text[i] for i in prefix_positions)
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
split_char = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
)
part2 = split_char.join(mask_pinyin)
pinyin_ids = self._compute_pinyin_ids(part2)
try:
labels = [
self.query_engine.get_char_info_by_char_pinyin(
text[i], pinyin_list[i]
).id
for i in prefix_positions
]
except AttributeError as e:
logger.error(
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
)
continue
# 整词末尾 10% 概率追加 EOS破词前缀不加
if not should_break and random.random() <= 0.1:
labels.append(0)
# part1: 词起点前的文本(所有样本共享)
part1 = text[max(0, word_start - 48) : word_start]
# part3: 词后文本
part3 = ""
if random.random() > 0.7:
part3 = text[word_end : word_end + np.random.choice(range(1, 17))]
# part4: 词提示
part4 = ""
if random.random() > 0.7:
num_words = random.randint(1, 3)
if words:
selected_words = random.sample(
words, min(num_words, len(words))
)
part4 = "|".join(selected_words)
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
batch_samples = self._add_word_samples(
batch_samples,
labels,
encoded,
part4,
part1,
part3,
part2,
pinyin_ids,
)
# ========== Phase 2: 破词续接 ==========
if should_break and break_pos < word_len_chars:
cont_start = char_positions[break_pos]
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
target_len = np.random.choice(range(1, 9), p=self.cont_length_probs)
cont_positions = []
pos = cont_start
while len(cont_positions) < target_len and pos < len(text):
if self.query_engine.is_chinese_char(text[pos]):
cont_positions.append(pos)
else: else:
break break
pos += 1
# 如果没有可用字符,跳过 if not cont_positions:
if max_valid_len == 0:
continue continue
# 首先取随机值pinyin_len1-8pinyin_len取值呈高斯分布最大概率取3 cont_text = "".join(text[i] for i in cont_positions)
# 获取text[i + pinyin_len]字符如果无法获取所指向的后如果pinyin_len cont_pinyin = [pinyin_list[i] for i in cont_positions]
# part2的长度为x取pinyin_list[i:i+pinyin_len]为part2
# 但是需要注意边界条件
target_len = np.random.choice(
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
)
# 根据实际可用长度调整
pinyin_len = min(target_len, max_valid_len)
py_end = min(i + pinyin_len, len(text)) _, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
pinyin_len, part2 = self.get_mask_pinyin( split_char_cont = np.random.choice(
text[i:py_end], pinyin_list[i:py_end]
)
split_char = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02] ["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
) )
part2_cont = split_char_cont.join(mask_pinyin_cont)
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
part2 = split_char.join(part2)
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
# part3为文本大概率0.70)为空
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
# x为1-16中的任意整数取值平均分布
part3 = ""
if random.random() > 0.7:
part3 = text[
i + pinyin_len : i
+ pinyin_len
+ np.random.choice(range(1, 17))
]
# part4为文本0.30的概率为空
# 不为空则为1-5个连续字符串
# 连续字符串的取值方法为随机从字符库中取一个字符以及该字符后x个字符
# x为2-6中的任意整数取值平均分布
# 使用|将part4中的字符串连接起来
part4 = ""
if random.random() > 0.7:
# 生成1-5个连续字符串
num_strings = random.randint(1, 5)
string_list = []
for _ in range(num_strings):
# 随机选择起始位置
start_pos = random.randint(0, len(text) - 1)
# 随机选择x的值(2-6)
x = random.randint(2, 6)
# 获取连续字符串
end_pos = min(start_pos + x + 1, len(text))
string_list.append(text[start_pos:end_pos])
# 用|连接所有字符串
part4 = "|".join(string_list)
try: try:
labels = [ cont_labels = [
self.query_engine.get_char_info_by_char_pinyin(c, p).id self.query_engine.get_char_info_by_char_pinyin(
for c, p in zip( text[i], pinyin_list[i]
text[i : i + pinyin_len], ).id
pinyin_list[i : i + pinyin_len], for i in cont_positions
)
] ]
except AttributeError as e: except AttributeError as e:
logger.error( logger.error(
f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}" f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
) )
continue continue
if random.random() <= 0.1:
labels.append(0)
encoded = self.tokenizer( # 续接末尾 10% 概率追加 EOS
f"{part4}|{part1}", if random.random() <= 0.1:
part3, cont_labels.append(0)
# part1_cont: 包含已确认前缀的上下文
part1_cont = text[max(0, cont_start - 48) : cont_start]
# part3_cont: 续接目标后的文本
cont_end = cont_positions[-1] + 1
part3_cont = ""
if random.random() > 0.7:
part3_cont = text[
cont_end : cont_end + np.random.choice(range(1, 17))
]
encoded_cont = self.tokenizer(
f"{part4}|{part1_cont}",
part3_cont,
max_length=self.max_seq_length, max_length=self.max_seq_length,
padding="max_length", padding="max_length",
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
return_token_type_ids=True, return_token_type_ids=True,
) )
samples = []
# 历史槽位长度权重:增加长历史采样比例
# 目标分布: H=0-2占45%, H=3-8占55%
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
# 修复变量名冲突将内层循环变量i重命名为label_idx batch_samples = self._add_word_samples(
for label_idx, label in enumerate(labels): batch_samples,
base_repeats = self.adjust_frequency(label) cont_labels,
# 根据历史槽位长度调整采样次数 encoded_cont,
weight = ( part4,
history_weights[label_idx] part1_cont,
if label_idx < len(history_weights) part3_cont,
else 3.0 part2_cont,
) pinyin_ids_cont,
repeats = max(1, int(base_repeats * weight)) )
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
masked_labels = labels[:label_idx]
len_l = len(masked_labels)
masked_labels.extend([0] * (8 - len_l))
samples.extend(
[
{
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
]
* repeats
)
# 添加到缓冲区
batch_samples.extend(samples)
# 处理shuffle buffer - 单缓冲区半保留方案 # 处理shuffle buffer - 单缓冲区半保留方案
if len(batch_samples) >= self.shuffle_buffer_size: if len(batch_samples) >= self.shuffle_buffer_size:
# 全量打乱缓冲区
indices = np.random.permutation(len(batch_samples)) indices = np.random.permutation(len(batch_samples))
# 计算实际保留大小(不超过缓冲区大小)
actual_retention = min(self.retention_size, len(batch_samples)) actual_retention = min(self.retention_size, len(batch_samples))
# 计算输出数量
output_count = len(batch_samples) - actual_retention output_count = len(batch_samples) - actual_retention
# 输出前output_count个样本
for i in range(output_count): for i in range(output_count):
if current_iter_index >= worker_quota: if current_iter_index >= worker_quota:
# 配额用完,清空缓冲区并返回
batch_samples = [] batch_samples = []
return return
yield batch_samples[indices[i]] yield batch_samples[indices[i]]
current_iter_index += 1 current_iter_index += 1
# 保留后actual_retention个样本不清空缓冲区
retained_samples = [ retained_samples = [
batch_samples[idx] for idx in indices[output_count:] batch_samples[idx] for idx in indices[output_count:]
] ]

View File

@ -0,0 +1,169 @@
"""
ONNX导出专用组件
为了支持ONNX导出对原始组件进行修改
1. 移除packed sequence操作
2. 处理动态形状问题
3. 确保所有操作符都ONNX兼容
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ExportPinyinLSTMEncoder(nn.Module):
"""
ONNX兼容的拼音LSTM编码器
简化版本不使用packed sequence
"""
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
self.num_layers = num_layers
self.dropout = dropout
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=self.hidden_dim,
num_layers=num_layers,
bidirectional=True,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, x, mask=None):
"""
ONNX兼容的前向传播
不使用packed sequence改为使用masked计算
"""
# 简化直接使用LSTM不处理padding
# 在ONNX中处理可变长度序列较复杂
# 对于输入法场景拼音长度固定为24所以可以这样处理
output, (hidden, cell) = self.lstm(x)
projected = self.proj(output)
return self.layer_norm(projected)
class ExportContextEncoder(nn.Module):
"""
ONNX兼容的上下文编码器
"""
def __init__(
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
):
super().__init__()
self.dim = dim
self.max_len = max_len
# 使用原始text_emb但需要确保它支持ONNX
from modelscope import AutoModel
self.text_emb = AutoModel.from_pretrained(
"iic/nlp_structbert_backbone_lite_std"
).embeddings
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(max_len, dim)
self.pinyin_pooling = ExportPinyinLSTMEncoder(dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=n_heads,
dim_feedforward=dim * 4,
dropout=0.1,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.ln = nn.LayerNorm(dim)
def forward(self, text_ids, pinyin_ids, mask=None):
"""
ONNX兼容的前向传播
"""
# 文本嵌入
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
# 位置编码 - 使用预计算的pos_ids
seq_len = text_emb.size(1)
# 使用静态最大长度避免动态arange
if seq_len > self.max_len:
seq_len = self.max_len
# 创建位置ID确保在导出时是静态的
# 使用torch.full而不是torch.arange因为arange在动态形状下可能有问题
pos_ids = torch.arange(
seq_len, device=text_ids.device, dtype=torch.long
).unsqueeze(0)
# 如果实际序列长度小于最大长度,截取位置嵌入
if seq_len < self.max_len:
x = text_emb[:, :seq_len, :] + self.pos_emb(pos_ids)
else:
x = text_emb + self.pos_emb(pos_ids)
# 处理mask
if mask is not None:
# 确保mask是bool类型
src_mask = (mask == 0).to(torch.bool)
else:
src_mask = None
# Transformer
H = self.transformer(x, src_key_padding_mask=src_mask)
H = self.ln(H)
# 拼音编码
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
# 简化不传递mask给LSTM
P = self.pinyin_pooling(pinyin_emb) # [B, pinyin_len, dim]
return H, P
def create_export_context_encoder(original_context_encoder):
"""
从原始ContextEncoder创建导出版本
"""
# 获取原始配置
config = {
"vocab_size": getattr(original_context_encoder, "vocab_size", 10019),
"pinyin_vocab_size": original_context_encoder.pinyin_emb.num_embeddings,
"dim": original_context_encoder.dim,
"n_layers": original_context_encoder.transformer.num_layers,
"n_heads": original_context_encoder.transformer.layers[0].self_attn.num_heads,
"max_len": original_context_encoder.pos_emb.num_embeddings,
}
# 创建导出版本
export_encoder = ExportContextEncoder(**config)
# 复制权重
# 复制text_emb权重从原始AutoModel embeddings
# 注意这里假设原始text_emb的结构
state_dict = original_context_encoder.state_dict()
export_state_dict = export_encoder.state_dict()
# 复制匹配的权重
for key in export_state_dict:
if key in state_dict:
export_state_dict[key] = state_dict[key]
# 特殊处理复制position embeddings
if "pos_emb.weight" in export_state_dict and "pos_emb.weight" in state_dict:
# 确保大小匹配
orig_pos_emb = state_dict["pos_emb.weight"]
export_pos_emb = export_state_dict["pos_emb.weight"]
min_len = min(orig_pos_emb.size(0), export_pos_emb.size(0))
export_state_dict["pos_emb.weight"][:min_len] = orig_pos_emb[:min_len]
export_encoder.load_state_dict(export_state_dict)
return export_encoder

291
src/model/export_models.py Normal file
View File

@ -0,0 +1,291 @@
"""
ONNX导出模型定义
定义两个子模型用于ONNX导出
1. ContextEncoderExport: 输入文本和拼音输出上下文编码
2. DecoderExport: 输入上下文编码拼音编码和槽位历史输出logits
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# 修补modelscope以支持AutoModel导入
import sys
import modelscope
if not hasattr(modelscope, "AutoModel"):
from modelscope import Model
modelscope.AutoModel = Model
sys.modules["modelscope"].AutoModel = Model
from .components import (
ContextEncoder,
CrossAttentionFusion,
MoELayer,
SlotMemory,
PinyinLSTMEncoder,
)
class ContextEncoderExport(nn.Module):
"""
上下文编码器导出模型
输入: input_ids, pinyin_ids, attention_mask
输出: context_H, pinyin_P, context_mask, pinyin_mask
"""
def __init__(self, context_encoder: ContextEncoder):
super().__init__()
self.context_encoder = context_encoder
self.dim = context_encoder.dim
# 创建ONNX兼容的拼音LSTM编码器简化版不使用packed sequence
# 复制原始LSTM的参数
original_pooling = context_encoder.pinyin_pooling
self.pinyin_lstm = nn.LSTM(
input_size=original_pooling.input_dim,
hidden_size=original_pooling.hidden_dim,
num_layers=original_pooling.num_layers,
bidirectional=True,
batch_first=True,
dropout=original_pooling.dropout
if original_pooling.num_layers > 1
else 0.0,
)
self.pinyin_proj = nn.Linear(
original_pooling.hidden_dim * 2, original_pooling.input_dim
)
self.pinyin_ln = nn.LayerNorm(original_pooling.input_dim)
# 复制权重
self.pinyin_lstm.load_state_dict(original_pooling.lstm.state_dict())
self.pinyin_proj.load_state_dict(original_pooling.proj.state_dict())
self.pinyin_ln.load_state_dict(original_pooling.layer_norm.state_dict())
def forward(self, input_ids, pinyin_ids, attention_mask):
"""
Args:
input_ids: [batch_size, seq_len]
pinyin_ids: [batch_size, pinyin_len] (pinyin_len固定为24)
attention_mask: [batch_size, seq_len]
Returns:
context_H: [batch_size, seq_len, dim]
pinyin_P: [batch_size, pinyin_len, dim]
context_mask: [batch_size, seq_len] (bool -> int32)
pinyin_mask: [batch_size, pinyin_len] (bool -> int32)
"""
# 获取原始context_encoder的组件
text_emb = self.context_encoder.text_emb
pinyin_emb = self.context_encoder.pinyin_emb
pos_emb = self.context_encoder.pos_emb
transformer = self.context_encoder.transformer
ln = self.context_encoder.ln
# 文本嵌入
text_emb_out = text_emb(input_ids) # [B, seq_len, dim]
# 位置编码 - 修复torch.arange动态形状问题
seq_len = text_emb_out.size(1)
# 使用预计算的位置ID确保静态形状
if seq_len > pos_emb.num_embeddings:
seq_len = pos_emb.num_embeddings
# 创建位置ID - 使用torch.arange但确保是整数张量
pos_ids = torch.arange(
seq_len, device=input_ids.device, dtype=torch.long
).unsqueeze(0)
# 如果实际序列长度小于最大长度,截取
if seq_len < text_emb_out.size(1):
x = text_emb_out[:, :seq_len, :] + pos_emb(pos_ids)
else:
x = text_emb_out + pos_emb(pos_ids)
# 处理mask - 确保bool类型
if attention_mask is not None:
src_mask = (attention_mask == 0).to(torch.bool)
else:
src_mask = None
# Transformer
H = transformer(x, src_key_padding_mask=src_mask)
H = ln(H)
# 恢复原始序列长度(如果需要)
if seq_len < text_emb_out.size(1):
# 填充H到原始长度
original_seq_len = text_emb_out.size(1)
padding = torch.zeros(
H.size(0),
original_seq_len - seq_len,
H.size(2),
device=H.device,
dtype=H.dtype,
)
H = torch.cat([H, padding], dim=1)
# 拼音编码 - 使用ONNX兼容版本
pinyin_emb_out = pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
# 简化的LSTM不使用packed sequence
pinyin_lstm_out, _ = self.pinyin_lstm(pinyin_emb_out)
pinyin_proj_out = self.pinyin_proj(pinyin_lstm_out)
P = self.pinyin_ln(pinyin_proj_out)
# 生成mask转换为int32以便ONNX支持
context_mask = (attention_mask == 0).to(torch.int32) # 1表示padding0表示有效
pinyin_mask = (pinyin_ids == 0).to(torch.int32) # 1表示padding0表示有效
return H, P, context_mask, pinyin_mask
class DecoderExport(nn.Module):
"""
解码器导出模型
输入: context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
输出: logits
"""
def __init__(
self,
slot_memory: SlotMemory,
cross_attn: CrossAttentionFusion,
moe: MoELayer,
slot_attention: nn.Module,
classifier: nn.Module,
num_slots: int = 8,
dim: int = 512,
):
super().__init__()
self.slot_memory = slot_memory
self.cross_attn = cross_attn
self.moe = moe
self.slot_attention = slot_attention
self.classifier = classifier
self.num_slots = num_slots
self.dim = dim
def forward(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
Args:
context_H: [batch_size, seq_len, dim]
pinyin_P: [batch_size, pinyin_len, dim]
history_slot_ids: [batch_size, num_slots]
context_mask: [batch_size, seq_len] (int32, 1表示padding)
pinyin_mask: [batch_size, pinyin_len] (int32, 1表示padding)
Returns:
logits: [batch_size, vocab_size]
"""
batch_size = context_H.size(0)
# 确保history_slot_ids形状正确
history_slot_ids = history_slot_ids.view(batch_size, self.num_slots)
# 1. 槽位记忆
S = self.slot_memory(history_slot_ids) # [batch_size, num_slots, dim]
# 2. 交叉注意力融合
# 转换mask: int32 -> bool (ONNX中需要bool类型)
context_mask_bool = context_mask.to(torch.bool)
pinyin_mask_bool = pinyin_mask.to(torch.bool)
fused = self.cross_attn(
S,
context_H,
pinyin_P,
context_mask=context_mask_bool,
pinyin_mask=pinyin_mask_bool,
) # [batch_size, num_slots, dim]
# 3. MoE层
moe_out = self.moe(fused) # [batch_size, num_slots, dim]
# 4. 槽位注意力池化
slot_scores = self.slot_attention(moe_out).squeeze(
-1
) # [batch_size, num_slots]
slot_weights = torch.softmax(slot_scores, dim=1) # [batch_size, num_slots]
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch_size, dim]
# 5. 分类头
logits = self.classifier(pooled) # [batch_size, vocab_size]
return logits
def create_export_models_from_checkpoint(checkpoint_path, device="cpu"):
"""
从checkpoint创建导出模型
Args:
checkpoint_path: 模型checkpoint路径
device: 加载设备
Returns:
context_encoder_export: ContextEncoderExport实例
decoder_export: DecoderExport实例
model_config: 模型配置字典
"""
# 加载原始模型配置
checkpoint = torch.load(checkpoint_path, map_location=device)
# 提取模型配置从checkpoint或使用默认值
if "config" in checkpoint:
config = checkpoint["config"]
else:
# 使用模型默认配置
config = {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
}
# 创建原始模型
from .model import InputMethodEngine
model = InputMethodEngine(
vocab_size=config.get("vocab_size", 10019),
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
dim=config.get("dim", 512),
num_slots=config.get("num_slots", 8),
n_layers=config.get("n_layers", 4),
n_heads=config.get("n_heads", 4),
num_experts=config.get("num_experts", 10),
max_seq_len=config.get("max_seq_len", 128),
compile=False,
)
# 加载权重
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
# 创建导出模型
context_encoder_export = ContextEncoderExport(model.context_encoder)
decoder_export = DecoderExport(
slot_memory=model.slot_memory,
cross_attn=model.cross_attn,
moe=model.moe,
slot_attention=model.slot_attention,
classifier=model.classifier,
num_slots=model.num_slots,
dim=model.dim,
)
# 设置为评估模式
context_encoder_export.eval()
decoder_export.eval()
return context_encoder_export, decoder_export, config

View File

@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module):
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads) self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE) # 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8) self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
# 5. 槽位注意力池化 # 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1) self.slot_attention = nn.Linear(dim, 1)

View File

@ -1221,7 +1221,7 @@ def train(
max_seq_length=max_seq_len, max_seq_length=max_seq_len,
text_field="text", text_field="text",
py_style_weight=(9, 2, 1), py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000, shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
) )

24
test.py
View File

@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常" part1 = "从中国回家后,他觉得世界上最好的城市就是"
part2 = "piaoliang" part2 = "shanghai"
pinyin_ids = text_to_pinyin_ids(part2) pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids) len_py = len(pinyin_ids)
if len_py < 24: if len_py < 24:
@ -56,9 +56,9 @@ if len_py < 24:
else: else:
pinyin_ids = pinyin_ids[:24] pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0] masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
part3 = "" part3 = ""
part4 = "可行|特别|伤害" part4 = ""
encoded = tokenizer( encoded = tokenizer(
f"{part4}|{part1}", f"{part4}|{part1}",
@ -83,8 +83,9 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False) model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu") checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
input_ids = sample["input_ids"] input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"] token_type_ids = sample["token_type_ids"]
@ -97,12 +98,13 @@ for k, v in sample.items():
print(f"{k}: {v}") print(f"{k}: {v}")
start = time.time() start = time.time()
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) with torch.no_grad():
print(f'计算时长: {(time.time() - start) * 1000:4f}ms') res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted( print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True sort_res = sorted(
) [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
print(sort_res[0:5]) )
print(sort_res[0:5])
query_engine = QueryEngine() query_engine = QueryEngine()
query_engine.load() query_engine.load()

101
test_dataset.py Normal file
View File

@ -0,0 +1,101 @@
import sys
sys.path.append("src")
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.model import InputMethodEngine
from model.query import QueryEngine
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
from model.dataset import PinyinInputDataset
def worker_init_fn(worker_id: int) -> None:
"""
初始化每个DataLoader worker的随机种子确保可复现性
Args:
worker_id: worker的ID
"""
worker_seed = torch.initial_seed() % (2**32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
Args:
batch: 样本列表每个样本是一个字典
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"labels": labels,
"history_slot_ids": history_slot_ids,
"prefix": prefixes,
"suffix": suffixes,
"pinyin": pinyins,
"pinyin_ids": pinyin_ids,
}
train_dataset = PinyinInputDataset(
data_path="/home/songsenand/Data/corpus/CCI-Data/",
max_workers=-1, # 自动选择worker数量
max_iter_length=1000000,
text_field="text",
py_style_weight=(90, 2, 1),
shuffle_buffer_size=20000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
dataloader = DataLoader(
train_dataset,
batch_size=512,
num_workers=2,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
)
for i, shape in tqdm(enumerate(dataloader), total=1000000/512):
pass

409
verify_onnx.py Normal file
View File

@ -0,0 +1,409 @@
#!/usr/bin/env python3
"""
ONNX模型验证脚本
验证导出的ONNX模型与原始PyTorch模型输出的一致性
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import torch
import onnxruntime as ort
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.export_models import create_export_models_from_checkpoint
def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5):
"""
比较PyTorch和ONNX输出
Args:
pytorch_output: PyTorch张量
onnx_output: ONNX Runtime输出numpy数组
name: 输出名称用于错误信息
rtol: 相对容差
atol: 绝对容差
Returns:
bool: 是否匹配
"""
# 转换PyTorch输出为numpy
if isinstance(pytorch_output, torch.Tensor):
pytorch_np = pytorch_output.detach().cpu().numpy()
else:
pytorch_np = np.array(pytorch_output)
# 确保形状一致
if pytorch_np.shape != onnx_output.shape:
print(
f"{name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}"
)
return False
# 计算差异
diff = np.abs(pytorch_np - onnx_output)
max_diff = np.max(diff)
mean_diff = np.mean(diff)
# 检查是否在容差范围内
is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol)
if is_close:
print(f"{name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
else:
print(f"{name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print(
f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], "
f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]"
)
return is_close
def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"):
"""
验证上下文编码器
Args:
checkpoint_path: PyTorch checkpoint路径
onnx_path: ONNX模型路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 验证上下文编码器: {onnx_path}")
# 加载PyTorch模型
context_encoder_export, _, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
# 创建ONNX Runtime会话
session = ort.InferenceSession(
onnx_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 2 # 使用batch_size=2测试动态批处理
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
# 随机屏蔽一些位置
attention_mask[:, seq_len // 2 :] = 0
# PyTorch推理
with torch.no_grad():
pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask)
# ONNX推理
onnx_inputs = {
"input_ids": input_ids.numpy(),
"pinyin_ids": pinyin_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
onnx_outputs = session.run(None, onnx_inputs)
# 比较输出
output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"]
all_match = True
for i, name in enumerate(output_names):
if i < len(pytorch_outputs) and i < len(onnx_outputs):
match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name)
all_match = all_match and match
return all_match
def verify_decoder(checkpoint_path, onnx_path, device="cpu"):
"""
验证解码器
Args:
checkpoint_path: PyTorch checkpoint路径
onnx_path: ONNX模型路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 验证解码器: {onnx_path}")
# 加载PyTorch模型
_, decoder_export, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
# 创建ONNX Runtime会话
session = ort.InferenceSession(
onnx_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long)
# PyTorch推理
with torch.no_grad():
pytorch_output = decoder_export(
context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
)
# ONNX推理
onnx_inputs = {
"context_H": context_H.numpy(),
"pinyin_P": pinyin_P.numpy(),
"history_slot_ids": history_slot_ids.numpy(),
"context_mask": context_mask.numpy(),
"pinyin_mask": pinyin_mask.numpy(),
}
onnx_outputs = session.run(None, onnx_inputs)
# 比较输出
return compare_outputs(pytorch_output, onnx_outputs[0], "logits")
def verify_end_to_end(
checkpoint_path, context_encoder_path, decoder_path, device="cpu"
):
"""
端到端验证比较完整推理流程
Args:
checkpoint_path: PyTorch checkpoint路径
context_encoder_path: 上下文编码器ONNX路径
decoder_path: 解码器ONNX路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 端到端验证")
# 加载原始PyTorch模型
from src.model.model import InputMethodEngine
checkpoint = torch.load(checkpoint_path, map_location=device)
if "config" in checkpoint:
config = checkpoint["config"]
else:
config = {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
}
model = InputMethodEngine(
vocab_size=config.get("vocab_size", 10019),
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
dim=config.get("dim", 512),
num_slots=config.get("num_slots", 8),
n_layers=config.get("n_layers", 4),
n_heads=config.get("n_heads", 4),
num_experts=config.get("num_experts", 10),
max_seq_len=config.get("max_seq_len", 128),
compile=False,
)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
# 创建ONNX Runtime会话
context_session = ort.InferenceSession(
context_encoder_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
decoder_session = ort.InferenceSession(
decoder_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 1
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
# PyTorch完整推理
with torch.no_grad():
pytorch_logits = model(
input_ids=input_ids,
token_type_ids=torch.zeros_like(input_ids), # 简化处理
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# ONNX推理流程
# 1. 上下文编码器
context_inputs = {
"input_ids": input_ids.numpy(),
"pinyin_ids": pinyin_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
context_outputs = context_session.run(None, context_inputs)
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
# 2. 解码器
decoder_inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids.numpy(),
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
onnx_outputs = decoder_session.run(None, decoder_inputs)
# 比较输出
return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits")
def main():
parser = argparse.ArgumentParser(description="ONNX模型验证")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径"
)
parser.add_argument(
"--context-encoder",
type=str,
help="上下文编码器ONNX路径默认: ./exported_models/context_encoder.onnx",
)
parser.add_argument(
"--decoder",
type=str,
help="解码器ONNX路径默认: ./exported_models/decoder.onnx",
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="导出目录(如果未指定单个模型路径)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="验证设备(默认: cpu",
)
parser.add_argument(
"--skip-context", action="store_true", help="跳过上下文编码器验证"
)
parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证")
parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证")
args = parser.parse_args()
# 确定模型路径
if args.context_encoder:
context_encoder_path = args.context_encoder
else:
context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx")
if args.decoder:
decoder_path = args.decoder
else:
decoder_path = os.path.join(args.output_dir, "decoder.onnx")
print("🔬 ONNX模型验证")
print("=" * 60)
print(f"Checkpoint: {args.checkpoint}")
print(f"上下文编码器: {context_encoder_path}")
print(f"解码器: {decoder_path}")
print(f"设备: {args.device}")
print()
all_pass = True
# 验证上下文编码器
if not args.skip_context and os.path.exists(context_encoder_path):
if verify_context_encoder(args.checkpoint, context_encoder_path, args.device):
print("✅ 上下文编码器验证通过")
else:
print("❌ 上下文编码器验证失败")
all_pass = False
elif not args.skip_context:
print("⚠️ 上下文编码器文件不存在,跳过验证")
# 验证解码器
if not args.skip_decoder and os.path.exists(decoder_path):
if verify_decoder(args.checkpoint, decoder_path, args.device):
print("✅ 解码器验证通过")
else:
print("❌ 解码器验证失败")
all_pass = False
elif not args.skip_decoder:
print("⚠️ 解码器文件不存在,跳过验证")
# 端到端验证
if (
not args.skip_end_to_end
and os.path.exists(context_encoder_path)
and os.path.exists(decoder_path)
):
if verify_end_to_end(
args.checkpoint, context_encoder_path, decoder_path, args.device
):
print("✅ 端到端验证通过")
else:
print("❌ 端到端验证失败")
all_pass = False
print("\n" + "=" * 60)
if all_pass:
print("🎉 所有验证通过ONNX模型与PyTorch模型输出一致")
else:
print("❌ 部分验证失败,请检查模型导出过程")
sys.exit(1)
if __name__ == "__main__":
main()