Compare commits
No commits in common. "numresblock18" and "main" have entirely different histories.
numresbloc
...
main
|
|
@ -177,9 +177,3 @@ cython_debug/
|
|||
uv.lock
|
||||
|
||||
data/*
|
||||
|
||||
|
||||
**/*.onnx
|
||||
**/*.data
|
||||
**/*.npz
|
||||
**/*.pt
|
||||
|
|
|
|||
105
README.md
105
README.md
|
|
@ -790,7 +790,112 @@ train-model evaluate \
|
|||
- 在评估数据集上计算准确率、困惑度等指标
|
||||
- 生成详细的性能报告
|
||||
|
||||
### 6.8 模型扩容两阶段训练
|
||||
|
||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||
|
||||
#### 训练策略
|
||||
|
||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "model:InputMethodEngine" \
|
||||
--num-experts 40 \
|
||||
--frozen-lr 2e-3 \
|
||||
--full-lr 5e-5 \
|
||||
--frozen-patience 8
|
||||
```
|
||||
|
||||
#### 完整参数示例
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--output-dir "./expansion_output" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "custom_model:ExpandedModel" \
|
||||
--vocab-size 10019 \
|
||||
--dim 512 \
|
||||
--num-experts 40 \
|
||||
--frozen-patience 10 \
|
||||
--frozen-lr 1e-3 \
|
||||
--full-lr 1e-4 \
|
||||
--frozen-scheduler cosine \
|
||||
--full-scheduler cosine \
|
||||
--batch-size 128 \
|
||||
--num-epochs 20 \
|
||||
--compile
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
**模型扩容参数**
|
||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||
|
||||
**两阶段训练参数**
|
||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
|
||||
**其他参数**
|
||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||
|
||||
#### 使用场景
|
||||
|
||||
1. **增加专家数量**(20→40)
|
||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||
- 新增参数:新专家网络、gate层
|
||||
|
||||
2. **增加top_k值**(2→3)
|
||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||
- 新增参数:无
|
||||
|
||||
3. **修改专家内部结构**(如增加resblocks)
|
||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||
- 新增参数:新增的resblocks层
|
||||
|
||||
4. **增加Transformer层数**(4→5)
|
||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||
- 新增参数:新增的第5层
|
||||
|
||||
#### 自定义模型类示例
|
||||
|
||||
```python
|
||||
# my_model.py
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
class MyExpandedModel(InputMethodEngine):
|
||||
def __init__(self, num_experts=40, **kwargs):
|
||||
# 调用父类构造函数,覆盖num_experts参数
|
||||
super().__init__(num_experts=num_experts, **kwargs)
|
||||
# 可以在这里添加额外的层或修改现有层
|
||||
|
||||
# 使用命令
|
||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||
|
||||
### 6.9 导出模型(开发中)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,436 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
束搜索算法演示
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行束搜索推理,
|
||||
模拟输入法场景:给定上下文、拼音和已确认字符,生成候选汉字序列。
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
|
||||
class ONNXBeamSearch:
|
||||
"""
|
||||
基于ONNX模型的束搜索解码器
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, context_encoder_path: str, decoder_path: str, device: str = "cpu"
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
device: 推理设备(cpu或cuda)
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
# 配置ONNX Runtime提供程序
|
||||
if device == "cuda":
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_session = ort.InferenceSession(
|
||||
context_encoder_path, providers=providers
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(decoder_path, providers=providers)
|
||||
|
||||
# 获取输入输出信息
|
||||
self.context_input_info = {
|
||||
input.name: input for input in self.context_session.get_inputs()
|
||||
}
|
||||
self.context_output_info = {
|
||||
output.name: output for output in self.context_session.get_outputs()
|
||||
}
|
||||
self.decoder_input_info = {
|
||||
input.name: input for input in self.decoder_session.get_inputs()
|
||||
}
|
||||
self.decoder_output_info = {
|
||||
output.name: output for output in self.decoder_session.get_outputs()
|
||||
}
|
||||
|
||||
print("📦 ONNX模型加载完成")
|
||||
print(f" 上下文编码器输入: {list(self.context_input_info.keys())}")
|
||||
print(f" 上下文编码器输出: {list(self.context_output_info.keys())}")
|
||||
print(f" 解码器输入: {list(self.decoder_input_info.keys())}")
|
||||
print(f" 解码器输出: {list(self.decoder_output_info.keys())}")
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
context_prompts: List[str] = None,
|
||||
tokenizer=None,
|
||||
query_engine=None,
|
||||
max_seq_len: int = 128,
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
准备模型输入
|
||||
|
||||
注意: 这是一个简化的版本,实际应用中需要实现完整的预处理逻辑
|
||||
这里使用随机数据模拟,实际应使用tokenizer和查询引擎
|
||||
|
||||
Args:
|
||||
text_before: 光标前文本
|
||||
text_after: 光标后文本
|
||||
pinyin: 拼音输入
|
||||
context_prompts: 上下文提示
|
||||
tokenizer: 分词器(未实现)
|
||||
query_engine: 查询引擎(未实现)
|
||||
max_seq_len: 最大序列长度
|
||||
|
||||
Returns:
|
||||
输入字典,包含input_ids, pinyin_ids, attention_mask
|
||||
"""
|
||||
# 在实际应用中,这里应该实现:
|
||||
# 1. 使用tokenizer将文本转换为input_ids
|
||||
# 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids
|
||||
# 3. 构建attention_mask
|
||||
|
||||
# 简化为随机数据(用于演示)
|
||||
batch_size = 1
|
||||
seq_len = min(max_seq_len, 64) # 简化长度
|
||||
|
||||
# 模拟输入
|
||||
input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64)
|
||||
pinyin_ids = np.random.randint(
|
||||
0, 30, (batch_size, 24), dtype=np.int64
|
||||
) # 固定24长度
|
||||
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
|
||||
|
||||
# 在实际应用中,应该这样处理拼音:
|
||||
# from src.model.dataset import text_to_pinyin_ids
|
||||
# pinyin_ids_list = text_to_pinyin_ids(pinyin)
|
||||
# pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def run_context_encoder(
|
||||
self, inputs: Dict[str, np.ndarray]
|
||||
) -> Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
inputs: 输入字典
|
||||
|
||||
Returns:
|
||||
上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask)
|
||||
"""
|
||||
# 准备ONNX输入(确保顺序正确)
|
||||
onnx_inputs = {}
|
||||
for input_name in self.context_input_info.keys():
|
||||
if input_name in inputs:
|
||||
onnx_inputs[input_name] = inputs[input_name]
|
||||
else:
|
||||
# 对于缺失的输入,使用默认值
|
||||
shape = self.context_input_info[input_name].shape
|
||||
dtype = self.context_input_info[input_name].type
|
||||
# 简化处理:创建零数组
|
||||
if "int" in dtype:
|
||||
onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64)
|
||||
else:
|
||||
onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32)
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_session.run(None, onnx_inputs)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
history_slot_ids: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码 [batch, seq_len, 512]
|
||||
pinyin_P: 拼音编码 [batch, 24, 512]
|
||||
history_slot_ids: 历史槽位IDs [batch, 8]
|
||||
context_mask: 上下文掩码 [batch, seq_len]
|
||||
pinyin_mask: 拼音掩码 [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits [batch, vocab_size]
|
||||
"""
|
||||
inputs = {
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(None, inputs)
|
||||
|
||||
return outputs[0]
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
beam_size: int = 5,
|
||||
max_length: int = 8,
|
||||
vocab_size: int = 10019,
|
||||
temperature: float = 1.0,
|
||||
length_penalty: float = 0.6,
|
||||
) -> List[Tuple[List[int], float]]:
|
||||
"""
|
||||
束搜索算法
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
temperature: 温度参数(控制随机性)
|
||||
length_penalty: 长度惩罚系数
|
||||
|
||||
Returns:
|
||||
排序后的候选序列列表,每个元素为(序列, 分数)
|
||||
"""
|
||||
# 初始束:空序列,对数概率为0
|
||||
beams = [([], 0.0)] # (序列, 累计对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids
|
||||
if len(seq) < 8:
|
||||
# 填充到8个槽位
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
# 只保留最近8个字符
|
||||
history = seq[-8:]
|
||||
|
||||
history_array = np.array([history], dtype=np.int64)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_array, context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 应用温度
|
||||
logits = logits / temperature
|
||||
|
||||
# 转换为概率
|
||||
probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy()
|
||||
|
||||
# 获取top-k候选(k = beam_size * 2 以增加多样性)
|
||||
top_k = min(beam_size * 2, vocab_size)
|
||||
top_indices = np.argsort(probs)[-top_k:][::-1]
|
||||
top_probs = probs[top_indices]
|
||||
|
||||
# 扩展束
|
||||
for idx, prob in zip(top_indices, top_probs):
|
||||
new_seq = seq + [int(idx)]
|
||||
# 更新分数:累计对数概率 + log(prob)
|
||||
new_score = score + np.log(prob + 1e-10)
|
||||
|
||||
# 应用长度惩罚
|
||||
length_penalized_score = new_score / (
|
||||
(5 + len(new_seq)) ** length_penalty
|
||||
)
|
||||
|
||||
new_beams.append((new_seq, length_penalized_score, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序
|
||||
beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]]
|
||||
|
||||
# 检查是否所有序列都已结束(结束符ID为0)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
# 返回原始分数(未应用长度惩罚)
|
||||
return beams
|
||||
|
||||
def interactive_beam_search(self):
|
||||
"""交互式束搜索演示"""
|
||||
print("\n" + "=" * 60)
|
||||
print("束搜索交互演示")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n📝 输入法场景模拟:")
|
||||
print(" 假设用户正在输入拼音 'nihao',已确认第一个字 '你'")
|
||||
print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'")
|
||||
print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)")
|
||||
print("-" * 60)
|
||||
|
||||
# 模拟输入(实际应用中应从用户获取)
|
||||
text_before = "今天天气很好"
|
||||
text_after = "我们去公园玩"
|
||||
pinyin = "hao" # 继续输入"hao"
|
||||
slot_chars = ["你"] # 已确认的字符
|
||||
context_prompts = ["张三", "李四"]
|
||||
|
||||
print(f"\n📋 输入参数:")
|
||||
print(f" 光标前文本: '{text_before}'")
|
||||
print(f" 光标后文本: '{text_after}'")
|
||||
print(f" 拼音: '{pinyin}'")
|
||||
print(f" 槽位历史: {slot_chars}")
|
||||
print(f" 上下文提示: {context_prompts}")
|
||||
|
||||
# 准备输入(简化版,使用随机数据)
|
||||
print(f"\n🔧 准备模型输入...")
|
||||
inputs = self.prepare_inputs(
|
||||
text_before=text_before,
|
||||
text_after=text_after,
|
||||
pinyin=pinyin,
|
||||
context_prompts=context_prompts,
|
||||
)
|
||||
|
||||
# 运行上下文编码器
|
||||
print(f"🧠 运行上下文编码器...")
|
||||
context_outputs = self.run_context_encoder(inputs)
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
|
||||
|
||||
print(f"✅ 上下文编码完成")
|
||||
print(f" context_H形状: {context_H.shape}")
|
||||
print(f" pinyin_P形状: {pinyin_P.shape}")
|
||||
|
||||
# 将槽位历史转换为ID(简化:使用随机ID)
|
||||
# 实际应用中应使用query_engine将汉字转换为ID
|
||||
slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42
|
||||
if len(slot_ids) < 8:
|
||||
slot_ids = slot_ids + [0] * (8 - len(slot_ids))
|
||||
|
||||
# 运行束搜索
|
||||
print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...")
|
||||
beams = self.beam_search(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
beam_size=3,
|
||||
max_length=4,
|
||||
vocab_size=10019,
|
||||
)
|
||||
|
||||
# 显示结果
|
||||
print(f"\n🏆 束搜索结果:")
|
||||
print("-" * 50)
|
||||
for i, (seq, score) in enumerate(beams):
|
||||
# 将ID序列转换为汉字(简化:显示ID)
|
||||
seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq])
|
||||
print(f"{i + 1}. 序列: [{seq_str}]")
|
||||
print(f" 对数概率: {score:.4f}")
|
||||
print(f" 概率: {np.exp(score):.6f}")
|
||||
print()
|
||||
|
||||
print("📝 说明:")
|
||||
print(" - 'END' 表示结束符 (ID: 0)")
|
||||
print(" - 实际应用中应将ID转换为汉字")
|
||||
print(" - 最高分数的序列作为最终预测")
|
||||
|
||||
return beams
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="束搜索算法演示")
|
||||
parser.add_argument(
|
||||
"--context-encoder",
|
||||
type=str,
|
||||
default="./exported_models/context_encoder.onnx",
|
||||
help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
default="./exported_models/decoder.onnx",
|
||||
help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="推理设备(默认: cpu)",
|
||||
)
|
||||
parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3)")
|
||||
parser.add_argument(
|
||||
"--max-length", type=int, default=4, help="最大生成长度(默认: 4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="交互模式(默认: True)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if not os.path.exists(args.context_encoder):
|
||||
print(f"❌ 上下文编码器文件不存在: {args.context_encoder}")
|
||||
print(f" 请先运行 export_onnx.py 导出模型")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.decoder):
|
||||
print(f"❌ 解码器文件不存在: {args.decoder}")
|
||||
print(f" 请先运行 export_onnx.py 导出模型")
|
||||
return
|
||||
|
||||
print("🚀 束搜索算法演示")
|
||||
print("=" * 60)
|
||||
print(f"上下文编码器: {args.context_encoder}")
|
||||
print(f"解码器: {args.decoder}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"束大小: {args.beam_size}")
|
||||
print(f"最大长度: {args.max_length}")
|
||||
|
||||
# 初始化束搜索器
|
||||
beam_searcher = ONNXBeamSearch(
|
||||
context_encoder_path=args.context_encoder,
|
||||
decoder_path=args.decoder,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.interactive:
|
||||
# 运行交互演示
|
||||
beam_searcher.interactive_beam_search()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 演示完成")
|
||||
print("\n💡 下一步:")
|
||||
print(" 1. 实现完整的输入预处理(tokenizer和拼音转换)")
|
||||
print(" 2. 集成查询引擎以将ID转换为汉字")
|
||||
print(" 3. 根据实际场景调整束搜索参数")
|
||||
print(" 4. 性能优化:批量处理、缓存等")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
docs/TRAINING.md
107
docs/TRAINING.md
|
|
@ -484,6 +484,113 @@ train-model evaluate \
|
|||
- 在评估数据集上计算准确率、困惑度等指标
|
||||
- 生成详细的性能报告
|
||||
|
||||
### 模型扩容两阶段训练
|
||||
|
||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||
|
||||
#### 训练策略
|
||||
|
||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "model:InputMethodEngine" \
|
||||
--num-experts 40 \
|
||||
--frozen-lr 2e-3 \
|
||||
--full-lr 5e-5 \
|
||||
--frozen-patience 8
|
||||
```
|
||||
|
||||
#### 完整参数示例
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--output-dir "./expansion_output" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "custom_model:ExpandedModel" \
|
||||
--vocab-size 10019 \
|
||||
--dim 512 \
|
||||
--num-experts 40 \
|
||||
--frozen-patience 10 \
|
||||
--frozen-lr 1e-3 \
|
||||
--full-lr 1e-4 \
|
||||
--frozen-scheduler cosine \
|
||||
--full-scheduler cosine \
|
||||
--batch-size 128 \
|
||||
--num-epochs 20 \
|
||||
--compile
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
**模型扩容参数**
|
||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||
|
||||
**两阶段训练参数**
|
||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
|
||||
**其他参数**
|
||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||
|
||||
#### 使用场景
|
||||
|
||||
1. **增加专家数量**(20→40)
|
||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||
- 新增参数:新专家网络、gate层
|
||||
|
||||
2. **增加top_k值**(2→3)
|
||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||
- 新增参数:无
|
||||
|
||||
3. **修改专家内部结构**(如增加resblocks)
|
||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||
- 新增参数:新增的resblocks层
|
||||
|
||||
4. **增加Transformer层数**(4→5)
|
||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||
- 新增参数:新增的第5层
|
||||
|
||||
#### 自定义模型类示例
|
||||
|
||||
```python
|
||||
# my_model.py
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
class MyExpandedModel(InputMethodEngine):
|
||||
def __init__(self, num_experts=40, **kwargs):
|
||||
# 调用父类构造函数,覆盖num_experts参数
|
||||
super().__init__(num_experts=num_experts, **kwargs)
|
||||
# 可以在这里添加额外的层或修改现有层
|
||||
|
||||
# 使用命令
|
||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||
|
||||
### 导出模型(开发中)
|
||||
|
||||
当前导出功能尚在开发中:
|
||||
|
|
|
|||
112
export.record
112
export.record
File diff suppressed because one or more lines are too long
562
export_onnx.py
562
export_onnx.py
|
|
@ -1,562 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
输入法模型ONNX导出脚本
|
||||
|
||||
将模型导出为两个ONNX部分:
|
||||
1. context_encoder.onnx - 上下文编码器
|
||||
2. decoder.onnx - 解码器
|
||||
|
||||
使用方法:
|
||||
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
|
||||
|
||||
依赖安装:
|
||||
pip install onnx onnxruntime
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# 检查ONNX是否可用
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError:
|
||||
ONNX_AVAILABLE = False
|
||||
print("警告: ONNX或ONNX Runtime未安装")
|
||||
print("请运行: pip install onnx onnxruntime")
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from src.model.export_models import create_export_models_from_checkpoint
|
||||
|
||||
|
||||
def check_onnx_available():
|
||||
"""检查ONNX依赖是否可用"""
|
||||
if not ONNX_AVAILABLE:
|
||||
print("错误: ONNX导出需要以下依赖:")
|
||||
print(" pip install onnx onnxruntime")
|
||||
print("请安装后重试")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def export_context_encoder(model, output_path, config):
|
||||
"""
|
||||
导出上下文编码器为ONNX格式
|
||||
|
||||
Args:
|
||||
model: ContextEncoderExport实例
|
||||
output_path: 输出路径
|
||||
config: 模型配置
|
||||
"""
|
||||
print(f"正在导出上下文编码器到: {output_path}")
|
||||
|
||||
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
|
||||
batch_size = 2
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24 # 固定长度
|
||||
dim = config.get("dim", 512)
|
||||
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
|
||||
# 设置动态轴
|
||||
dynamic_axes = {
|
||||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_ids": {0: "batch_size"},
|
||||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"context_H": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_P": {0: "batch_size"},
|
||||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_mask": {0: "batch_size"},
|
||||
}
|
||||
|
||||
# 导出模型
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_ids, pinyin_ids, attention_mask),
|
||||
output_path,
|
||||
input_names=["input_ids", "pinyin_ids", "attention_mask"],
|
||||
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=18, # 支持现代操作符
|
||||
do_constant_folding=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
print("✅ 上下文编码器导出完成")
|
||||
|
||||
# 验证导出
|
||||
try:
|
||||
onnx_model = onnx.load(output_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print("✅ ONNX模型验证通过")
|
||||
except Exception as e:
|
||||
print(f"⚠️ ONNX模型验证警告: {e}")
|
||||
|
||||
return input_ids, pinyin_ids, attention_mask
|
||||
|
||||
|
||||
def export_decoder(model, output_path, config, example_inputs=None):
|
||||
"""
|
||||
导出解码器为ONNX格式
|
||||
|
||||
Args:
|
||||
model: DecoderExport实例
|
||||
output_path: 输出路径
|
||||
config: 模型配置
|
||||
example_inputs: 示例输入(用于验证一致性)
|
||||
"""
|
||||
print(f"正在导解码器到: {output_path}")
|
||||
|
||||
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
|
||||
batch_size = 2
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
dim = config.get("dim", 512)
|
||||
num_slots = config.get("num_slots", 8)
|
||||
|
||||
# 使用提供的示例输入或创建新的
|
||||
if example_inputs is not None:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
|
||||
batch_size = context_H.size(0) # 使用实际batch size
|
||||
history_slot_ids = torch.randint(
|
||||
0, 100, (batch_size, num_slots), dtype=torch.long
|
||||
)
|
||||
else:
|
||||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||||
history_slot_ids = torch.randint(
|
||||
0, 100, (batch_size, num_slots), dtype=torch.long
|
||||
)
|
||||
|
||||
# 设置动态轴
|
||||
dynamic_axes = {
|
||||
"context_H": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_P": {0: "batch_size"},
|
||||
"history_slot_ids": {0: "batch_size"},
|
||||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_mask": {0: "batch_size"},
|
||||
"logits": {0: "batch_size"},
|
||||
}
|
||||
|
||||
# 导出模型
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
|
||||
output_path,
|
||||
input_names=[
|
||||
"context_H",
|
||||
"pinyin_P",
|
||||
"history_slot_ids",
|
||||
"context_mask",
|
||||
"pinyin_mask",
|
||||
],
|
||||
output_names=["logits"],
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
print("✅ 解码器导出完成")
|
||||
|
||||
# 验证导出
|
||||
try:
|
||||
onnx_model = onnx.load(output_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print("✅ ONNX模型验证通过")
|
||||
except Exception as e:
|
||||
print(f"⚠️ ONNX模型验证警告: {e}")
|
||||
|
||||
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||
|
||||
|
||||
def save_example_inputs(output_dir, example_inputs_dict):
|
||||
"""
|
||||
保存示例输入为NPZ文件,用于验证
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
example_inputs_dict: 示例输入字典
|
||||
"""
|
||||
npz_path = os.path.join(output_dir, "example_inputs.npz")
|
||||
|
||||
# 转换为numpy数组
|
||||
np_data = {}
|
||||
for key, tensor in example_inputs_dict.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
np_data[key] = tensor.cpu().numpy()
|
||||
elif isinstance(tensor, tuple):
|
||||
for i, t in enumerate(tensor):
|
||||
if isinstance(t, torch.Tensor):
|
||||
np_data[f"{key}_{i}"] = t.cpu().numpy()
|
||||
|
||||
np.savez(npz_path, **np_data)
|
||||
print(f"✅ 示例输入已保存到: {npz_path}")
|
||||
|
||||
# 同时保存为PyTorch格式
|
||||
torch_path = os.path.join(output_dir, "example_inputs.pt")
|
||||
torch.save(example_inputs_dict, torch_path)
|
||||
print(f"✅ PyTorch示例输入已保存到: {torch_path}")
|
||||
|
||||
|
||||
def create_inference_example(output_dir, config):
|
||||
"""
|
||||
创建推理示例脚本
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
config: 模型配置
|
||||
"""
|
||||
example_path = os.path.join(output_dir, "inference_example.py")
|
||||
|
||||
example_code = '''#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
with open(example_path, "w", encoding="utf-8") as f:
|
||||
f.write(example_code)
|
||||
|
||||
print(f"✅ 推理示例脚本已保存到: {example_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
|
||||
parser.add_argument(
|
||||
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
type=str,
|
||||
default="./exported_models",
|
||||
help="输出目录(默认: ./exported_models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="导出设备(默认: cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查依赖
|
||||
if not check_onnx_available():
|
||||
sys.exit(1)
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"📁 输出目录: {output_dir.absolute()}")
|
||||
|
||||
# 加载checkpoint并创建导出模型
|
||||
print(f"📦 加载checkpoint: {args.checkpoint}")
|
||||
context_encoder_export, decoder_export, config = (
|
||||
create_export_models_from_checkpoint(args.checkpoint, args.device)
|
||||
)
|
||||
|
||||
print(f"📊 模型配置: {config}")
|
||||
|
||||
# 导出上下文编码器
|
||||
context_encoder_path = output_dir / "context_encoder.onnx"
|
||||
example_inputs = export_context_encoder(
|
||||
context_encoder_export, str(context_encoder_path), config
|
||||
)
|
||||
|
||||
# 使用上下文编码器的输出作为解码器的示例输入
|
||||
with torch.no_grad():
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# 导出解码器
|
||||
decoder_path = output_dir / "decoder.onnx"
|
||||
export_decoder(
|
||||
decoder_export,
|
||||
str(decoder_path),
|
||||
config,
|
||||
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
|
||||
)
|
||||
|
||||
# 保存示例输入
|
||||
example_inputs_dict = {
|
||||
"input_ids": example_inputs[0],
|
||||
"pinyin_ids": example_inputs[1],
|
||||
"attention_mask": example_inputs[2],
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
save_example_inputs(output_dir, example_inputs_dict)
|
||||
|
||||
# 创建推理示例脚本
|
||||
create_inference_example(output_dir, config)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 ONNX导出完成!")
|
||||
print("=" * 60)
|
||||
print(f"生成的模型文件:")
|
||||
print(f" - {context_encoder_path}")
|
||||
print(f" - {decoder_path}")
|
||||
print(f" - {output_dir}/example_inputs.npz")
|
||||
print(f" - {output_dir}/example_inputs.pt")
|
||||
print(f" - {output_dir}/inference_example.py")
|
||||
print("\n使用方法:")
|
||||
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
|
||||
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
|
||||
print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类")
|
||||
print("\n注意:")
|
||||
print(" - 请确保安装了onnxruntime: pip install onnxruntime")
|
||||
print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu")
|
||||
print(" - 束搜索算法需要根据实际需求进行调整")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
inference.py
137
inference.py
|
|
@ -22,8 +22,6 @@
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
|
@ -62,98 +60,6 @@ class InputMethodInference:
|
|||
|
||||
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
||||
|
||||
# 尝试启用readline以获得更好的行编辑功能
|
||||
try:
|
||||
import readline
|
||||
|
||||
# 设置readline使用UTF-8编码
|
||||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||||
print("📝 readline已启用,支持更好的行编辑功能")
|
||||
except ImportError:
|
||||
print("📝 readline不可用,使用标准输入")
|
||||
|
||||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||||
"""
|
||||
安全的输入函数,尝试正确处理UTF-8字符和退格键
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
用户输入的字符串
|
||||
"""
|
||||
try:
|
||||
# 显示提示和默认值
|
||||
if default:
|
||||
full_prompt = f"{prompt} [{default}]: "
|
||||
else:
|
||||
full_prompt = f"{prompt}: "
|
||||
|
||||
# 使用标准input
|
||||
result = input(full_prompt)
|
||||
|
||||
# 如果用户直接回车且存在默认值,则返回默认值
|
||||
if not result and default:
|
||||
return default
|
||||
|
||||
return result.strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
# 用户按Ctrl+D或Ctrl+C
|
||||
print()
|
||||
return ""
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
print(f"\n⚠️ 输入错误: {e}")
|
||||
return ""
|
||||
|
||||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||||
"""
|
||||
清理拼音输入字符串,处理退格键等特殊字符
|
||||
|
||||
拼音只允许: a-z, `, ', -
|
||||
中文字符和其他字符会被忽略
|
||||
|
||||
Args:
|
||||
pinyin: 原始拼音输入字符串
|
||||
|
||||
Returns:
|
||||
清理后的拼音字符串
|
||||
"""
|
||||
if not pinyin:
|
||||
return ""
|
||||
|
||||
result = []
|
||||
for c in pinyin:
|
||||
# 检查是否为合法拼音字符 (a-z, `, ', -)
|
||||
# 注意: 中文字符的isalpha()也返回True,所以需要额外检查
|
||||
is_valid_pinyin_char = (
|
||||
("a" <= c <= "z")
|
||||
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
|
||||
or c in ["`", "'", "-"]
|
||||
)
|
||||
|
||||
if is_valid_pinyin_char:
|
||||
# 合法拼音字符,转换为小写
|
||||
result.append(c.lower())
|
||||
elif c == " ":
|
||||
# 空格忽略
|
||||
continue
|
||||
elif c == "\b" or c == "\x7f" or c == "\x08":
|
||||
# 退格键、删除键:删除前一个字符
|
||||
if result:
|
||||
result.pop()
|
||||
elif c == "\x1b":
|
||||
# ESC键:清空所有输入
|
||||
result.clear()
|
||||
else:
|
||||
# 其他字符(包括中文字符)忽略
|
||||
# 注意:这里不添加到result,所以退格键无法删除它们
|
||||
# 但用户可能在拼音输入中误输入中文字符,应该忽略
|
||||
pass
|
||||
|
||||
return "".join(result)
|
||||
|
||||
def load_model(self):
|
||||
"""加载训练好的模型"""
|
||||
# 创建模型实例(不编译)
|
||||
|
|
@ -278,7 +184,7 @@ class InputMethodInference:
|
|||
"""
|
||||
|
||||
# 1. 构建tokenizer输入
|
||||
# 根据test.py和dataset.py,格式为: "part4|part1" 和 part3
|
||||
# 根据dataset.py,格式为: "part4|part1" 和 part3
|
||||
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
||||
# part1: text_before
|
||||
# part3: text_after
|
||||
|
|
@ -286,14 +192,13 @@ class InputMethodInference:
|
|||
# 处理上下文提示
|
||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||
|
||||
# 构建输入文本 - 与test.py保持一致
|
||||
# test.py: f"{part4}|{part1}" 作为第一个参数,part3作为第二个参数
|
||||
# 构建输入文本
|
||||
if context_text:
|
||||
input_text = f"{context_text}|{text_before}"
|
||||
else:
|
||||
input_text = text_before
|
||||
|
||||
# 2. Tokenize - 与test.py保持一致
|
||||
# 2. Tokenize
|
||||
encoded = self.tokenizer(
|
||||
input_text,
|
||||
text_after,
|
||||
|
|
@ -304,11 +209,8 @@ class InputMethodInference:
|
|||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
# 3. 处理拼音输入 - 与test.py保持一致
|
||||
# 首先清理拼音字符串,处理退格键等特殊字符
|
||||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||||
|
||||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||||
# 3. 处理拼音输入
|
||||
pinyin_ids = text_to_pinyin_ids(pinyin)
|
||||
if len(pinyin_ids) < 24:
|
||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||
else:
|
||||
|
|
@ -494,16 +396,7 @@ class InputMethodInference:
|
|||
print("\n" + "=" * 60)
|
||||
print("输入法模型推理 - 交互模式")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查终端编码
|
||||
encoding = sys.stdout.encoding or "unknown"
|
||||
print(f"终端编码: {encoding}")
|
||||
if encoding.lower() not in ["utf-8", "utf8"]:
|
||||
print("⚠️ 警告: 终端编码不是UTF-8,中文输入可能有问题")
|
||||
print(" 建议设置: export LANG=en_US.UTF-8")
|
||||
print(" 或设置: export LC_ALL=en_US.UTF-8")
|
||||
|
||||
print("\n说明:")
|
||||
print("说明:")
|
||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||
print(" - 光标前文本: 光标前的连续文本")
|
||||
print(" - 光标后文本: 光标后的连续文本")
|
||||
|
|
@ -518,7 +411,7 @@ class InputMethodInference:
|
|||
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
||||
print("格式: 用逗号分隔多个词汇,可为空")
|
||||
print("示例: 张三,李四,北京大学")
|
||||
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
|
||||
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
|
||||
|
||||
if context_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -541,7 +434,7 @@ class InputMethodInference:
|
|||
print("第2步: 光标前文本")
|
||||
print("说明: 光标前的连续文本内容")
|
||||
print("示例: 今天天气很好")
|
||||
text_before = self._safe_input("请输入光标前文本")
|
||||
text_before = input("请输入光标前文本: ").strip()
|
||||
|
||||
if text_before.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -553,7 +446,7 @@ class InputMethodInference:
|
|||
print("第3步: 光标后文本")
|
||||
print("说明: 光标后的连续文本内容")
|
||||
print("示例: 我们去公园玩")
|
||||
text_after = self._safe_input("请输入光标后文本")
|
||||
text_after = input("请输入光标后文本: ").strip()
|
||||
|
||||
if text_after.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -565,7 +458,7 @@ class InputMethodInference:
|
|||
print("第4步: 拼音输入")
|
||||
print("说明: 当前正在输入的拼音")
|
||||
print("示例: tian, shang, hao")
|
||||
pinyin = self._safe_input("请输入拼音")
|
||||
pinyin = input("请输入拼音: ").strip()
|
||||
|
||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -578,7 +471,7 @@ class InputMethodInference:
|
|||
print("说明: 用户已确认的输入历史,用逗号分隔")
|
||||
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
||||
print(" 今天,天气 (表示已确认两个词)")
|
||||
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
|
||||
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
|
||||
|
||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -631,9 +524,7 @@ class InputMethodInference:
|
|||
|
||||
# 询问是否继续
|
||||
print("\n" + "-" * 40)
|
||||
continue_input = (
|
||||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||||
)
|
||||
continue_input = input("是否继续推理?(y/n): ").strip().lower()
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
|
@ -648,9 +539,7 @@ class InputMethodInference:
|
|||
traceback.print_exc()
|
||||
|
||||
# 询问是否继续
|
||||
continue_input = (
|
||||
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
|
||||
)
|
||||
continue_input = input("\n是否继续?(y/n): ").strip().lower()
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,763 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX输入法模型推理脚本
|
||||
|
||||
使用ONNX Runtime进行推理,测量每个阶段的执行时长
|
||||
|
||||
使用方法:
|
||||
python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx
|
||||
|
||||
交互模式: 分步询问输入
|
||||
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
|
||||
2. 光标前文本: 光标前的连续文本
|
||||
3. 光标后文本: 光标后的连续文本
|
||||
4. 拼音: 当前输入的拼音
|
||||
5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上")
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
from src.model.query import QueryEngine
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX输入法模型推理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_encoder_path: str,
|
||||
decoder_path: str,
|
||||
vocab_size: int = 10019,
|
||||
device: str = "cpu",
|
||||
use_beam_search: bool = False,
|
||||
beam_size: int = 5,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.use_beam_search = use_beam_search
|
||||
self.beam_size = beam_size
|
||||
|
||||
# 加载组件
|
||||
print(f"正在加载上下文编码器: {context_encoder_path}")
|
||||
load_start = time.perf_counter()
|
||||
self.load_context_encoder(context_encoder_path)
|
||||
self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)")
|
||||
|
||||
print(f"正在加载解码器: {decoder_path}")
|
||||
load_start = time.perf_counter()
|
||||
self.load_decoder(decoder_path)
|
||||
self.decoder_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)")
|
||||
|
||||
# 加载tokenizer
|
||||
print("正在加载tokenizer...")
|
||||
load_start = time.perf_counter()
|
||||
self.load_tokenizer()
|
||||
self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)")
|
||||
|
||||
# 加载查询引擎
|
||||
print("正在加载查询引擎...")
|
||||
load_start = time.perf_counter()
|
||||
self.load_query_engine()
|
||||
self.query_engine_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)")
|
||||
|
||||
total_load_time = (
|
||||
self.context_encoder_load_time
|
||||
+ self.decoder_load_time
|
||||
+ self.tokenizer_load_time
|
||||
+ self.query_engine_load_time
|
||||
)
|
||||
print(f"\n✅ 推理器初始化完成 (设备: {device})")
|
||||
print(f" 总加载时间: {total_load_time:.2f}ms")
|
||||
|
||||
# 尝试启用readline
|
||||
try:
|
||||
import readline
|
||||
|
||||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def load_context_encoder(self, model_path: str):
|
||||
"""加载上下文编码器ONNX模型"""
|
||||
providers = (
|
||||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if self.device == "cuda"
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
model_path, providers=providers
|
||||
)
|
||||
|
||||
self.context_input_names = [
|
||||
inp.name for inp in self.context_encoder_session.get_inputs()
|
||||
]
|
||||
self.context_output_names = [
|
||||
out.name for out in self.context_encoder_session.get_outputs()
|
||||
]
|
||||
|
||||
def load_decoder(self, model_path: str):
|
||||
"""加载解码器ONNX模型"""
|
||||
providers = (
|
||||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if self.device == "cuda"
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(model_path, providers=providers)
|
||||
|
||||
self.decoder_input_names = [
|
||||
inp.name for inp in self.decoder_session.get_inputs()
|
||||
]
|
||||
self.decoder_output_names = [
|
||||
out.name for out in self.decoder_session.get_outputs()
|
||||
]
|
||||
|
||||
def load_tokenizer(self):
|
||||
"""加载tokenizer"""
|
||||
try:
|
||||
tokenizer_path = (
|
||||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||
except Exception:
|
||||
print(" ⚠️ 无法加载自定义tokenizer,使用bert-base-chinese")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||||
|
||||
def load_query_engine(self):
|
||||
"""加载查询引擎"""
|
||||
try:
|
||||
self.query_engine = QueryEngine()
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
if stats_path.exists():
|
||||
self.query_engine.load(stats_path)
|
||||
except Exception:
|
||||
self.query_engine = None
|
||||
|
||||
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
|
||||
"""将汉字转换为ID"""
|
||||
if char == "//":
|
||||
return 0
|
||||
|
||||
if self.query_engine is None:
|
||||
return ord(char) if len(char) == 1 else 0
|
||||
|
||||
try:
|
||||
if pinyin is not None:
|
||||
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
|
||||
if info:
|
||||
return info.id
|
||||
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results:
|
||||
return results[0][0]
|
||||
return 0
|
||||
except:
|
||||
return 0
|
||||
|
||||
def id_to_char(self, id: int) -> str:
|
||||
"""将ID转换为汉字"""
|
||||
if id == 0:
|
||||
return "//"
|
||||
|
||||
if self.query_engine is None:
|
||||
return chr(id) if id < 0x110000 else f"[ID:{id}]"
|
||||
|
||||
try:
|
||||
info = self.query_engine.query_by_id(id)
|
||||
return info.char if info else f"[ID:{id}]"
|
||||
except:
|
||||
return f"[ID:{id}]"
|
||||
|
||||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||||
"""清理拼音输入字符串"""
|
||||
if not pinyin:
|
||||
return ""
|
||||
|
||||
result = []
|
||||
for c in pinyin:
|
||||
is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"]
|
||||
if is_valid:
|
||||
result.append(c.lower())
|
||||
elif c == " ":
|
||||
continue
|
||||
elif c in ["\b", "\x7f", "\x08"]:
|
||||
if result:
|
||||
result.pop()
|
||||
elif c == "\x1b":
|
||||
result.clear()
|
||||
return "".join(result)
|
||||
|
||||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||||
"""安全的输入函数"""
|
||||
try:
|
||||
full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: "
|
||||
result = input(full_prompt)
|
||||
if not result and default:
|
||||
return default
|
||||
return result.strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ 输入错误: {e}")
|
||||
return ""
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
context_prompts: List[str],
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
slot_chars: List[str],
|
||||
max_seq_len: int = 128,
|
||||
) -> dict:
|
||||
"""
|
||||
准备模型输入
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'preprocess_time': float, # 预处理时间(ms)
|
||||
'input_ids': numpy array,
|
||||
'attention_mask': numpy array,
|
||||
'pinyin_ids': numpy array,
|
||||
'history_slot_ids': numpy array,
|
||||
}
|
||||
"""
|
||||
preprocess_start = time.perf_counter()
|
||||
|
||||
# 1. 构建tokenizer输入
|
||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||
|
||||
if context_text:
|
||||
input_text = f"{context_text}|{text_before}"
|
||||
else:
|
||||
input_text = text_before
|
||||
|
||||
# 2. Tokenize
|
||||
encoded = self.tokenizer(
|
||||
input_text,
|
||||
text_after,
|
||||
max_length=max_seq_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"].numpy()
|
||||
attention_mask = encoded["attention_mask"].numpy()
|
||||
|
||||
# 3. 处理拼音输入
|
||||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||||
|
||||
if len(pinyin_ids) < 24:
|
||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
|
||||
pinyin_ids = np.array([pinyin_ids], dtype=np.int64)
|
||||
|
||||
# 4. 处理历史槽位
|
||||
history_slot_ids = []
|
||||
for char in slot_chars:
|
||||
char_id = self.char_to_id(char)
|
||||
history_slot_ids.append(char_id)
|
||||
|
||||
if len(history_slot_ids) < 8:
|
||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||
else:
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
history_slot_ids = np.array([history_slot_ids], dtype=np.int64)
|
||||
|
||||
preprocess_time = (time.perf_counter() - preprocess_start) * 1000
|
||||
|
||||
return {
|
||||
"preprocess_time": preprocess_time,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
}
|
||||
|
||||
def run_context_encoder(
|
||||
self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
context_start = time.perf_counter()
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000
|
||||
|
||||
return context_H, pinyin_P, context_mask, pinyin_mask
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
history_slot_ids: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
decoder_start = time.perf_counter()
|
||||
|
||||
inputs = {
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000
|
||||
|
||||
return outputs[0]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
context_prompts: List[str],
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
slot_chars: List[str],
|
||||
top_k: int = 20,
|
||||
use_beam_search: bool = False,
|
||||
beam_size: int = 5,
|
||||
max_length: int = 10,
|
||||
) -> Tuple[List[Tuple[str, float, int]], dict]:
|
||||
"""
|
||||
执行推理
|
||||
|
||||
Args:
|
||||
context_prompts: 上下文提示
|
||||
text_before: 光标前文本
|
||||
text_after: 光标后文本
|
||||
pinyin: 当前输入的拼音
|
||||
slot_chars: 槽位内的汉字列表
|
||||
top_k: 返回top-k个预测结果
|
||||
use_beam_search: 是否使用束搜索
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
|
||||
Returns:
|
||||
(predictions, timing_info)
|
||||
predictions: List[Tuple[char, prob, id]]
|
||||
timing_info: 各阶段耗时字典
|
||||
"""
|
||||
total_start = time.perf_counter()
|
||||
|
||||
# 阶段1: 预处理
|
||||
prep_start = time.perf_counter()
|
||||
inputs = self.prepare_inputs(
|
||||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||||
)
|
||||
preprocess_time = inputs["preprocess_time"]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
pinyin_ids = inputs["pinyin_ids"]
|
||||
history_slot_ids = inputs["history_slot_ids"]
|
||||
prep_time = (time.perf_counter() - prep_start) * 1000
|
||||
|
||||
# 阶段2: 上下文编码
|
||||
context_start = time.perf_counter()
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
context_encoder_time = self.last_context_encoder_time
|
||||
|
||||
if use_beam_search:
|
||||
# 阶段3: 束搜索解码
|
||||
decode_start = time.perf_counter()
|
||||
predictions, beam_decode_time = self._beam_search_decode(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
beam_size,
|
||||
max_length,
|
||||
top_k,
|
||||
)
|
||||
decoder_time = beam_decode_time
|
||||
else:
|
||||
# 阶段3: 单步解码
|
||||
decode_start = time.perf_counter()
|
||||
logits = self.run_decoder(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
history_slot_ids,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
)
|
||||
|
||||
# 阶段4: 后处理
|
||||
postprocess_start = time.perf_counter()
|
||||
probs = self._softmax(logits)
|
||||
top_indices, top_probs = self._topk(probs, top_k)
|
||||
|
||||
predictions = []
|
||||
for i in range(top_k):
|
||||
idx = int(top_indices[0, i])
|
||||
prob = float(top_probs[0, i])
|
||||
char = self.id_to_char(idx)
|
||||
predictions.append((char, prob, idx))
|
||||
|
||||
postprocess_time = (time.perf_counter() - postprocess_start) * 1000
|
||||
decoder_time = self.last_decoder_time
|
||||
|
||||
total_time = (time.perf_counter() - total_start) * 1000
|
||||
|
||||
timing_info = {
|
||||
"预处理": prep_time,
|
||||
"上下文编码": context_encoder_time,
|
||||
"解码": decoder_time,
|
||||
"后处理": postprocess_time if not use_beam_search else 0,
|
||||
"总耗时": total_time,
|
||||
}
|
||||
|
||||
return predictions, timing_info
|
||||
|
||||
def _softmax(self, logits: np.ndarray) -> np.ndarray:
|
||||
"""计算softmax"""
|
||||
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
|
||||
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
|
||||
|
||||
def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""获取top-k"""
|
||||
topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1]
|
||||
topk_probs = np.take_along_axis(probs, topk_indices, axis=-1)
|
||||
return topk_indices, topk_probs
|
||||
|
||||
def _beam_search_decode(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
beam_size: int,
|
||||
max_length: int,
|
||||
top_k: int,
|
||||
) -> Tuple[List[Tuple[str, float, int]], float]:
|
||||
"""束搜索解码"""
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:]
|
||||
|
||||
history_tensor = np.array([history], dtype=np.int64)
|
||||
|
||||
logits = self.run_decoder(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
history_tensor,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
)
|
||||
|
||||
probs = self._softmax(logits)[0]
|
||||
topk_indices = np.argsort(probs)[-beam_size:][::-1]
|
||||
topk_probs = probs[topk_indices]
|
||||
|
||||
for idx, prob in zip(topk_indices, topk_probs):
|
||||
new_seq = seq + [int(idx)]
|
||||
new_score = score + np.log(prob + 1e-10)
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
# 返回top-k个候选
|
||||
predictions = []
|
||||
for seq, score in beams[:top_k]:
|
||||
if seq:
|
||||
char = self.id_to_char(seq[-1])
|
||||
prob = np.exp(score / max(len(seq), 1))
|
||||
else:
|
||||
char = self.id_to_char(0)
|
||||
prob = 0.0
|
||||
predictions.append((char, prob, seq[-1] if seq else 0))
|
||||
|
||||
decode_time = self.last_decoder_time # 只记录最后一次解码的时间
|
||||
|
||||
return predictions, decode_time
|
||||
|
||||
def interactive_mode(self):
|
||||
"""交互式推理模式"""
|
||||
print("\n" + "=" * 60)
|
||||
print("ONNX输入法模型推理 - 交互模式")
|
||||
print("=" * 60)
|
||||
|
||||
encoding = sys.stdout.encoding or "unknown"
|
||||
print(f"终端编码: {encoding}")
|
||||
|
||||
print("\n说明:")
|
||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||
print(" - 光标前文本: 光标前的连续文本")
|
||||
print(" - 光标后文本: 光标后的连续文本")
|
||||
print(" - 拼音: 当前输入的拼音")
|
||||
print(" - 槽位历史: 用户已确认的输入历史")
|
||||
if self.use_beam_search:
|
||||
print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})")
|
||||
else:
|
||||
print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)")
|
||||
print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出")
|
||||
print("-" * 60)
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("\n" + "=" * 60)
|
||||
context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)")
|
||||
if context_input.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
context_prompts = [
|
||||
item.strip() for item in context_input.split(",") if item.strip()
|
||||
]
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
text_before = self._safe_input("第2步: 光标前文本")
|
||||
if text_before.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
text_after = self._safe_input("第3步: 光标后文本")
|
||||
if text_after.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
pinyin = self._safe_input("第4步: 拼音输入")
|
||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)")
|
||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
slot_chars = [
|
||||
char.strip() for char in slot_input.split(",") if char.strip()
|
||||
]
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📝 输入汇总:")
|
||||
print(f" 上下文提示: {context_prompts if context_prompts else '无'}")
|
||||
print(f" 光标前文本: '{text_before}'")
|
||||
print(f" 光标后文本: '{text_after}'")
|
||||
print(f" 拼音: '{pinyin}'")
|
||||
print(f" 槽位历史: {slot_chars if slot_chars else '无'}")
|
||||
|
||||
print("\n🔮 推理中...")
|
||||
predictions, timing_info = self.predict(
|
||||
context_prompts,
|
||||
text_before,
|
||||
text_after,
|
||||
pinyin,
|
||||
slot_chars,
|
||||
top_k=20,
|
||||
use_beam_search=self.use_beam_search,
|
||||
beam_size=self.beam_size,
|
||||
)
|
||||
|
||||
# 显示时间统计
|
||||
print(f"\n⏱️ 执行时间统计:")
|
||||
print("-" * 40)
|
||||
for stage, duration in timing_info.items():
|
||||
if duration > 0:
|
||||
print(f" {stage:<12}: {duration:>8.2f} ms")
|
||||
print("-" * 40)
|
||||
|
||||
# 显示结果
|
||||
print("\n🏆 Top-20 预测结果:")
|
||||
print("-" * 50)
|
||||
for i, (char, prob, idx) in enumerate(predictions):
|
||||
if char == "//":
|
||||
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
|
||||
else:
|
||||
print(
|
||||
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
|
||||
)
|
||||
|
||||
# 显示拼音参考
|
||||
if pinyin and self.query_engine:
|
||||
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
|
||||
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
|
||||
if pinyin_results:
|
||||
for j, (pid, char, count) in enumerate(pinyin_results):
|
||||
print(f" {char} (频次: {count:,})")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
continue_input = (
|
||||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||||
)
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n退出交互模式")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ 推理出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ONNX输入法模型推理")
|
||||
parser.add_argument(
|
||||
"--context-encoder",
|
||||
"-c",
|
||||
type=str,
|
||||
required=True,
|
||||
help="上下文编码器ONNX模型路径",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
"-d",
|
||||
type=str,
|
||||
required=True,
|
||||
help="解码器ONNX模型路径",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
default=10019,
|
||||
help="词汇表大小 (默认: 10019)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="推理设备 (默认: cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="交互模式 (默认: True)",
|
||||
)
|
||||
parser.add_argument("--test", action="store_true", help="运行测试推理")
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
action="store_true",
|
||||
help="使用束搜索解码 (默认: 单步解码)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="束大小 (默认: 5)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(args.context_encoder):
|
||||
print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}")
|
||||
sys.exit(1)
|
||||
if not os.path.exists(args.decoder):
|
||||
print(f"❌ 错误: 解码器文件不存在: {args.decoder}")
|
||||
sys.exit(1)
|
||||
|
||||
# 初始化推理器
|
||||
inference = ONNXInference(
|
||||
context_encoder_path=args.context_encoder,
|
||||
decoder_path=args.decoder,
|
||||
vocab_size=args.vocab_size,
|
||||
device=args.device,
|
||||
use_beam_search=args.beam,
|
||||
beam_size=args.beam_size,
|
||||
)
|
||||
|
||||
# 测试推理
|
||||
if args.test:
|
||||
print("\n🧪 运行测试推理...")
|
||||
print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'")
|
||||
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
|
||||
|
||||
predictions, timing_info = inference.predict(
|
||||
context_prompts=["张三", "李四"],
|
||||
text_before="今天天气",
|
||||
text_after="很好",
|
||||
pinyin="tian",
|
||||
slot_chars=["上"],
|
||||
use_beam_search=args.beam,
|
||||
beam_size=args.beam_size,
|
||||
)
|
||||
|
||||
print(f"\n⏱️ 执行时间统计:")
|
||||
print("-" * 40)
|
||||
for stage, duration in timing_info.items():
|
||||
if duration > 0:
|
||||
print(f" {stage:<12}: {duration:>8.2f} ms")
|
||||
print("-" * 40)
|
||||
|
||||
print(f"\nTop-5 结果:")
|
||||
for i, (char, prob, idx) in enumerate(predictions[:5]):
|
||||
if char == "//":
|
||||
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
|
||||
else:
|
||||
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
|
||||
|
||||
# 交互模式
|
||||
if args.interactive:
|
||||
inference.interactive_mode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -14,7 +14,6 @@ dependencies = [
|
|||
"onnxruntime>=1.24.2",
|
||||
"pandas>=3.0.0",
|
||||
"plotly>=5.0.0",
|
||||
"jieba>=0.42.1",
|
||||
"pypinyin>=0.55.0",
|
||||
"requests>=2.32.5",
|
||||
"rich>=14.3.1",
|
||||
|
|
@ -24,7 +23,6 @@ dependencies = [
|
|||
"transformers==5.1.0",
|
||||
"typer>=0.21.1",
|
||||
"waitress>=3.0.2",
|
||||
"onnx>=1.21.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -1,155 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据分布查看脚本
|
||||
主要关注 labels 为:0-10, 200-210, 2900-2910, 5100-5110, 9520-9530 的分布情况
|
||||
|
||||
使用方法:
|
||||
uv run python src/analyze_data.py --data_path ./data/train.txt
|
||||
|
||||
注意:
|
||||
data_path 可以是:
|
||||
- 本地文本文件 (每行一个文本)
|
||||
- HuggingFace 数据集路径
|
||||
- 目录路径 (该目录下需要有 dataset_info.json 或类似配置文件)
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
|
||||
|
||||
def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 10000):
|
||||
"""分析 label 在指定区间的分布"""
|
||||
target_ranges = [
|
||||
(0, 10),
|
||||
(200, 210),
|
||||
(2900, 2910),
|
||||
(5100, 5110),
|
||||
(9520, 9530),
|
||||
]
|
||||
|
||||
id_counts = defaultdict(int)
|
||||
all_examples = []
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=1, num_workers=16)
|
||||
|
||||
total_count = 0
|
||||
sample_collected = 0
|
||||
|
||||
for batch in dataloader:
|
||||
if sample_collected >= sample_size:
|
||||
break
|
||||
|
||||
label = batch["label"].item()
|
||||
prefix = batch["prefix"][0]
|
||||
suffix = batch["suffix"][0]
|
||||
pinyin = batch["pinyin"][0]
|
||||
history = batch["history_slot_ids"][0].tolist()
|
||||
part4 = prefix.split("^")[0] if "^" in prefix else ""
|
||||
|
||||
sample_collected += 1
|
||||
total_count += 1
|
||||
|
||||
in_target_range = False
|
||||
for start, end in target_ranges:
|
||||
if start <= label <= end:
|
||||
id_counts[label] += 1
|
||||
in_target_range = True
|
||||
|
||||
if len(all_examples) < 200:
|
||||
all_examples.append(
|
||||
{
|
||||
"label": label,
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"pinyin": pinyin,
|
||||
"history": history,
|
||||
"part4": part4,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"采样总数: {total_count}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
print("Label ID 分布统计:")
|
||||
print("-" * 50)
|
||||
for start, end in target_ranges:
|
||||
print(f"\n区间 [{start:5d} - {end:5d}]:")
|
||||
for label_id in range(start, end + 1):
|
||||
count = id_counts[label_id]
|
||||
percentage = (count / total_count * 100) if total_count > 0 else 0
|
||||
bar = "█" * min(count, 50)
|
||||
print(f" ID {label_id:5d}: {count:6d} ({percentage:.3f}%) {bar}")
|
||||
|
||||
print("\n")
|
||||
print("=" * 80)
|
||||
print("随机抽取 20 个样本详情:")
|
||||
print("=" * 80)
|
||||
|
||||
random.shuffle(all_examples)
|
||||
for idx, ex in enumerate(all_examples[:20], 1):
|
||||
print(f"\n样本 {idx}:")
|
||||
print(f" Label: {ex['label']}")
|
||||
print(f" Part4: {ex['part4']}")
|
||||
print(f" 光标前: {ex['prefix']}")
|
||||
print(f" 光标后: {ex['suffix']}")
|
||||
print(f" 拼音: {ex['pinyin']}")
|
||||
print(f" 历史槽位: {ex['history']}")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
default_data_path = os.path.expanduser("~/Data/corpus/CCI-Data/")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="分析数据集 label 分布",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=f"""
|
||||
默认数据集路径: {default_data_path}
|
||||
|
||||
示例:
|
||||
uv run python src/analyze_data.py --data_path {default_data_path} --sample_size 10000
|
||||
uv run python src/analyze_data.py --data_path ./data/eval.txt --sample_size 5000
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
default=default_data_path,
|
||||
help="数据集路径 (本地文件或HuggingFace路径)",
|
||||
)
|
||||
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
|
||||
parser.add_argument("--max_workers", type=int, default=-1, help="DataLoader workers")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"加载数据集: {args.data_path}")
|
||||
print()
|
||||
|
||||
try:
|
||||
dataset = PinyinInputDataset(
|
||||
data_path=args.data_path,
|
||||
max_workers=args.max_workers,
|
||||
max_iter_length=args.sample_size,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\n错误: 无法加载数据集 '{args.data_path}'")
|
||||
print(f"原因: {e}")
|
||||
print("\n请确保:")
|
||||
print(" 1. 数据路径存在且正确")
|
||||
print(" 2. 如果是本地文本文件,每行应该是一个 JSON 对象或纯文本")
|
||||
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
|
||||
return
|
||||
|
||||
analyze_label_distribution(dataset, sample_size=args.sample_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"id": 0,
|
||||
"char": "",
|
||||
"pinyin": "",
|
||||
"count": 494748360
|
||||
"count": 11067734826
|
||||
},
|
||||
"1": {
|
||||
"id": 1,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import jieba
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
|
|
@ -24,26 +23,6 @@ CHAR_TO_ID["'"] = 28 # 显式添加引号
|
|||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||||
|
||||
|
||||
jieba.setLogLevel(jieba.logging.INFO)
|
||||
|
||||
|
||||
def segment_text(text: str) -> List[str]:
|
||||
"""使用 jieba 分词,返回词列表"""
|
||||
return list(jieba.cut(text, HMM=False))
|
||||
|
||||
|
||||
def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]:
|
||||
"""建立词边界列表 [(start, end), ...],基于顺序位置累加"""
|
||||
result = []
|
||||
pos = 0
|
||||
for word in words:
|
||||
start = pos
|
||||
end = pos + len(word)
|
||||
result.append((start, end))
|
||||
pos = end
|
||||
return result
|
||||
|
||||
|
||||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||
"""
|
||||
将拼音字符串转换为 ID 列表。
|
||||
|
|
@ -62,22 +41,17 @@ class PinyinInputDataset(IterableDataset):
|
|||
max_iter_length=1e6,
|
||||
max_seq_length=128,
|
||||
text_field: str = "text",
|
||||
py_style_weight=(90, 2, 1),
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 100000,
|
||||
retention_ratio: float = 0.8,
|
||||
retention_ratio: float = 0.5,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
self.drop_start_freq = 10_000_000
|
||||
self.max_drop_prob = 0.9
|
||||
self.drop_start_freq = 30_000_000
|
||||
self.max_drop_prob = 0.8
|
||||
self.repeat_end_freq = 10_000
|
||||
self.max_repeat_expect = 50
|
||||
self.min_freq = 109
|
||||
self.word_break_prob = 0.10
|
||||
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
|
||||
jieba.initialize()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(files(__package__))) / "assets" / "tokenizer"
|
||||
|
|
@ -109,15 +83,14 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 提取每个样本的目标字符及其频率
|
||||
self.sample_freqs = self.query_engine.get_all_weights()
|
||||
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
|
||||
|
||||
def adjust_frequency(self, freq: int) -> int:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||
# 1. 削峰处理(高频字)
|
||||
if freq >= self.drop_start_freq:
|
||||
# 线性丢弃概率计算
|
||||
max_freq = self.max_freq # 使用预计算的最大频率值
|
||||
if max_freq <= self.drop_start_freq:
|
||||
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
|
||||
if max_freq == self.drop_start_freq:
|
||||
drop_prob = 0.0
|
||||
else:
|
||||
drop_prob = (
|
||||
|
|
@ -146,11 +119,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
)
|
||||
# 使用泊松分布实现随机重复
|
||||
repeat_count = np.random.poisson(repeat_expect)
|
||||
if repeat_expect < 1.0:
|
||||
# 小期望值时,以概率 repeat_expect 采样 1 次
|
||||
return 1 if random.random() < repeat_expect else 0
|
||||
else:
|
||||
return max(1, repeat_count) # 原逻辑
|
||||
return max(1, repeat_count)
|
||||
|
||||
# 3. 中间频率字
|
||||
else:
|
||||
|
|
@ -227,55 +196,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
mask_pinyin.append(py)
|
||||
return len(mask_pinyin), mask_pinyin
|
||||
|
||||
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
|
||||
pinyin_ids = text_to_pinyin_ids(pinyin_str)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
pinyin_ids.extend([0] * (24 - len_py))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
return torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
def _add_word_samples(
|
||||
self,
|
||||
batch_samples: list,
|
||||
labels: list,
|
||||
encoded: dict,
|
||||
part4: str,
|
||||
part1: str,
|
||||
part3: str,
|
||||
pinyin_str: str,
|
||||
pinyin_ids: torch.Tensor,
|
||||
) -> list:
|
||||
for label_idx, label in enumerate(labels):
|
||||
base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0))
|
||||
if base_repeats == 0:
|
||||
continue
|
||||
weight = (
|
||||
self._history_weights[label_idx]
|
||||
if label_idx < len(self._history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
history = labels[:label_idx]
|
||||
len_h = len(history)
|
||||
history.extend([0] * (8 - len_h))
|
||||
|
||||
sample_dict = {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(history, dtype=torch.long),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": pinyin_str,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
batch_samples.extend([sample_dict] * repeats)
|
||||
return batch_samples
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
|
|
@ -288,106 +208,146 @@ class PinyinInputDataset(IterableDataset):
|
|||
random.seed(seed % (2**32))
|
||||
np.random.seed(seed % (2**32))
|
||||
|
||||
# 安全检查:如果worker_id >= num_workers,则该worker不应该工作
|
||||
# 这可能发生在self.max_workers小于实际worker数量时
|
||||
if worker_id >= num_workers:
|
||||
return
|
||||
return # 产生空迭代器
|
||||
|
||||
# 使用局部变量存储分片数据集,避免竞争条件
|
||||
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||
|
||||
# 计算每个worker的配额
|
||||
# 将 max_iter_length 转换为整数以确保整数除法
|
||||
total_quota = int(self.max_iter_length)
|
||||
base_quota = total_quota // num_workers
|
||||
remainder = total_quota % num_workers
|
||||
|
||||
# 最后一个worker处理剩余的样本(如果有余数)
|
||||
if worker_id == num_workers - 1:
|
||||
worker_quota = base_quota + remainder
|
||||
else:
|
||||
worker_quota = base_quota
|
||||
else:
|
||||
# 单worker情况,使用全部配额
|
||||
worker_quota = int(self.max_iter_length)
|
||||
num_workers = 1
|
||||
worker_dataset = self.dataset
|
||||
worker_dataset = self.dataset # 不使用分片
|
||||
|
||||
# 每个worker有自己的迭代计数器
|
||||
current_iter_index = 0
|
||||
|
||||
batch_samples = []
|
||||
for sample in worker_dataset:
|
||||
# 检查是否达到最大迭代次数
|
||||
if current_iter_index >= worker_quota:
|
||||
break
|
||||
|
||||
text = sample.get(self.text_field, "")
|
||||
if not text:
|
||||
continue
|
||||
|
||||
words = segment_text(text)
|
||||
word_boundaries = build_word_boundaries(words)
|
||||
if text:
|
||||
pinyin_list = self.generate_pinyin(text)
|
||||
for i in range(len(text)):
|
||||
# 在开始处理每个字符前检查配额
|
||||
if current_iter_index >= worker_quota:
|
||||
break
|
||||
|
||||
for word_start, word_end in word_boundaries:
|
||||
char_positions = []
|
||||
for i in range(word_start, word_end):
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
char_positions.append(i)
|
||||
labels = []
|
||||
# 如果text[i]不在字符库中,则跳过
|
||||
# 当i小于48时候,则将part1取text[0:i]
|
||||
# 当i大于48时候,则将part1取text[i-48:i]
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
continue
|
||||
if i < 48:
|
||||
part1 = text[0:i]
|
||||
else:
|
||||
part1 = text[i - 48 : i]
|
||||
|
||||
if not char_positions:
|
||||
# 方案C:提前检查从位置i开始连续有多少个字符在词库中
|
||||
max_valid_len = 0
|
||||
for j in range(i, min(i + 8, len(text))):
|
||||
if self.query_engine.is_chinese_char(text[j]):
|
||||
max_valid_len += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果没有可用字符,跳过
|
||||
if max_valid_len == 0:
|
||||
continue
|
||||
|
||||
word_len_chars = len(char_positions)
|
||||
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||||
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||||
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||||
# 但是需要注意边界条件
|
||||
target_len = np.random.choice(
|
||||
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
)
|
||||
# 根据实际可用长度调整
|
||||
pinyin_len = min(target_len, max_valid_len)
|
||||
|
||||
should_break = (
|
||||
word_len_chars > 1 and random.random() < self.word_break_prob
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
pinyin_len, part2 = self.get_mask_pinyin(
|
||||
text[i:py_end], pinyin_list[i:py_end]
|
||||
)
|
||||
|
||||
if should_break:
|
||||
break_pos = random.randint(1, word_len_chars - 1)
|
||||
else:
|
||||
break_pos = word_len_chars
|
||||
|
||||
# ========== Phase 1: 前缀/整词 ==========
|
||||
prefix_positions = char_positions[:break_pos]
|
||||
prefix_text = "".join(text[i] for i in prefix_positions)
|
||||
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
|
||||
|
||||
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
|
||||
split_char = np.random.choice(
|
||||
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||
)
|
||||
part2 = split_char.join(mask_pinyin)
|
||||
pinyin_ids = self._compute_pinyin_ids(part2)
|
||||
|
||||
part2 = split_char.join(part2)
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
pinyin_ids.extend([0] * (24 - len_py))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
# part3为文本,大概率(0.70)为空
|
||||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||||
# x为1-16中的任意整数,取值平均分布
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[
|
||||
i + pinyin_len : i
|
||||
+ pinyin_len
|
||||
+ np.random.choice(range(1, 17))
|
||||
]
|
||||
|
||||
# part4为文本,0.30的概率为空
|
||||
# 不为空则为1-5个连续字符串
|
||||
# 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符
|
||||
# x为2-6中的任意整数,取值平均分布
|
||||
# 使用|将part4中的字符串连接起来
|
||||
part4 = ""
|
||||
if random.random() > 0.7:
|
||||
# 生成1-5个连续字符串
|
||||
num_strings = random.randint(1, 5)
|
||||
string_list = []
|
||||
for _ in range(num_strings):
|
||||
# 随机选择起始位置
|
||||
start_pos = random.randint(0, len(text) - 1)
|
||||
# 随机选择x的值(2-6)
|
||||
x = random.randint(2, 6)
|
||||
# 获取连续字符串
|
||||
end_pos = min(start_pos + x + 1, len(text))
|
||||
string_list.append(text[start_pos:end_pos])
|
||||
# 用|连接所有字符串
|
||||
part4 = "|".join(string_list)
|
||||
try:
|
||||
labels = [
|
||||
self.query_engine.get_char_info_by_char_pinyin(
|
||||
text[i], pinyin_list[i]
|
||||
).id
|
||||
for i in prefix_positions
|
||||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||
for c, p in zip(
|
||||
text[i : i + pinyin_len],
|
||||
pinyin_list[i : i + pinyin_len],
|
||||
)
|
||||
]
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
|
||||
f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 整词末尾 10% 概率追加 EOS(破词前缀不加)
|
||||
if not should_break and random.random() <= 0.1:
|
||||
if random.random() <= 0.1:
|
||||
labels.append(0)
|
||||
|
||||
# part1: 词起点前的文本(所有样本共享)
|
||||
part1 = text[max(0, word_start - 48) : word_start]
|
||||
|
||||
# part3: 词后文本
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[word_end : word_end + np.random.choice(range(1, 17))]
|
||||
|
||||
# part4: 词提示
|
||||
part4 = ""
|
||||
if random.random() > 0.7:
|
||||
num_words = random.randint(1, 3)
|
||||
if words:
|
||||
selected_words = random.sample(
|
||||
words, min(num_words, len(words))
|
||||
)
|
||||
part4 = "|".join(selected_words)
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
|
|
@ -397,110 +357,70 @@ class PinyinInputDataset(IterableDataset):
|
|||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
samples = []
|
||||
# 历史槽位长度权重:增加长历史采样比例
|
||||
# 目标分布: H=0-2占45%, H=3-8占55%
|
||||
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
|
||||
batch_samples = self._add_word_samples(
|
||||
batch_samples,
|
||||
labels,
|
||||
encoded,
|
||||
part4,
|
||||
part1,
|
||||
part3,
|
||||
part2,
|
||||
pinyin_ids,
|
||||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||||
for label_idx, label in enumerate(labels):
|
||||
base_repeats = self.adjust_frequency(label)
|
||||
# 根据历史槽位长度调整采样次数
|
||||
weight = (
|
||||
history_weights[label_idx]
|
||||
if label_idx < len(history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
# ========== Phase 2: 破词续接 ==========
|
||||
if should_break and break_pos < word_len_chars:
|
||||
cont_start = char_positions[break_pos]
|
||||
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
|
||||
masked_labels = labels[:label_idx]
|
||||
len_l = len(masked_labels)
|
||||
masked_labels.extend([0] * (8 - len_l))
|
||||
|
||||
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
|
||||
target_len = np.random.choice(range(1, 9), p=self.cont_length_probs)
|
||||
cont_positions = []
|
||||
pos = cont_start
|
||||
while len(cont_positions) < target_len and pos < len(text):
|
||||
if self.query_engine.is_chinese_char(text[pos]):
|
||||
cont_positions.append(pos)
|
||||
else:
|
||||
break
|
||||
pos += 1
|
||||
|
||||
if not cont_positions:
|
||||
continue
|
||||
|
||||
cont_text = "".join(text[i] for i in cont_positions)
|
||||
cont_pinyin = [pinyin_list[i] for i in cont_positions]
|
||||
|
||||
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
|
||||
split_char_cont = np.random.choice(
|
||||
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||
)
|
||||
part2_cont = split_char_cont.join(mask_pinyin_cont)
|
||||
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
|
||||
|
||||
try:
|
||||
cont_labels = [
|
||||
self.query_engine.get_char_info_by_char_pinyin(
|
||||
text[i], pinyin_list[i]
|
||||
).id
|
||||
for i in cont_positions
|
||||
samples.extend(
|
||||
[
|
||||
{
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(
|
||||
masked_labels, dtype=torch.long
|
||||
),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": part2,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
]
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 续接末尾 10% 概率追加 EOS
|
||||
if random.random() <= 0.1:
|
||||
cont_labels.append(0)
|
||||
|
||||
# part1_cont: 包含已确认前缀的上下文
|
||||
part1_cont = text[max(0, cont_start - 48) : cont_start]
|
||||
|
||||
# part3_cont: 续接目标后的文本
|
||||
cont_end = cont_positions[-1] + 1
|
||||
part3_cont = ""
|
||||
if random.random() > 0.7:
|
||||
part3_cont = text[
|
||||
cont_end : cont_end + np.random.choice(range(1, 17))
|
||||
]
|
||||
|
||||
encoded_cont = self.tokenizer(
|
||||
f"{part4}|{part1_cont}",
|
||||
part3_cont,
|
||||
max_length=self.max_seq_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
* repeats
|
||||
)
|
||||
|
||||
batch_samples = self._add_word_samples(
|
||||
batch_samples,
|
||||
cont_labels,
|
||||
encoded_cont,
|
||||
part4,
|
||||
part1_cont,
|
||||
part3_cont,
|
||||
part2_cont,
|
||||
pinyin_ids_cont,
|
||||
)
|
||||
# 添加到缓冲区
|
||||
batch_samples.extend(samples)
|
||||
|
||||
# 处理shuffle buffer - 单缓冲区半保留方案
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
# 全量打乱缓冲区
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
|
||||
# 计算实际保留大小(不超过缓冲区大小)
|
||||
actual_retention = min(self.retention_size, len(batch_samples))
|
||||
|
||||
# 计算输出数量
|
||||
output_count = len(batch_samples) - actual_retention
|
||||
|
||||
# 输出前output_count个样本
|
||||
for i in range(output_count):
|
||||
if current_iter_index >= worker_quota:
|
||||
# 配额用完,清空缓冲区并返回
|
||||
batch_samples = []
|
||||
return
|
||||
yield batch_samples[indices[i]]
|
||||
current_iter_index += 1
|
||||
|
||||
# 保留后actual_retention个样本(不清空缓冲区)
|
||||
retained_samples = [
|
||||
batch_samples[idx] for idx in indices[output_count:]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,169 +0,0 @@
|
|||
"""
|
||||
ONNX导出专用组件
|
||||
|
||||
为了支持ONNX导出,对原始组件进行修改:
|
||||
1. 移除packed sequence操作
|
||||
2. 处理动态形状问题
|
||||
3. 确保所有操作符都ONNX兼容
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ExportPinyinLSTMEncoder(nn.Module):
|
||||
"""
|
||||
ONNX兼容的拼音LSTM编码器
|
||||
简化版本,不使用packed sequence
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_dim,
|
||||
hidden_size=self.hidden_dim,
|
||||
num_layers=num_layers,
|
||||
bidirectional=True,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
)
|
||||
|
||||
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
ONNX兼容的前向传播
|
||||
不使用packed sequence,改为使用masked计算
|
||||
"""
|
||||
# 简化:直接使用LSTM,不处理padding
|
||||
# 在ONNX中,处理可变长度序列较复杂
|
||||
# 对于输入法场景,拼音长度固定为24,所以可以这样处理
|
||||
output, (hidden, cell) = self.lstm(x)
|
||||
projected = self.proj(output)
|
||||
return self.layer_norm(projected)
|
||||
|
||||
|
||||
class ExportContextEncoder(nn.Module):
|
||||
"""
|
||||
ONNX兼容的上下文编码器
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
|
||||
# 使用原始text_emb,但需要确保它支持ONNX
|
||||
from modelscope import AutoModel
|
||||
|
||||
self.text_emb = AutoModel.from_pretrained(
|
||||
"iic/nlp_structbert_backbone_lite_std"
|
||||
).embeddings
|
||||
|
||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||
self.pos_emb = nn.Embedding(max_len, dim)
|
||||
self.pinyin_pooling = ExportPinyinLSTMEncoder(dim)
|
||||
|
||||
# Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=dim,
|
||||
nhead=n_heads,
|
||||
dim_feedforward=dim * 4,
|
||||
dropout=0.1,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||
self.ln = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, text_ids, pinyin_ids, mask=None):
|
||||
"""
|
||||
ONNX兼容的前向传播
|
||||
"""
|
||||
# 文本嵌入
|
||||
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
|
||||
|
||||
# 位置编码 - 使用预计算的pos_ids
|
||||
seq_len = text_emb.size(1)
|
||||
# 使用静态最大长度,避免动态arange
|
||||
if seq_len > self.max_len:
|
||||
seq_len = self.max_len
|
||||
|
||||
# 创建位置ID,确保在导出时是静态的
|
||||
# 使用torch.full而不是torch.arange,因为arange在动态形状下可能有问题
|
||||
pos_ids = torch.arange(
|
||||
seq_len, device=text_ids.device, dtype=torch.long
|
||||
).unsqueeze(0)
|
||||
|
||||
# 如果实际序列长度小于最大长度,截取位置嵌入
|
||||
if seq_len < self.max_len:
|
||||
x = text_emb[:, :seq_len, :] + self.pos_emb(pos_ids)
|
||||
else:
|
||||
x = text_emb + self.pos_emb(pos_ids)
|
||||
|
||||
# 处理mask
|
||||
if mask is not None:
|
||||
# 确保mask是bool类型
|
||||
src_mask = (mask == 0).to(torch.bool)
|
||||
else:
|
||||
src_mask = None
|
||||
|
||||
# Transformer
|
||||
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||||
H = self.ln(H)
|
||||
|
||||
# 拼音编码
|
||||
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
|
||||
# 简化:不传递mask给LSTM
|
||||
P = self.pinyin_pooling(pinyin_emb) # [B, pinyin_len, dim]
|
||||
|
||||
return H, P
|
||||
|
||||
|
||||
def create_export_context_encoder(original_context_encoder):
|
||||
"""
|
||||
从原始ContextEncoder创建导出版本
|
||||
"""
|
||||
# 获取原始配置
|
||||
config = {
|
||||
"vocab_size": getattr(original_context_encoder, "vocab_size", 10019),
|
||||
"pinyin_vocab_size": original_context_encoder.pinyin_emb.num_embeddings,
|
||||
"dim": original_context_encoder.dim,
|
||||
"n_layers": original_context_encoder.transformer.num_layers,
|
||||
"n_heads": original_context_encoder.transformer.layers[0].self_attn.num_heads,
|
||||
"max_len": original_context_encoder.pos_emb.num_embeddings,
|
||||
}
|
||||
|
||||
# 创建导出版本
|
||||
export_encoder = ExportContextEncoder(**config)
|
||||
|
||||
# 复制权重
|
||||
# 复制text_emb权重(从原始AutoModel embeddings)
|
||||
# 注意:这里假设原始text_emb的结构
|
||||
state_dict = original_context_encoder.state_dict()
|
||||
export_state_dict = export_encoder.state_dict()
|
||||
|
||||
# 复制匹配的权重
|
||||
for key in export_state_dict:
|
||||
if key in state_dict:
|
||||
export_state_dict[key] = state_dict[key]
|
||||
|
||||
# 特殊处理:复制position embeddings
|
||||
if "pos_emb.weight" in export_state_dict and "pos_emb.weight" in state_dict:
|
||||
# 确保大小匹配
|
||||
orig_pos_emb = state_dict["pos_emb.weight"]
|
||||
export_pos_emb = export_state_dict["pos_emb.weight"]
|
||||
min_len = min(orig_pos_emb.size(0), export_pos_emb.size(0))
|
||||
export_state_dict["pos_emb.weight"][:min_len] = orig_pos_emb[:min_len]
|
||||
|
||||
export_encoder.load_state_dict(export_state_dict)
|
||||
|
||||
return export_encoder
|
||||
|
|
@ -1,291 +0,0 @@
|
|||
"""
|
||||
ONNX导出模型定义
|
||||
|
||||
定义两个子模型用于ONNX导出:
|
||||
1. ContextEncoderExport: 输入文本和拼音,输出上下文编码
|
||||
2. DecoderExport: 输入上下文编码、拼音编码和槽位历史,输出logits
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 修补modelscope以支持AutoModel导入
|
||||
import sys
|
||||
import modelscope
|
||||
|
||||
if not hasattr(modelscope, "AutoModel"):
|
||||
from modelscope import Model
|
||||
|
||||
modelscope.AutoModel = Model
|
||||
sys.modules["modelscope"].AutoModel = Model
|
||||
|
||||
from .components import (
|
||||
ContextEncoder,
|
||||
CrossAttentionFusion,
|
||||
MoELayer,
|
||||
SlotMemory,
|
||||
PinyinLSTMEncoder,
|
||||
)
|
||||
|
||||
|
||||
class ContextEncoderExport(nn.Module):
|
||||
"""
|
||||
上下文编码器导出模型
|
||||
输入: input_ids, pinyin_ids, attention_mask
|
||||
输出: context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
|
||||
def __init__(self, context_encoder: ContextEncoder):
|
||||
super().__init__()
|
||||
self.context_encoder = context_encoder
|
||||
self.dim = context_encoder.dim
|
||||
|
||||
# 创建ONNX兼容的拼音LSTM编码器(简化版,不使用packed sequence)
|
||||
# 复制原始LSTM的参数
|
||||
original_pooling = context_encoder.pinyin_pooling
|
||||
self.pinyin_lstm = nn.LSTM(
|
||||
input_size=original_pooling.input_dim,
|
||||
hidden_size=original_pooling.hidden_dim,
|
||||
num_layers=original_pooling.num_layers,
|
||||
bidirectional=True,
|
||||
batch_first=True,
|
||||
dropout=original_pooling.dropout
|
||||
if original_pooling.num_layers > 1
|
||||
else 0.0,
|
||||
)
|
||||
self.pinyin_proj = nn.Linear(
|
||||
original_pooling.hidden_dim * 2, original_pooling.input_dim
|
||||
)
|
||||
self.pinyin_ln = nn.LayerNorm(original_pooling.input_dim)
|
||||
|
||||
# 复制权重
|
||||
self.pinyin_lstm.load_state_dict(original_pooling.lstm.state_dict())
|
||||
self.pinyin_proj.load_state_dict(original_pooling.proj.state_dict())
|
||||
self.pinyin_ln.load_state_dict(original_pooling.layer_norm.state_dict())
|
||||
|
||||
def forward(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch_size, seq_len]
|
||||
pinyin_ids: [batch_size, pinyin_len] (pinyin_len固定为24)
|
||||
attention_mask: [batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H: [batch_size, seq_len, dim]
|
||||
pinyin_P: [batch_size, pinyin_len, dim]
|
||||
context_mask: [batch_size, seq_len] (bool -> int32)
|
||||
pinyin_mask: [batch_size, pinyin_len] (bool -> int32)
|
||||
"""
|
||||
# 获取原始context_encoder的组件
|
||||
text_emb = self.context_encoder.text_emb
|
||||
pinyin_emb = self.context_encoder.pinyin_emb
|
||||
pos_emb = self.context_encoder.pos_emb
|
||||
transformer = self.context_encoder.transformer
|
||||
ln = self.context_encoder.ln
|
||||
|
||||
# 文本嵌入
|
||||
text_emb_out = text_emb(input_ids) # [B, seq_len, dim]
|
||||
|
||||
# 位置编码 - 修复torch.arange动态形状问题
|
||||
seq_len = text_emb_out.size(1)
|
||||
# 使用预计算的位置ID,确保静态形状
|
||||
if seq_len > pos_emb.num_embeddings:
|
||||
seq_len = pos_emb.num_embeddings
|
||||
|
||||
# 创建位置ID - 使用torch.arange但确保是整数张量
|
||||
pos_ids = torch.arange(
|
||||
seq_len, device=input_ids.device, dtype=torch.long
|
||||
).unsqueeze(0)
|
||||
|
||||
# 如果实际序列长度小于最大长度,截取
|
||||
if seq_len < text_emb_out.size(1):
|
||||
x = text_emb_out[:, :seq_len, :] + pos_emb(pos_ids)
|
||||
else:
|
||||
x = text_emb_out + pos_emb(pos_ids)
|
||||
|
||||
# 处理mask - 确保bool类型
|
||||
if attention_mask is not None:
|
||||
src_mask = (attention_mask == 0).to(torch.bool)
|
||||
else:
|
||||
src_mask = None
|
||||
|
||||
# Transformer
|
||||
H = transformer(x, src_key_padding_mask=src_mask)
|
||||
H = ln(H)
|
||||
|
||||
# 恢复原始序列长度(如果需要)
|
||||
if seq_len < text_emb_out.size(1):
|
||||
# 填充H到原始长度
|
||||
original_seq_len = text_emb_out.size(1)
|
||||
padding = torch.zeros(
|
||||
H.size(0),
|
||||
original_seq_len - seq_len,
|
||||
H.size(2),
|
||||
device=H.device,
|
||||
dtype=H.dtype,
|
||||
)
|
||||
H = torch.cat([H, padding], dim=1)
|
||||
|
||||
# 拼音编码 - 使用ONNX兼容版本
|
||||
pinyin_emb_out = pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
|
||||
# 简化的LSTM,不使用packed sequence
|
||||
pinyin_lstm_out, _ = self.pinyin_lstm(pinyin_emb_out)
|
||||
pinyin_proj_out = self.pinyin_proj(pinyin_lstm_out)
|
||||
P = self.pinyin_ln(pinyin_proj_out)
|
||||
|
||||
# 生成mask(转换为int32以便ONNX支持)
|
||||
context_mask = (attention_mask == 0).to(torch.int32) # 1表示padding,0表示有效
|
||||
pinyin_mask = (pinyin_ids == 0).to(torch.int32) # 1表示padding,0表示有效
|
||||
|
||||
return H, P, context_mask, pinyin_mask
|
||||
|
||||
|
||||
class DecoderExport(nn.Module):
|
||||
"""
|
||||
解码器导出模型
|
||||
输入: context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||
输出: logits
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
slot_memory: SlotMemory,
|
||||
cross_attn: CrossAttentionFusion,
|
||||
moe: MoELayer,
|
||||
slot_attention: nn.Module,
|
||||
classifier: nn.Module,
|
||||
num_slots: int = 8,
|
||||
dim: int = 512,
|
||||
):
|
||||
super().__init__()
|
||||
self.slot_memory = slot_memory
|
||||
self.cross_attn = cross_attn
|
||||
self.moe = moe
|
||||
self.slot_attention = slot_attention
|
||||
self.classifier = classifier
|
||||
self.num_slots = num_slots
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
Args:
|
||||
context_H: [batch_size, seq_len, dim]
|
||||
pinyin_P: [batch_size, pinyin_len, dim]
|
||||
history_slot_ids: [batch_size, num_slots]
|
||||
context_mask: [batch_size, seq_len] (int32, 1表示padding)
|
||||
pinyin_mask: [batch_size, pinyin_len] (int32, 1表示padding)
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, vocab_size]
|
||||
"""
|
||||
batch_size = context_H.size(0)
|
||||
|
||||
# 确保history_slot_ids形状正确
|
||||
history_slot_ids = history_slot_ids.view(batch_size, self.num_slots)
|
||||
|
||||
# 1. 槽位记忆
|
||||
S = self.slot_memory(history_slot_ids) # [batch_size, num_slots, dim]
|
||||
|
||||
# 2. 交叉注意力融合
|
||||
# 转换mask: int32 -> bool (ONNX中需要bool类型)
|
||||
context_mask_bool = context_mask.to(torch.bool)
|
||||
pinyin_mask_bool = pinyin_mask.to(torch.bool)
|
||||
|
||||
fused = self.cross_attn(
|
||||
S,
|
||||
context_H,
|
||||
pinyin_P,
|
||||
context_mask=context_mask_bool,
|
||||
pinyin_mask=pinyin_mask_bool,
|
||||
) # [batch_size, num_slots, dim]
|
||||
|
||||
# 3. MoE层
|
||||
moe_out = self.moe(fused) # [batch_size, num_slots, dim]
|
||||
|
||||
# 4. 槽位注意力池化
|
||||
slot_scores = self.slot_attention(moe_out).squeeze(
|
||||
-1
|
||||
) # [batch_size, num_slots]
|
||||
slot_weights = torch.softmax(slot_scores, dim=1) # [batch_size, num_slots]
|
||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch_size, dim]
|
||||
|
||||
# 5. 分类头
|
||||
logits = self.classifier(pooled) # [batch_size, vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def create_export_models_from_checkpoint(checkpoint_path, device="cpu"):
|
||||
"""
|
||||
从checkpoint创建导出模型
|
||||
|
||||
Args:
|
||||
checkpoint_path: 模型checkpoint路径
|
||||
device: 加载设备
|
||||
|
||||
Returns:
|
||||
context_encoder_export: ContextEncoderExport实例
|
||||
decoder_export: DecoderExport实例
|
||||
model_config: 模型配置字典
|
||||
"""
|
||||
# 加载原始模型配置
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# 提取模型配置(从checkpoint或使用默认值)
|
||||
if "config" in checkpoint:
|
||||
config = checkpoint["config"]
|
||||
else:
|
||||
# 使用模型默认配置
|
||||
config = {
|
||||
"vocab_size": 10019,
|
||||
"pinyin_vocab_size": 30,
|
||||
"dim": 512,
|
||||
"num_slots": 8,
|
||||
"n_layers": 4,
|
||||
"n_heads": 4,
|
||||
"num_experts": 10,
|
||||
"max_seq_len": 128,
|
||||
}
|
||||
|
||||
# 创建原始模型
|
||||
from .model import InputMethodEngine
|
||||
|
||||
model = InputMethodEngine(
|
||||
vocab_size=config.get("vocab_size", 10019),
|
||||
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
|
||||
dim=config.get("dim", 512),
|
||||
num_slots=config.get("num_slots", 8),
|
||||
n_layers=config.get("n_layers", 4),
|
||||
n_heads=config.get("n_heads", 4),
|
||||
num_experts=config.get("num_experts", 10),
|
||||
max_seq_len=config.get("max_seq_len", 128),
|
||||
compile=False,
|
||||
)
|
||||
|
||||
# 加载权重
|
||||
if "model_state_dict" in checkpoint:
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# 创建导出模型
|
||||
context_encoder_export = ContextEncoderExport(model.context_encoder)
|
||||
decoder_export = DecoderExport(
|
||||
slot_memory=model.slot_memory,
|
||||
cross_attn=model.cross_attn,
|
||||
moe=model.moe,
|
||||
slot_attention=model.slot_attention,
|
||||
classifier=model.classifier,
|
||||
num_slots=model.num_slots,
|
||||
dim=model.dim,
|
||||
)
|
||||
|
||||
# 设置为评估模式
|
||||
context_encoder_export.eval()
|
||||
decoder_export.eval()
|
||||
|
||||
return context_encoder_export, decoder_export, config
|
||||
|
|
@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module):
|
|||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||
|
||||
# 4. 混合专家层 (MoE)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
|
||||
|
||||
# 5. 槽位注意力池化
|
||||
self.slot_attention = nn.Linear(dim, 1)
|
||||
|
|
|
|||
|
|
@ -1221,7 +1221,7 @@ def train(
|
|||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
|
|||
12
test.py
12
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
|||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||
|
||||
|
||||
part1 = "从中国回家后,他觉得世界上最好的城市就是"
|
||||
part2 = "shanghai"
|
||||
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
|
||||
part2 = "piaoliang"
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
|
|
@ -56,9 +56,9 @@ if len_py < 24:
|
|||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||
masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
|
||||
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
|
||||
part3 = ""
|
||||
part4 = ""
|
||||
part4 = "可行|特别|伤害"
|
||||
|
||||
encoded = tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
|
|
@ -83,9 +83,8 @@ sample = {
|
|||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
token_type_ids = sample["token_type_ids"]
|
||||
|
|
@ -98,7 +97,6 @@ for k, v in sample.items():
|
|||
print(f"{k}: {v}")
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||
sort_res = sorted(
|
||||
|
|
|
|||
101
test_dataset.py
101
test_dataset.py
|
|
@ -1,101 +0,0 @@
|
|||
import sys
|
||||
|
||||
sys.path.append("src")
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.model import InputMethodEngine
|
||||
from model.query import QueryEngine
|
||||
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials
|
||||
from torch.utils.data import IterableDataset
|
||||
from model.dataset import PinyinInputDataset
|
||||
|
||||
|
||||
def worker_init_fn(worker_id: int) -> None:
|
||||
"""
|
||||
初始化每个DataLoader worker的随机种子,确保可复现性
|
||||
|
||||
Args:
|
||||
worker_id: worker的ID
|
||||
"""
|
||||
worker_seed = torch.initial_seed() % (2**32)
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
自定义批处理函数,将多个样本组合成一个batch
|
||||
|
||||
Args:
|
||||
batch: 样本列表,每个样本是一个字典
|
||||
|
||||
Returns:
|
||||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||||
"""
|
||||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||
|
||||
# 字符串字段保持为列表
|
||||
prefixes = [item["prefix"] for item in batch]
|
||||
suffixes = [item["suffix"] for item in batch]
|
||||
pinyins = [item["pinyin"] for item in batch]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
"prefix": prefixes,
|
||||
"suffix": suffixes,
|
||||
"pinyin": pinyins,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
|
||||
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path="/home/songsenand/Data/corpus/CCI-Data/",
|
||||
max_workers=-1, # 自动选择worker数量
|
||||
max_iter_length=1000000,
|
||||
text_field="text",
|
||||
py_style_weight=(90, 2, 1),
|
||||
shuffle_buffer_size=20000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=512,
|
||||
num_workers=2,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2, # 减少预取以避免内存问题
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
for i, shape in tqdm(enumerate(dataloader), total=1000000/512):
|
||||
pass
|
||||
|
||||
409
verify_onnx.py
409
verify_onnx.py
|
|
@ -1,409 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型验证脚本
|
||||
|
||||
验证导出的ONNX模型与原始PyTorch模型输出的一致性
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from src.model.export_models import create_export_models_from_checkpoint
|
||||
|
||||
|
||||
def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5):
|
||||
"""
|
||||
比较PyTorch和ONNX输出
|
||||
|
||||
Args:
|
||||
pytorch_output: PyTorch张量
|
||||
onnx_output: ONNX Runtime输出(numpy数组)
|
||||
name: 输出名称(用于错误信息)
|
||||
rtol: 相对容差
|
||||
atol: 绝对容差
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
# 转换PyTorch输出为numpy
|
||||
if isinstance(pytorch_output, torch.Tensor):
|
||||
pytorch_np = pytorch_output.detach().cpu().numpy()
|
||||
else:
|
||||
pytorch_np = np.array(pytorch_output)
|
||||
|
||||
# 确保形状一致
|
||||
if pytorch_np.shape != onnx_output.shape:
|
||||
print(
|
||||
f"❌ {name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 计算差异
|
||||
diff = np.abs(pytorch_np - onnx_output)
|
||||
max_diff = np.max(diff)
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
# 检查是否在容差范围内
|
||||
is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol)
|
||||
|
||||
if is_close:
|
||||
print(f"✅ {name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
else:
|
||||
print(f"❌ {name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
print(
|
||||
f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], "
|
||||
f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]"
|
||||
)
|
||||
|
||||
return is_close
|
||||
|
||||
|
||||
def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"):
|
||||
"""
|
||||
验证上下文编码器
|
||||
|
||||
Args:
|
||||
checkpoint_path: PyTorch checkpoint路径
|
||||
onnx_path: ONNX模型路径
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
print(f"\n🔍 验证上下文编码器: {onnx_path}")
|
||||
|
||||
# 加载PyTorch模型
|
||||
context_encoder_export, _, config = create_export_models_from_checkpoint(
|
||||
checkpoint_path, device
|
||||
)
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
session = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
|
||||
# 创建测试输入
|
||||
batch_size = 2 # 使用batch_size=2测试动态批处理
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
# 随机屏蔽一些位置
|
||||
attention_mask[:, seq_len // 2 :] = 0
|
||||
|
||||
# PyTorch推理
|
||||
with torch.no_grad():
|
||||
pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask)
|
||||
|
||||
# ONNX推理
|
||||
onnx_inputs = {
|
||||
"input_ids": input_ids.numpy(),
|
||||
"pinyin_ids": pinyin_ids.numpy(),
|
||||
"attention_mask": attention_mask.numpy(),
|
||||
}
|
||||
|
||||
onnx_outputs = session.run(None, onnx_inputs)
|
||||
|
||||
# 比较输出
|
||||
output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"]
|
||||
all_match = True
|
||||
|
||||
for i, name in enumerate(output_names):
|
||||
if i < len(pytorch_outputs) and i < len(onnx_outputs):
|
||||
match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name)
|
||||
all_match = all_match and match
|
||||
|
||||
return all_match
|
||||
|
||||
|
||||
def verify_decoder(checkpoint_path, onnx_path, device="cpu"):
|
||||
"""
|
||||
验证解码器
|
||||
|
||||
Args:
|
||||
checkpoint_path: PyTorch checkpoint路径
|
||||
onnx_path: ONNX模型路径
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
print(f"\n🔍 验证解码器: {onnx_path}")
|
||||
|
||||
# 加载PyTorch模型
|
||||
_, decoder_export, config = create_export_models_from_checkpoint(
|
||||
checkpoint_path, device
|
||||
)
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
session = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
|
||||
# 创建测试输入
|
||||
batch_size = 2
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
dim = config.get("dim", 512)
|
||||
num_slots = config.get("num_slots", 8)
|
||||
|
||||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||||
history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long)
|
||||
|
||||
# PyTorch推理
|
||||
with torch.no_grad():
|
||||
pytorch_output = decoder_export(
|
||||
context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# ONNX推理
|
||||
onnx_inputs = {
|
||||
"context_H": context_H.numpy(),
|
||||
"pinyin_P": pinyin_P.numpy(),
|
||||
"history_slot_ids": history_slot_ids.numpy(),
|
||||
"context_mask": context_mask.numpy(),
|
||||
"pinyin_mask": pinyin_mask.numpy(),
|
||||
}
|
||||
|
||||
onnx_outputs = session.run(None, onnx_inputs)
|
||||
|
||||
# 比较输出
|
||||
return compare_outputs(pytorch_output, onnx_outputs[0], "logits")
|
||||
|
||||
|
||||
def verify_end_to_end(
|
||||
checkpoint_path, context_encoder_path, decoder_path, device="cpu"
|
||||
):
|
||||
"""
|
||||
端到端验证:比较完整推理流程
|
||||
|
||||
Args:
|
||||
checkpoint_path: PyTorch checkpoint路径
|
||||
context_encoder_path: 上下文编码器ONNX路径
|
||||
decoder_path: 解码器ONNX路径
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
print(f"\n🔍 端到端验证")
|
||||
|
||||
# 加载原始PyTorch模型
|
||||
from src.model.model import InputMethodEngine
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
if "config" in checkpoint:
|
||||
config = checkpoint["config"]
|
||||
else:
|
||||
config = {
|
||||
"vocab_size": 10019,
|
||||
"pinyin_vocab_size": 30,
|
||||
"dim": 512,
|
||||
"num_slots": 8,
|
||||
"n_layers": 4,
|
||||
"n_heads": 4,
|
||||
"num_experts": 10,
|
||||
"max_seq_len": 128,
|
||||
}
|
||||
|
||||
model = InputMethodEngine(
|
||||
vocab_size=config.get("vocab_size", 10019),
|
||||
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
|
||||
dim=config.get("dim", 512),
|
||||
num_slots=config.get("num_slots", 8),
|
||||
n_layers=config.get("n_layers", 4),
|
||||
n_heads=config.get("n_heads", 4),
|
||||
num_experts=config.get("num_experts", 10),
|
||||
max_seq_len=config.get("max_seq_len", 128),
|
||||
compile=False,
|
||||
)
|
||||
|
||||
if "model_state_dict" in checkpoint:
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
context_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
|
||||
# 创建测试输入
|
||||
batch_size = 1
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
|
||||
|
||||
# PyTorch完整推理
|
||||
with torch.no_grad():
|
||||
pytorch_logits = model(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=torch.zeros_like(input_ids), # 简化处理
|
||||
attention_mask=attention_mask,
|
||||
pinyin_ids=pinyin_ids,
|
||||
history_slot_ids=history_slot_ids,
|
||||
)
|
||||
|
||||
# ONNX推理流程
|
||||
# 1. 上下文编码器
|
||||
context_inputs = {
|
||||
"input_ids": input_ids.numpy(),
|
||||
"pinyin_ids": pinyin_ids.numpy(),
|
||||
"attention_mask": attention_mask.numpy(),
|
||||
}
|
||||
context_outputs = context_session.run(None, context_inputs)
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
|
||||
|
||||
# 2. 解码器
|
||||
decoder_inputs = {
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy(),
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
onnx_outputs = decoder_session.run(None, decoder_inputs)
|
||||
|
||||
# 比较输出
|
||||
return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ONNX模型验证")
|
||||
parser.add_argument(
|
||||
"--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-encoder",
|
||||
type=str,
|
||||
help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
type=str,
|
||||
default="./exported_models",
|
||||
help="导出目录(如果未指定单个模型路径)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="验证设备(默认: cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-context", action="store_true", help="跳过上下文编码器验证"
|
||||
)
|
||||
parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证")
|
||||
parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确定模型路径
|
||||
if args.context_encoder:
|
||||
context_encoder_path = args.context_encoder
|
||||
else:
|
||||
context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx")
|
||||
|
||||
if args.decoder:
|
||||
decoder_path = args.decoder
|
||||
else:
|
||||
decoder_path = os.path.join(args.output_dir, "decoder.onnx")
|
||||
|
||||
print("🔬 ONNX模型验证")
|
||||
print("=" * 60)
|
||||
print(f"Checkpoint: {args.checkpoint}")
|
||||
print(f"上下文编码器: {context_encoder_path}")
|
||||
print(f"解码器: {decoder_path}")
|
||||
print(f"设备: {args.device}")
|
||||
print()
|
||||
|
||||
all_pass = True
|
||||
|
||||
# 验证上下文编码器
|
||||
if not args.skip_context and os.path.exists(context_encoder_path):
|
||||
if verify_context_encoder(args.checkpoint, context_encoder_path, args.device):
|
||||
print("✅ 上下文编码器验证通过")
|
||||
else:
|
||||
print("❌ 上下文编码器验证失败")
|
||||
all_pass = False
|
||||
elif not args.skip_context:
|
||||
print("⚠️ 上下文编码器文件不存在,跳过验证")
|
||||
|
||||
# 验证解码器
|
||||
if not args.skip_decoder and os.path.exists(decoder_path):
|
||||
if verify_decoder(args.checkpoint, decoder_path, args.device):
|
||||
print("✅ 解码器验证通过")
|
||||
else:
|
||||
print("❌ 解码器验证失败")
|
||||
all_pass = False
|
||||
elif not args.skip_decoder:
|
||||
print("⚠️ 解码器文件不存在,跳过验证")
|
||||
|
||||
# 端到端验证
|
||||
if (
|
||||
not args.skip_end_to_end
|
||||
and os.path.exists(context_encoder_path)
|
||||
and os.path.exists(decoder_path)
|
||||
):
|
||||
if verify_end_to_end(
|
||||
args.checkpoint, context_encoder_path, decoder_path, args.device
|
||||
):
|
||||
print("✅ 端到端验证通过")
|
||||
else:
|
||||
print("❌ 端到端验证失败")
|
||||
all_pass = False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_pass:
|
||||
print("🎉 所有验证通过!ONNX模型与PyTorch模型输出一致")
|
||||
else:
|
||||
print("❌ 部分验证失败,请检查模型导出过程")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue