docs: 详细描述history_slot_ids的设计策略与使用场景

This commit is contained in:
songsenand 2026-05-23 13:25:32 +08:00
parent 71ef54e3d4
commit 53f244de2f
8 changed files with 935 additions and 653 deletions

View File

@ -72,6 +72,36 @@
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。 * **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
* **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。 * **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。
#### 4.1.1 历史槽位history_slot_ids的设计策略
`history_slot_ids` 的语义是**当前拼音组内已经确认的字符**,用于模拟真实的输入法输入场景。
**设计原则:**
1. **拼音组隔离**每个拼音组jieba 分词的一个词或合并后的短词组)对应一次独立的拼音输入会话。拼音组开始时 `history_slot_ids` 为空(全 0组内逐字累积。
```
输入: 特种兵 → 拼音组 "tezhongbing"
Step 1: pinyin="tezhongbing", history=[] → 预测 特
Step 2: pinyin="tezhongbing", history=[特] → 预测 种
Step 3: pinyin="tezhongbing", history=[特, 种] → 预测 兵
输入: 十分倾佩 → 拼音组 "shifen" + 拼音组 "qing" + 拼音组 "pei"
拼音组 "qing": pinyin="qing", history=[] → 预测 倾
拼音组 "pei": pinyin="pei", history=[] → 预测 佩
```
2. **上下文承载前缀字**:当拼音组切换时,前一个拼音组已确认的字符(如"特")已写入文本,通过 BERT 编码的 `input_ids`(而非 `history_slot_ids`)传递给模型。这模拟了真实输入法环境——用户打完"te"确认"特"后,文本中已有"特",再输入"zhongbing"时,模型从上下文(`input_ids`)中能看到"特"。
3. **破词续接**:当一个词被按一定概率(~10%)拆分时(`should_break=True`Phase 1 处理前缀(如"特"Phase 2 处理续接(如"种兵"。Phase 2 的 `cont_processed_history = []` 是**正确的设计**——前缀已确认为文本的一部分续接是一个新的拼音会话history 从空开始。
```
"特种兵" 断在"特"后:
Phase 1: pinyin="te", history=[] → 预测 特 → 确认为文本上下文
Phase 2: pinyin="zhongbing", history=[] → 预测 种
history=[种] → 预测 兵
```
Phase 2 中"特"已通过 `part1`(光标前文本)进入 BERT 上下文中,不在 `history_slot_ids` 中。
4. **短词合并**相邻单字词≤2 字符)有 50% 概率合并为一个拼音组。合并后组内字符共享 history 累积未合并时各自为独立拼音组history 为空)。两种方式都对应真实输入法场景——用户可能一次性输入多个字,也可能逐字输入。
### 4.2 损失函数与优化 ### 4.2 损失函数与优化
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。 * **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
* **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。 * **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。

View File

@ -14,450 +14,12 @@
""" """
import argparse import argparse
import os
import sys import sys
from pathlib import Path from pathlib import Path
import torch
import numpy as np
# 检查ONNX是否可用
try:
import onnx
import onnxruntime as ort
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
print("警告: ONNX或ONNX Runtime未安装")
print("请运行: pip install onnx onnxruntime")
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
from src.model.export_models import create_export_models_from_checkpoint from src.model.onnx_export import check_onnx_available, run_full_export
def check_onnx_available():
"""检查ONNX依赖是否可用"""
if not ONNX_AVAILABLE:
print("错误: ONNX导出需要以下依赖:")
print(" pip install onnx onnxruntime")
print("请安装后重试")
return False
return True
def export_context_encoder(model, output_path, config):
"""
导出上下文编码器为ONNX格式
Args:
model: ContextEncoderExport实例
output_path: 输出路径
config: 模型配置
"""
print(f"正在导出上下文编码器到: {output_path}")
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24 # 固定长度
dim = config.get("dim", 512)
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
# 设置动态轴
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"pinyin_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
}
# 导出模型
torch.onnx.export(
model,
(input_ids, pinyin_ids, attention_mask),
output_path,
input_names=["input_ids", "pinyin_ids", "attention_mask"],
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
dynamic_axes=dynamic_axes,
opset_version=18, # 支持现代操作符
do_constant_folding=True,
verbose=False,
)
print("✅ 上下文编码器导出完成")
# 验证导出
try:
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX模型验证通过")
except Exception as e:
print(f"⚠️ ONNX模型验证警告: {e}")
return input_ids, pinyin_ids, attention_mask
def export_decoder(model, output_path, config, example_inputs=None):
"""
导出解码器为ONNX格式
Args:
model: DecoderExport实例
output_path: 输出路径
config: 模型配置
example_inputs: 示例输入用于验证一致性
"""
print(f"正在导解码器到: {output_path}")
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
# 使用提供的示例输入或创建新的
if example_inputs is not None:
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
batch_size = context_H.size(0) # 使用实际batch size
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
else:
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
# 设置动态轴
dynamic_axes = {
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"history_slot_ids": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
}
# 导出模型
torch.onnx.export(
model,
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
output_path,
input_names=[
"context_H",
"pinyin_P",
"history_slot_ids",
"context_mask",
"pinyin_mask",
],
output_names=["logits"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print("✅ 解码器导出完成")
# 验证导出
try:
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX模型验证通过")
except Exception as e:
print(f"⚠️ ONNX模型验证警告: {e}")
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
def save_example_inputs(output_dir, example_inputs_dict):
"""
保存示例输入为NPZ文件用于验证
Args:
output_dir: 输出目录
example_inputs_dict: 示例输入字典
"""
npz_path = os.path.join(output_dir, "example_inputs.npz")
# 转换为numpy数组
np_data = {}
for key, tensor in example_inputs_dict.items():
if isinstance(tensor, torch.Tensor):
np_data[key] = tensor.cpu().numpy()
elif isinstance(tensor, tuple):
for i, t in enumerate(tensor):
if isinstance(t, torch.Tensor):
np_data[f"{key}_{i}"] = t.cpu().numpy()
np.savez(npz_path, **np_data)
print(f"✅ 示例输入已保存到: {npz_path}")
# 同时保存为PyTorch格式
torch_path = os.path.join(output_dir, "example_inputs.pt")
torch.save(example_inputs_dict, torch_path)
print(f"✅ PyTorch示例输入已保存到: {torch_path}")
def create_inference_example(output_dir, config):
"""
创建推理示例脚本
Args:
output_dir: 输出目录
config: 模型配置
"""
example_path = os.path.join(output_dir, "inference_example.py")
example_code = '''#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()
'''
with open(example_path, "w", encoding="utf-8") as f:
f.write(example_code)
print(f"✅ 推理示例脚本已保存到: {example_path}")
def main(): def main():
@ -485,77 +47,39 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# 检查依赖
if not check_onnx_available(): if not check_onnx_available():
sys.exit(1) sys.exit(1)
# 创建输出目录 print(f"输出目录: {Path(args.output_dir).absolute()}")
context_encoder_path, decoder_path, config = run_full_export(
checkpoint_path=args.checkpoint,
output_dir=args.output_dir,
device=args.device,
skip_verification=args.skip_verification,
)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) print()
print(f"📁 输出目录: {output_dir.absolute()}")
# 加载checkpoint并创建导出模型
print(f"📦 加载checkpoint: {args.checkpoint}")
context_encoder_export, decoder_export, config = (
create_export_models_from_checkpoint(args.checkpoint, args.device)
)
print(f"📊 模型配置: {config}")
# 导出上下文编码器
context_encoder_path = output_dir / "context_encoder.onnx"
example_inputs = export_context_encoder(
context_encoder_export, str(context_encoder_path), config
)
# 使用上下文编码器的输出作为解码器的示例输入
with torch.no_grad():
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
*example_inputs
)
# 导出解码器
decoder_path = output_dir / "decoder.onnx"
export_decoder(
decoder_export,
str(decoder_path),
config,
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
)
# 保存示例输入
example_inputs_dict = {
"input_ids": example_inputs[0],
"pinyin_ids": example_inputs[1],
"attention_mask": example_inputs[2],
"context_H": context_H,
"pinyin_P": pinyin_P,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
save_example_inputs(output_dir, example_inputs_dict)
# 创建推理示例脚本
create_inference_example(output_dir, config)
print("\n" + "=" * 60)
print("🎉 ONNX导出完成")
print("=" * 60) print("=" * 60)
print(f"生成的模型文件:") print("ONNX导出完成")
print("=" * 60)
print("生成的模型文件:")
print(f" - {context_encoder_path}") print(f" - {context_encoder_path}")
print(f" - {decoder_path}") print(f" - {decoder_path}")
print(f" - {output_dir}/example_inputs.npz") print(f" - {output_dir / 'example_inputs.npz'}")
print(f" - {output_dir}/example_inputs.pt") print(f" - {output_dir / 'example_inputs.pt'}")
print(f" - {output_dir}/inference_example.py") print(f" - {output_dir / 'inference_example.py'}")
print("\n使用方法:") print()
print("使用方法:")
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}") print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py") print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类") print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类")
print("\n注意:") print()
print(" - 请确保安装了onnxruntime: pip install onnxruntime") print("注意:")
print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu") print(" - 请确保安装了 onnxruntime: pip install onnxruntime")
print(" - 束搜索算法需要根据实际需求进行调整") print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu")
print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -5,7 +5,7 @@ ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理 展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法 包括束搜索beam search算法
""" """
import os
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch import torch
@ -15,94 +15,46 @@ from typing import List, Tuple
class ONNXInference: class ONNXInference:
"""ONNX模型推理器""" """ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path): def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession( self.context_encoder_session = ort.InferenceSession(
context_encoder_path, context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' providers=['CPUExecutionProvider']
) )
self.decoder_session = ort.InferenceSession( self.decoder_session = ort.InferenceSession(
decoder_path, decoder_path,
providers=['CPUExecutionProvider'] providers=['CPUExecutionProvider']
) )
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}") print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}") print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}") print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}") print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理") raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = { inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
} }
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs) outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs context_H, pinyin_P, context_mask, pinyin_mask = outputs
return ( return (
torch.from_numpy(context_H), torch.from_numpy(context_H),
torch.from_numpy(pinyin_P), torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask), torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask), torch.from_numpy(pinyin_mask),
) )
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = { inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
@ -110,99 +62,46 @@ class ONNXInference:
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
} }
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs) outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0] logits = outputs[0]
return torch.from_numpy(logits) return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019): beam_size=5, max_length=10, vocab_size=10019):
""" beams = [([], 0.0)]
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length): for step in range(max_length):
new_beams = [] new_beams = []
for seq, score in beams: for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8: if len(seq) < 8:
history = seq + [0] * (8 - len(seq)) history = seq + [0] * (8 - len(seq))
else: else:
history = seq[-8:] # 只保留最近8个 history = seq[-8:]
history_tensor = torch.tensor([history], dtype=torch.long) history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder( logits = self.run_decoder(
context_H, pinyin_P, history_tensor, context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask context_mask, pinyin_mask
) )
# 获取概率
probs = F.softmax(logits[0], dim=-1) probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size) top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices): for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()] new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item() new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score)) new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True) new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size] beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended: if all_ended:
break break
return beams return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask input_ids, pinyin_ids, attention_mask
) )
# 2. 运行解码器
logits = self.run_decoder( logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids, context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask context_mask, pinyin_mask
) )
return logits return logits
@ -210,21 +109,19 @@ def main():
"""示例主函数""" """示例主函数"""
print("ONNX模型推理示例") print("ONNX模型推理示例")
print("=" * 60) print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx" context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx" decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件") print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型") print("请先运行 train-model export 导出模型")
return return
inference = ONNXInference(context_encoder_path, decoder_path) inference = ONNXInference(context_encoder_path, decoder_path)
print("\u2705 ONNX推理器初始化完成")
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程") print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

409
scripts/finetune_slots.py Normal file
View File

@ -0,0 +1,409 @@
#!/usr/bin/env python3
"""
临时迁移训练脚本在预训练模型基础上重新训练支持冻结 context_encoder
src/model/trainer.py train 命令行为完全一致额外增加
- --pretrained-checkpoint: 加载预训练权重必需的迁移学习源
- --freeze-context-encoder: 冻结 context_encoder 默认开启
运行方式:
python scripts/finetune_slots.py \
--pretrained-checkpoint ./output/checkpoints/best_model.pt \
--train-data-path /path/to/train_data \
--eval-data-path /path/to/eval_data \
--output-dir ./finetune_output \
--freeze-context-encoder
"""
import argparse
import json
import os
import random
import sys
from datetime import datetime
from pathlib import Path
import numpy as np
import torch
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.model import InputMethodEngine
from model.trainer import (
Trainer,
create_dataloader,
worker_init_fn,
)
from model.dataset import PinyinInputDataset
from model.preprocessed_dataset import (
PreProcessedDataset,
is_preprocessed_data,
)
from loguru import logger
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
def main():
parser = argparse.ArgumentParser(
description="迁移学习训练:加载预训练模型,冻结指定层后重新训练",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# === 数据参数 ===
parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径")
parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径")
parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录")
parser.add_argument("--max-iter-length", type=int,
default=1024 * 1024 * 128, help="每个 epoch 最大样本数")
# === 迁移学习参数 ===
parser.add_argument("--pretrained-checkpoint", "-c", required=True,
help="预训练模型检查点路径")
parser.add_argument("--freeze-context-encoder", action="store_true", default=True,
help="冻结 context_encoder 层 (默认开启)")
parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder",
action="store_false",
help="不冻结 context_encoder")
# === 训练参数 ===
parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小")
parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4,
help="学习率")
parser.add_argument("--min-learning-rate", type=float, default=1e-9,
help="最小学习率")
parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减")
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例")
parser.add_argument("--label-smoothing", type=float, default=0.1,
help="标签平滑参数")
parser.add_argument("--grad-accum-steps", type=int, default=1,
help="梯度累积步数")
parser.add_argument("--clip-grad-norm", type=float, default=1.0,
help="梯度裁剪范数")
parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率")
parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率")
# === 其他参数 ===
parser.add_argument("--mixed-precision", action="store_true", default=True)
parser.add_argument("--no-mixed-precision", dest="mixed_precision",
action="store_false", help="禁用混合精度")
parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数")
parser.add_argument("--tensorboard", action="store_true", default=True)
parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false",
help="禁用 TensorBoard")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument("--compile", action="store_true", default=False,
help="使用 torch.compile 优化")
parser.add_argument("--moe-mode", default="all",
choices=["all", "sparse", "sparse_allow_graph"],
help="MoE 计算策略")
args = parser.parse_args()
# ================================================================
# 初始化
# ================================================================
torch.multiprocessing.set_sharing_strategy("file_system")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
console = Console()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# ================================================================
# 模型常量 (与 trainer.py 保持一致)
# ================================================================
vocab_size = 10019
pinyin_vocab_size = 30
dim = 512
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 10
max_seq_len = 128
# ================================================================
# 打印配置
# ================================================================
console.print(Panel.fit(
"[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan"))
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint)
config_table.add_row("迁移学习", "冻结 context_encoder",
str(args.freeze_context_encoder))
config_table.add_row("数据", "训练数据路径", args.train_data_path)
config_table.add_row("数据", "评估数据路径", args.eval_data_path)
config_table.add_row("数据", "输出目录", args.output_dir)
config_table.add_row("数据", "批次大小", str(args.batch_size))
config_table.add_row("数据", "Worker数量", str(args.num_workers))
config_table.add_row("模型", "MoE策略", args.moe_mode)
config_table.add_row("模型", "编译优化", str(args.compile))
config_table.add_row("训练", "训练轮数", str(args.num_epochs))
config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(args.weight_decay))
config_table.add_row("训练", "热身比例", str(args.warmup_ratio))
config_table.add_row("训练", "标签平滑", str(args.label_smoothing))
config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm))
config_table.add_row("训练", "混合精度", str(args.mixed_precision))
# ================================================================
# 创建数据加载器 (逻辑与 trainer.py CLI 完全一致)
# ================================================================
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
is_train_preprocessed = is_preprocessed_data(args.train_data_path)
is_eval_preprocessed = is_preprocessed_data(args.eval_data_path)
if is_train_preprocessed:
train_dataset = PreProcessedDataset(args.train_data_path,
max_cache_shards=2)
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
shuffle_train = not pre_shuffled
if args.max_iter_length > 0:
capped_samples = min(len(train_dataset), args.max_iter_length)
else:
capped_samples = len(train_dataset)
total_steps = (capped_samples // args.batch_size) * args.num_epochs
train_num_workers = min(args.num_workers, 1)
logger.info(
f"Preprocessed dataset: {len(train_dataset):,} samples, "
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
f"workers={train_num_workers}, steps={total_steps:,}")
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=shuffle_train,
)
config_table.add_row("数据", "训练数据类型", "预处理数据")
else:
train_dataset = PinyinInputDataset(
data_path=args.train_data_path,
max_workers=-1,
max_iter_length=args.max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2},
)
total_steps = int(args.max_iter_length *
args.num_epochs / args.batch_size)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=args.max_iter_length,
)
config_table.add_row("数据", "训练数据类型", "流式数据")
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(args.eval_data_path,
max_cache_shards=1)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
num_workers=0,
pin_memory=torch.cuda.is_available(),
shuffle=False,
)
config_table.add_row("数据", "评估数据类型", "预处理数据")
else:
eval_dataset = PinyinInputDataset(
data_path=args.eval_data_path,
max_workers=-1,
max_iter_length=args.batch_size * 64,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
max_iter_length=args.batch_size * 64,
)
config_table.add_row("数据", "评估数据类型", "流式数据")
config_table.add_row("数据", "总步数", str(total_steps))
console.print(config_table)
# ================================================================
# 创建模型并加载预训练权重
# ================================================================
console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=n_layers,
n_heads=n_heads,
num_experts=num_experts,
max_seq_len=max_seq_len,
compile=args.compile,
moe_mode=args.moe_mode,
)
model.to(device)
# 加载预训练权重
pretrained_path = Path(args.pretrained_checkpoint)
if not pretrained_path.exists():
console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]")
sys.exit(1)
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device)
if "model_state_dict" in checkpoint:
pretrained_weights = checkpoint["model_state_dict"]
else:
pretrained_weights = checkpoint
missing_keys, unexpected_keys = model.load_state_dict(
pretrained_weights, strict=False)
if missing_keys:
console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): "
f"{missing_keys[:5]}...[/yellow]")
if unexpected_keys:
console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): "
f"{unexpected_keys[:5]}...[/yellow]")
console.print(f"[green]✓ 预训练权重加载完成[/green]")
# ================================================================
# 冻结 context_encoder
# ================================================================
if args.freeze_context_encoder:
console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]")
frozen_count = 0
trainable_count = 0
for name, param in model.named_parameters():
if name.startswith("context_encoder"):
param.requires_grad = False
frozen_count += param.numel()
else:
trainable_count += param.numel()
total_params = frozen_count + trainable_count
console.print(f"[green]✓ context_encoder 已冻结[/green]")
logger.info(
f"冻结参数: {frozen_count:,} / {total_params:,} "
f"({frozen_count / total_params * 100:.1f}%), "
f"可训练参数: {trainable_count:,} / {total_params:,} "
f"({trainable_count / total_params * 100:.1f}%)")
else:
logger.info("未冻结任何层,全模型参与训练")
# ================================================================
# 保存配置
# ================================================================
config = {
"pretrained_checkpoint": args.pretrained_checkpoint,
"freeze_context_encoder": args.freeze_context_encoder,
"train_data_path": args.train_data_path,
"eval_data_path": args.eval_data_path,
"output_dir": args.output_dir,
"batch_size": args.batch_size,
"num_epochs": args.num_epochs,
"learning_rate": args.learning_rate,
"min_learning_rate": args.min_learning_rate,
"weight_decay": args.weight_decay,
"warmup_ratio": args.warmup_ratio,
"label_smoothing": args.label_smoothing,
"grad_accum_steps": args.grad_accum_steps,
"clip_grad_norm": args.clip_grad_norm,
"eval_frequency": args.eval_frequency,
"save_frequency": args.save_frequency,
"mixed_precision": args.mixed_precision,
"num_workers": args.num_workers,
"use_tensorboard": args.tensorboard,
"seed": args.seed,
"compile": args.compile,
"moe_mode": args.moe_mode,
"total_steps": total_steps,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"is_train_preprocessed": is_train_preprocessed,
"is_eval_preprocessed": is_eval_preprocessed,
}
config_file = output_path / "training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# ================================================================
# 创建 Trainer 并开始训练
# ================================================================
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=args.output_dir,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
min_learning_rate=args.min_learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
label_smoothing=args.label_smoothing,
grad_accum_steps=args.grad_accum_steps,
clip_grad_norm=args.clip_grad_norm,
eval_frequency=args.eval_frequency,
save_frequency=args.save_frequency,
mixed_precision=args.mixed_precision,
device=device,
use_tensorboard=args.tensorboard,
status_file="training_status.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=None,
reset_training_state=False,
auto_resume=False, # 迁移学习从头开始,不自动恢复
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {args.output_dir}")
if __name__ == "__main__":
main()

View File

@ -53,27 +53,7 @@ class PinyinLSTMEncoder(nn.Module):
self.layer_norm = nn.LayerNorm(input_dim) self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, x, mask=None): def forward(self, x, mask=None):
""" output, _ = self.lstm(x)
Args:
x: [batch, seq_len, input_dim] pinyin embeddings
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
Returns:
output: [batch, seq_len, input_dim] 每位置的拼音编码
"""
total_len = x.size(1)
if mask is not None:
lengths = mask.sum(dim=1).cpu().clamp(min=1)
packed = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False
)
packed_out, (hidden, cell) = self.lstm(packed)
output, _ = nn.utils.rnn.pad_packed_sequence(
packed_out, batch_first=True, total_length=total_len
)
else:
output, (hidden, cell) = self.lstm(x)
projected = self.proj(output) projected = self.proj(output)
return self.layer_norm(projected) return self.layer_norm(projected)

View File

@ -110,6 +110,10 @@ class InputMethodEngine(nn.Module):
history_slot_ids = history_slot_ids.view(-1, self.num_slots) history_slot_ids = history_slot_ids.view(-1, self.num_slots)
is_zero = (history_slot_ids == 0)
cumsum_zero = is_zero.cumsum(dim=1)
slot_mask = (cumsum_zero <= 1).to(torch.get_default_dtype())
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask) H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
S = self.slot_memory(history_slot_ids) S = self.slot_memory(history_slot_ids)
@ -120,10 +124,13 @@ class InputMethodEngine(nn.Module):
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
) )
fused = fused * slot_mask.unsqueeze(-1)
moe_out = self.moe(fused) moe_out = self.moe(fused)
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
slot_scores = self.slot_attention(moe_out).squeeze(-1) slot_scores = self.slot_attention(moe_out).squeeze(-1)
slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9)
slot_weights = torch.softmax(slot_scores, dim=1) slot_weights = torch.softmax(slot_scores, dim=1)
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)

366
src/model/onnx_export.py Normal file
View File

@ -0,0 +1,366 @@
"""
ONNX模型导出核心逻辑
InputMethodEngine 模型导出为两个 ONNX 文件
1. context_encoder.onnx - 上下文编码器可复用
2. decoder.onnx - 解码器
共用此模块的入口
- CLI: train-model export
- 脚本: python export_onnx.py薄壳
"""
import os
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np
import torch
from .export_models import create_export_models_from_checkpoint
def check_onnx_available() -> bool:
try:
import onnx # noqa: F401
import onnxruntime as ort # noqa: F401
return True
except ImportError:
print("错误: ONNX导出需要以下依赖:")
print(" pip install onnx onnxruntime")
print("请安装后重试")
return False
def export_context_encoder(
model,
output_path: str,
config: Dict,
skip_verification: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"pinyin_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
}
torch.onnx.export(
model,
(input_ids, pinyin_ids, attention_mask),
output_path,
input_names=["input_ids", "pinyin_ids", "attention_mask"],
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print(f" 上下文编码器导出完成 -> {output_path}")
if not skip_verification:
try:
import onnx
onnx.checker.check_model(onnx.load(output_path))
print(" ONNX 模型验证通过")
except Exception as e:
print(f" ONNX 模型验证警告: {e}")
return input_ids, pinyin_ids, attention_mask
def export_decoder(
model,
output_path: str,
config: Dict,
example_inputs: Optional[Tuple] = None,
skip_verification: bool = False,
) -> Tuple[torch.Tensor, ...]:
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
if example_inputs is not None:
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
batch_size = context_H.size(0)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
else:
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
dynamic_axes = {
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"history_slot_ids": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
}
torch.onnx.export(
model,
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
output_path,
input_names=[
"context_H",
"pinyin_P",
"history_slot_ids",
"context_mask",
"pinyin_mask",
],
output_names=["logits"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print(f" 解码器导出完成 -> {output_path}")
if not skip_verification:
try:
import onnx
onnx.checker.check_model(onnx.load(output_path))
print(" ONNX 模型验证通过")
except Exception as e:
print(f" ONNX 模型验证警告: {e}")
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
def save_example_inputs(output_dir: str, example_inputs_dict: Dict) -> None:
npz_path = os.path.join(output_dir, "example_inputs.npz")
np_data = {}
for key, tensor in example_inputs_dict.items():
if isinstance(tensor, torch.Tensor):
np_data[key] = tensor.cpu().numpy()
elif isinstance(tensor, tuple):
for i, t in enumerate(tensor):
if isinstance(t, torch.Tensor):
np_data[f"{key}_{i}"] = t.cpu().numpy()
np.savez(npz_path, **np_data)
print(f" 示例输入已保存到: {npz_path}")
torch_path = os.path.join(output_dir, "example_inputs.pt")
torch.save(example_inputs_dict, torch_path)
print(f" PyTorch 示例输入已保存到: {torch_path}")
_INFERENCE_EXAMPLE_TEMPLATE = '''#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider']
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
beams = [([], 0.0)]
for step in range(max_length):
new_beams = []
for seq, score in beams:
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:]
history_tensor = torch.tensor([history], dtype=torch.long)
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
probs = F.softmax(logits[0], dim=-1)
top_probs, top_indices = torch.topk(probs, beam_size)
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行 train-model export 导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("\\u2705 ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()
'''
def create_inference_example(output_dir: str, config: Dict) -> None:
example_path = os.path.join(output_dir, "inference_example.py")
with open(example_path, "w", encoding="utf-8") as f:
f.write(_INFERENCE_EXAMPLE_TEMPLATE)
print(f" 推理示例脚本已保存到: {example_path}")
def run_full_export(
checkpoint_path: str,
output_dir: str,
device: str = "cpu",
skip_verification: bool = False,
) -> Tuple[str, str, Dict]:
"""
完整的 ONNX 导出流程
Returns:
(context_encoder_path, decoder_path, config)
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"加载 checkpoint: {checkpoint_path}")
context_encoder_export, decoder_export, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
print(f"模型配置: vocab_size={config.get('vocab_size')}, "
f"dim={config.get('dim')}, "
f"num_slots={config.get('num_slots')}, "
f"num_experts={config.get('num_experts')}, "
f"moe_mode={config.get('moe_mode', 'all')}")
print("正在导出模型...")
context_encoder_path = str(output_path / "context_encoder.onnx")
example_inputs = export_context_encoder(
context_encoder_export, context_encoder_path, config,
skip_verification=skip_verification,
)
with torch.no_grad():
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
*example_inputs
)
decoder_path = str(output_path / "decoder.onnx")
export_decoder(
decoder_export, decoder_path, config,
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
skip_verification=skip_verification,
)
example_inputs_dict = {
"input_ids": example_inputs[0],
"pinyin_ids": example_inputs[1],
"attention_mask": example_inputs[2],
"context_H": context_H,
"pinyin_P": pinyin_P,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
save_example_inputs(output_dir, example_inputs_dict)
create_inference_example(output_dir, config)
return context_encoder_path, decoder_path, config

View File

@ -838,6 +838,20 @@ class Trainer:
# epoch 完成 # epoch 完成
progress.update(epoch_task, advance=1) progress.update(epoch_task, advance=1)
# 每个 epoch 结束后评估并保存 best model
epoch_eval_metrics = self.evaluate()
if epoch_eval_metrics:
epoch_log_metrics = {
"epoch/eval_loss": epoch_eval_metrics["eval_loss"],
"epoch/eval_accuracy": epoch_eval_metrics["eval_accuracy"],
}
if self.writer is not None:
for key, value in epoch_log_metrics.items():
self.writer.add_scalar(key, value, global_step)
if epoch_eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = epoch_eval_metrics["eval_loss"]
self.save_checkpoint("best_model.pt", is_best=True)
# 每个 epoch 结束后保存检查点 # 每个 epoch 结束后保存检查点
self.save_epoch_checkpoint(epoch + 1) self.save_epoch_checkpoint(epoch + 1)
@ -848,6 +862,34 @@ class Trainer:
# 训练完成 # 训练完成
logger.info("Training completed!") logger.info("Training completed!")
# 最终评估并保存 best model
logger.info("Running final evaluation...")
final_eval_metrics = self.evaluate()
if final_eval_metrics:
final_log_metrics = {
"train/loss": accumulated_loss / accumulation_counter
if accumulation_counter > 0
else 0.0,
"train/accuracy": accumulated_accuracy / accumulation_counter
if accumulation_counter > 0
else 0.0,
"train/learning_rate": self._get_current_lr(),
"eval/loss": final_eval_metrics["eval_loss"],
"eval/accuracy": final_eval_metrics["eval_accuracy"],
}
self._log_to_tensorboard(final_log_metrics, global_step)
logger.info(
f"Final eval loss: {final_eval_metrics['eval_loss']:.4f}, "
f"Final eval accuracy: {final_eval_metrics['eval_accuracy']:.4f}"
)
if final_eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = final_eval_metrics["eval_loss"]
self.save_checkpoint("best_model.pt", is_best=True)
# 保存最终模型
self.save_checkpoint("final_model.pt")
logger.info("Final model saved.")
# 显示保存的 epoch checkpoint 信息 # 显示保存的 epoch checkpoint 信息
if self.epoch_checkpoints: if self.epoch_checkpoints:
sorted_checkpoints = self.get_epoch_checkpoints() sorted_checkpoints = self.get_epoch_checkpoints()
@ -1096,6 +1138,7 @@ def create_dataloader(
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=preprocessed_collate_fn, collate_fn=preprocessed_collate_fn,
drop_last=True,
persistent_workers=True if num_workers > 0 else False, persistent_workers=True if num_workers > 0 else False,
) )
@ -1108,7 +1151,8 @@ def create_dataloader(
pin_memory=pin_memory, pin_memory=pin_memory,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
collate_fn=preprocess_collate_fn(fixed_max_seq_length), collate_fn=preprocess_collate_fn(fixed_max_seq_length),
prefetch_factor=2, # 减少预取以避免内存问题 drop_last=True,
prefetch_factor=2,
persistent_workers=True, persistent_workers=True,
shuffle=shuffle, shuffle=shuffle,
) )
@ -1464,21 +1508,46 @@ def evaluate(
@app.command() @app.command()
def export( def export(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"), checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
output_path: str = typer.Option( output_dir: str = typer.Option(
"./exported_model.onnx", "--output", "-o", help="输出路径" "./exported_models", "--output", "-o", help="输出目录"
),
device: str = typer.Option("cpu", "--device", help="导出设备cpu/cuda"),
skip_verification: bool = typer.Option(
False, "--skip-verification", help="跳过ONNX模型验证"
), ),
): ):
""" """
导出模型为ONNX格式 导出模型为ONNX格式生成 context_encoder.onnx decoder.onnx
""" """
from .onnx_export import check_onnx_available, run_full_export
console = Console() console = Console()
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]") console.print("[bold cyan]━━━ ONNX 模型导出 ━━━[/bold cyan]")
console.print(f" 检查点: {checkpoint_path}")
console.print(f" 输出目录: {output_dir}")
console.print(f" 设备: {device}")
# 这里应该实现导出逻辑 if not check_onnx_available():
# 1. 加载检查点 raise typer.Exit(1)
# 2. 导出为ONNX
console.print("[yellow]导出功能待实现[/yellow]") try:
context_encoder_path, decoder_path, config = run_full_export(
checkpoint_path=checkpoint_path,
output_dir=output_dir,
device=device,
skip_verification=skip_verification,
)
except Exception as e:
console.print(f"[bold red]导出失败: {e}[/bold red]")
raise typer.Exit(1)
console.print("\n[bold green]✓ 导出完成![/bold green]")
console.print(f" context_encoder.onnx -> {context_encoder_path}")
console.print(f" decoder.onnx -> {decoder_path}")
console.print(
f"\n MoE 层使用 'all' 模式(全量计算 {config.get('num_experts', '?')} 个专家),"
f"稀疏化优化可后续迭代"
)
if __name__ == "__main__": if __name__ == "__main__":