docs: 详细描述history_slot_ids的设计策略与使用场景
This commit is contained in:
parent
71ef54e3d4
commit
53f244de2f
30
README.md
30
README.md
|
|
@ -72,6 +72,36 @@
|
||||||
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
||||||
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
||||||
|
|
||||||
|
#### 4.1.1 历史槽位(history_slot_ids)的设计策略
|
||||||
|
|
||||||
|
`history_slot_ids` 的语义是**当前拼音组内已经确认的字符**,用于模拟真实的输入法输入场景。
|
||||||
|
|
||||||
|
**设计原则:**
|
||||||
|
|
||||||
|
1. **拼音组隔离**:每个拼音组(jieba 分词的一个词或合并后的短词组)对应一次独立的拼音输入会话。拼音组开始时 `history_slot_ids` 为空(全 0),组内逐字累积。
|
||||||
|
```
|
||||||
|
输入: 特种兵 → 拼音组 "tezhongbing"
|
||||||
|
Step 1: pinyin="tezhongbing", history=[] → 预测 特
|
||||||
|
Step 2: pinyin="tezhongbing", history=[特] → 预测 种
|
||||||
|
Step 3: pinyin="tezhongbing", history=[特, 种] → 预测 兵
|
||||||
|
输入: 十分倾佩 → 拼音组 "shifen" + 拼音组 "qing" + 拼音组 "pei"
|
||||||
|
拼音组 "qing": pinyin="qing", history=[] → 预测 倾
|
||||||
|
拼音组 "pei": pinyin="pei", history=[] → 预测 佩
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **上下文承载前缀字**:当拼音组切换时,前一个拼音组已确认的字符(如"特")已写入文本,通过 BERT 编码的 `input_ids`(而非 `history_slot_ids`)传递给模型。这模拟了真实输入法环境——用户打完"te"确认"特"后,文本中已有"特",再输入"zhongbing"时,模型从上下文(`input_ids`)中能看到"特"。
|
||||||
|
|
||||||
|
3. **破词续接**:当一个词被按一定概率(~10%)拆分时(`should_break=True`),Phase 1 处理前缀(如"特"),Phase 2 处理续接(如"种兵")。Phase 2 的 `cont_processed_history = []` 是**正确的设计**——前缀已确认为文本的一部分,续接是一个新的拼音会话,history 从空开始。
|
||||||
|
```
|
||||||
|
"特种兵" 断在"特"后:
|
||||||
|
Phase 1: pinyin="te", history=[] → 预测 特 → 确认为文本上下文
|
||||||
|
Phase 2: pinyin="zhongbing", history=[] → 预测 种
|
||||||
|
history=[种] → 预测 兵
|
||||||
|
```
|
||||||
|
Phase 2 中"特"已通过 `part1`(光标前文本)进入 BERT 上下文中,不在 `history_slot_ids` 中。
|
||||||
|
|
||||||
|
4. **短词合并**:相邻单字词(≤2 字符)有 50% 概率合并为一个拼音组。合并后组内字符共享 history 累积;未合并时各自为独立拼音组(history 为空)。两种方式都对应真实输入法场景——用户可能一次性输入多个字,也可能逐字输入。
|
||||||
|
|
||||||
### 4.2 损失函数与优化
|
### 4.2 损失函数与优化
|
||||||
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
||||||
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
||||||
|
|
|
||||||
526
export_onnx.py
526
export_onnx.py
|
|
@ -14,450 +14,12 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# 检查ONNX是否可用
|
|
||||||
try:
|
|
||||||
import onnx
|
|
||||||
import onnxruntime as ort
|
|
||||||
|
|
||||||
ONNX_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
ONNX_AVAILABLE = False
|
|
||||||
print("警告: ONNX或ONNX Runtime未安装")
|
|
||||||
print("请运行: pip install onnx onnxruntime")
|
|
||||||
|
|
||||||
# 添加src目录到路径
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from src.model.export_models import create_export_models_from_checkpoint
|
from src.model.onnx_export import check_onnx_available, run_full_export
|
||||||
|
|
||||||
|
|
||||||
def check_onnx_available():
|
|
||||||
"""检查ONNX依赖是否可用"""
|
|
||||||
if not ONNX_AVAILABLE:
|
|
||||||
print("错误: ONNX导出需要以下依赖:")
|
|
||||||
print(" pip install onnx onnxruntime")
|
|
||||||
print("请安装后重试")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def export_context_encoder(model, output_path, config):
|
|
||||||
"""
|
|
||||||
导出上下文编码器为ONNX格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: ContextEncoderExport实例
|
|
||||||
output_path: 输出路径
|
|
||||||
config: 模型配置
|
|
||||||
"""
|
|
||||||
print(f"正在导出上下文编码器到: {output_path}")
|
|
||||||
|
|
||||||
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
|
|
||||||
batch_size = 2
|
|
||||||
seq_len = config.get("max_seq_len", 128)
|
|
||||||
pinyin_len = 24 # 固定长度
|
|
||||||
dim = config.get("dim", 512)
|
|
||||||
|
|
||||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
|
||||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
|
||||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
|
||||||
|
|
||||||
# 设置动态轴
|
|
||||||
dynamic_axes = {
|
|
||||||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
|
||||||
"pinyin_ids": {0: "batch_size"},
|
|
||||||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
|
||||||
"context_H": {0: "batch_size", 1: "seq_len"},
|
|
||||||
"pinyin_P": {0: "batch_size"},
|
|
||||||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
|
||||||
"pinyin_mask": {0: "batch_size"},
|
|
||||||
}
|
|
||||||
|
|
||||||
# 导出模型
|
|
||||||
torch.onnx.export(
|
|
||||||
model,
|
|
||||||
(input_ids, pinyin_ids, attention_mask),
|
|
||||||
output_path,
|
|
||||||
input_names=["input_ids", "pinyin_ids", "attention_mask"],
|
|
||||||
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
|
|
||||||
dynamic_axes=dynamic_axes,
|
|
||||||
opset_version=18, # 支持现代操作符
|
|
||||||
do_constant_folding=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("✅ 上下文编码器导出完成")
|
|
||||||
|
|
||||||
# 验证导出
|
|
||||||
try:
|
|
||||||
onnx_model = onnx.load(output_path)
|
|
||||||
onnx.checker.check_model(onnx_model)
|
|
||||||
print("✅ ONNX模型验证通过")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ ONNX模型验证警告: {e}")
|
|
||||||
|
|
||||||
return input_ids, pinyin_ids, attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def export_decoder(model, output_path, config, example_inputs=None):
|
|
||||||
"""
|
|
||||||
导出解码器为ONNX格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: DecoderExport实例
|
|
||||||
output_path: 输出路径
|
|
||||||
config: 模型配置
|
|
||||||
example_inputs: 示例输入(用于验证一致性)
|
|
||||||
"""
|
|
||||||
print(f"正在导解码器到: {output_path}")
|
|
||||||
|
|
||||||
# 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理
|
|
||||||
batch_size = 2
|
|
||||||
seq_len = config.get("max_seq_len", 128)
|
|
||||||
pinyin_len = 24
|
|
||||||
dim = config.get("dim", 512)
|
|
||||||
num_slots = config.get("num_slots", 8)
|
|
||||||
|
|
||||||
# 使用提供的示例输入或创建新的
|
|
||||||
if example_inputs is not None:
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
|
|
||||||
batch_size = context_H.size(0) # 使用实际batch size
|
|
||||||
history_slot_ids = torch.randint(
|
|
||||||
0, 100, (batch_size, num_slots), dtype=torch.long
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
|
||||||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
|
||||||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
|
||||||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
|
||||||
history_slot_ids = torch.randint(
|
|
||||||
0, 100, (batch_size, num_slots), dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
# 设置动态轴
|
|
||||||
dynamic_axes = {
|
|
||||||
"context_H": {0: "batch_size", 1: "seq_len"},
|
|
||||||
"pinyin_P": {0: "batch_size"},
|
|
||||||
"history_slot_ids": {0: "batch_size"},
|
|
||||||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
|
||||||
"pinyin_mask": {0: "batch_size"},
|
|
||||||
"logits": {0: "batch_size"},
|
|
||||||
}
|
|
||||||
|
|
||||||
# 导出模型
|
|
||||||
torch.onnx.export(
|
|
||||||
model,
|
|
||||||
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
|
|
||||||
output_path,
|
|
||||||
input_names=[
|
|
||||||
"context_H",
|
|
||||||
"pinyin_P",
|
|
||||||
"history_slot_ids",
|
|
||||||
"context_mask",
|
|
||||||
"pinyin_mask",
|
|
||||||
],
|
|
||||||
output_names=["logits"],
|
|
||||||
dynamic_axes=dynamic_axes,
|
|
||||||
opset_version=18,
|
|
||||||
do_constant_folding=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("✅ 解码器导出完成")
|
|
||||||
|
|
||||||
# 验证导出
|
|
||||||
try:
|
|
||||||
onnx_model = onnx.load(output_path)
|
|
||||||
onnx.checker.check_model(onnx_model)
|
|
||||||
print("✅ ONNX模型验证通过")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ ONNX模型验证警告: {e}")
|
|
||||||
|
|
||||||
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
|
||||||
|
|
||||||
|
|
||||||
def save_example_inputs(output_dir, example_inputs_dict):
|
|
||||||
"""
|
|
||||||
保存示例输入为NPZ文件,用于验证
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_dir: 输出目录
|
|
||||||
example_inputs_dict: 示例输入字典
|
|
||||||
"""
|
|
||||||
npz_path = os.path.join(output_dir, "example_inputs.npz")
|
|
||||||
|
|
||||||
# 转换为numpy数组
|
|
||||||
np_data = {}
|
|
||||||
for key, tensor in example_inputs_dict.items():
|
|
||||||
if isinstance(tensor, torch.Tensor):
|
|
||||||
np_data[key] = tensor.cpu().numpy()
|
|
||||||
elif isinstance(tensor, tuple):
|
|
||||||
for i, t in enumerate(tensor):
|
|
||||||
if isinstance(t, torch.Tensor):
|
|
||||||
np_data[f"{key}_{i}"] = t.cpu().numpy()
|
|
||||||
|
|
||||||
np.savez(npz_path, **np_data)
|
|
||||||
print(f"✅ 示例输入已保存到: {npz_path}")
|
|
||||||
|
|
||||||
# 同时保存为PyTorch格式
|
|
||||||
torch_path = os.path.join(output_dir, "example_inputs.pt")
|
|
||||||
torch.save(example_inputs_dict, torch_path)
|
|
||||||
print(f"✅ PyTorch示例输入已保存到: {torch_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_inference_example(output_dir, config):
|
|
||||||
"""
|
|
||||||
创建推理示例脚本
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_dir: 输出目录
|
|
||||||
config: 模型配置
|
|
||||||
"""
|
|
||||||
example_path = os.path.join(output_dir, "inference_example.py")
|
|
||||||
|
|
||||||
example_code = '''#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
ONNX模型推理示例
|
|
||||||
|
|
||||||
展示如何使用导出的两个ONNX模型进行推理,
|
|
||||||
包括束搜索(beam search)算法。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXInference:
|
|
||||||
"""ONNX模型推理器"""
|
|
||||||
|
|
||||||
def __init__(self, context_encoder_path, decoder_path):
|
|
||||||
"""
|
|
||||||
初始化ONNX推理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_encoder_path: 上下文编码器ONNX模型路径
|
|
||||||
decoder_path: 解码器ONNX模型路径
|
|
||||||
"""
|
|
||||||
# 创建ONNX Runtime会话
|
|
||||||
self.context_encoder_session = ort.InferenceSession(
|
|
||||||
context_encoder_path,
|
|
||||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
|
||||||
)
|
|
||||||
self.decoder_session = ort.InferenceSession(
|
|
||||||
decoder_path,
|
|
||||||
providers=['CPUExecutionProvider']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取输入输出名称
|
|
||||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
|
||||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
|
||||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
|
||||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
|
||||||
|
|
||||||
print(f"上下文编码器输入: {self.context_input_names}")
|
|
||||||
print(f"上下文编码器输出: {self.context_output_names}")
|
|
||||||
print(f"解码器输入: {self.decoder_input_names}")
|
|
||||||
print(f"解码器输出: {self.decoder_output_names}")
|
|
||||||
|
|
||||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
|
||||||
"""
|
|
||||||
准备模型输入(与原始推理脚本保持一致)
|
|
||||||
|
|
||||||
注意: 这里需要实现文本到token的转换
|
|
||||||
为了简化示例,假设已经实现了相关函数
|
|
||||||
"""
|
|
||||||
# 这里应该调用实际的预处理函数
|
|
||||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
|
||||||
raise NotImplementedError("请实现实际的输入预处理")
|
|
||||||
|
|
||||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
|
||||||
"""
|
|
||||||
运行上下文编码器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: [batch, seq_len]
|
|
||||||
pinyin_ids: [batch, 24]
|
|
||||||
attention_mask: [batch, seq_len]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask
|
|
||||||
"""
|
|
||||||
# 准备输入
|
|
||||||
inputs = {
|
|
||||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
|
||||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
|
||||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 运行推理
|
|
||||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
|
||||||
|
|
||||||
# 解包输出
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.from_numpy(context_H),
|
|
||||||
torch.from_numpy(pinyin_P),
|
|
||||||
torch.from_numpy(context_mask),
|
|
||||||
torch.from_numpy(pinyin_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
|
||||||
"""
|
|
||||||
运行解码器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_H: [batch, seq_len, 512]
|
|
||||||
pinyin_P: [batch, 24, 512]
|
|
||||||
history_slot_ids: [batch, 8]
|
|
||||||
context_mask: [batch, seq_len]
|
|
||||||
pinyin_mask: [batch, 24]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits: [batch, vocab_size]
|
|
||||||
"""
|
|
||||||
# 准备输入
|
|
||||||
inputs = {
|
|
||||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
|
||||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
|
||||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
|
||||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
|
||||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 运行推理
|
|
||||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
|
||||||
|
|
||||||
# 解包输出
|
|
||||||
logits = outputs[0]
|
|
||||||
return torch.from_numpy(logits)
|
|
||||||
|
|
||||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
|
||||||
beam_size=5, max_length=10, vocab_size=10019):
|
|
||||||
"""
|
|
||||||
束搜索算法示例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_H: 上下文编码
|
|
||||||
pinyin_P: 拼音编码
|
|
||||||
context_mask: 上下文掩码
|
|
||||||
pinyin_mask: 拼音掩码
|
|
||||||
beam_size: 束大小
|
|
||||||
max_length: 最大生成长度
|
|
||||||
vocab_size: 词汇表大小
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
最佳序列列表
|
|
||||||
"""
|
|
||||||
# 初始束:空序列,分数为0
|
|
||||||
beams = [([], 0.0)] # (序列, 对数概率)
|
|
||||||
|
|
||||||
for step in range(max_length):
|
|
||||||
new_beams = []
|
|
||||||
|
|
||||||
for seq, score in beams:
|
|
||||||
# 构建history_slot_ids:已确认的字符ID
|
|
||||||
if len(seq) < 8:
|
|
||||||
history = seq + [0] * (8 - len(seq))
|
|
||||||
else:
|
|
||||||
history = seq[-8:] # 只保留最近8个
|
|
||||||
|
|
||||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
|
||||||
|
|
||||||
# 运行解码器
|
|
||||||
logits = self.run_decoder(
|
|
||||||
context_H, pinyin_P, history_tensor,
|
|
||||||
context_mask, pinyin_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取概率
|
|
||||||
probs = F.softmax(logits[0], dim=-1)
|
|
||||||
|
|
||||||
# 获取top-k候选
|
|
||||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
|
||||||
|
|
||||||
# 扩展束
|
|
||||||
for prob, idx in zip(top_probs, top_indices):
|
|
||||||
new_seq = seq + [idx.item()]
|
|
||||||
new_score = score + torch.log(prob).item()
|
|
||||||
new_beams.append((new_seq, new_score))
|
|
||||||
|
|
||||||
# 剪枝:保留beam_size个最佳候选
|
|
||||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
beams = new_beams[:beam_size]
|
|
||||||
|
|
||||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
|
||||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
|
||||||
if all_ended:
|
|
||||||
break
|
|
||||||
|
|
||||||
return beams
|
|
||||||
|
|
||||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
|
||||||
"""
|
|
||||||
单步预测
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: 输入token IDs
|
|
||||||
pinyin_ids: 拼音IDs
|
|
||||||
attention_mask: 注意力掩码
|
|
||||||
history_slot_ids: 历史槽位IDs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
预测logits
|
|
||||||
"""
|
|
||||||
# 1. 运行上下文编码器
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
|
||||||
input_ids, pinyin_ids, attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 运行解码器
|
|
||||||
logits = self.run_decoder(
|
|
||||||
context_H, pinyin_P, history_slot_ids,
|
|
||||||
context_mask, pinyin_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""示例主函数"""
|
|
||||||
print("ONNX模型推理示例")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 初始化推理器
|
|
||||||
context_encoder_path = "context_encoder.onnx"
|
|
||||||
decoder_path = "decoder.onnx"
|
|
||||||
|
|
||||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
|
||||||
print("错误: 找不到ONNX模型文件")
|
|
||||||
print("请先运行export_onnx.py导出模型")
|
|
||||||
return
|
|
||||||
|
|
||||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
|
||||||
|
|
||||||
print("✅ ONNX推理器初始化完成")
|
|
||||||
print("请参考此示例实现完整的输入法推理流程")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
'''
|
|
||||||
|
|
||||||
with open(example_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(example_code)
|
|
||||||
|
|
||||||
print(f"✅ 推理示例脚本已保存到: {example_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -485,77 +47,39 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 检查依赖
|
|
||||||
if not check_onnx_available():
|
if not check_onnx_available():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 创建输出目录
|
print(f"输出目录: {Path(args.output_dir).absolute()}")
|
||||||
|
|
||||||
|
context_encoder_path, decoder_path, config = run_full_export(
|
||||||
|
checkpoint_path=args.checkpoint,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
device=args.device,
|
||||||
|
skip_verification=args.skip_verification,
|
||||||
|
)
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
print()
|
||||||
|
|
||||||
print(f"📁 输出目录: {output_dir.absolute()}")
|
|
||||||
|
|
||||||
# 加载checkpoint并创建导出模型
|
|
||||||
print(f"📦 加载checkpoint: {args.checkpoint}")
|
|
||||||
context_encoder_export, decoder_export, config = (
|
|
||||||
create_export_models_from_checkpoint(args.checkpoint, args.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"📊 模型配置: {config}")
|
|
||||||
|
|
||||||
# 导出上下文编码器
|
|
||||||
context_encoder_path = output_dir / "context_encoder.onnx"
|
|
||||||
example_inputs = export_context_encoder(
|
|
||||||
context_encoder_export, str(context_encoder_path), config
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用上下文编码器的输出作为解码器的示例输入
|
|
||||||
with torch.no_grad():
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
|
|
||||||
*example_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
# 导出解码器
|
|
||||||
decoder_path = output_dir / "decoder.onnx"
|
|
||||||
export_decoder(
|
|
||||||
decoder_export,
|
|
||||||
str(decoder_path),
|
|
||||||
config,
|
|
||||||
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存示例输入
|
|
||||||
example_inputs_dict = {
|
|
||||||
"input_ids": example_inputs[0],
|
|
||||||
"pinyin_ids": example_inputs[1],
|
|
||||||
"attention_mask": example_inputs[2],
|
|
||||||
"context_H": context_H,
|
|
||||||
"pinyin_P": pinyin_P,
|
|
||||||
"context_mask": context_mask,
|
|
||||||
"pinyin_mask": pinyin_mask,
|
|
||||||
}
|
|
||||||
save_example_inputs(output_dir, example_inputs_dict)
|
|
||||||
|
|
||||||
# 创建推理示例脚本
|
|
||||||
create_inference_example(output_dir, config)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🎉 ONNX导出完成!")
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"生成的模型文件:")
|
print("ONNX导出完成!")
|
||||||
|
print("=" * 60)
|
||||||
|
print("生成的模型文件:")
|
||||||
print(f" - {context_encoder_path}")
|
print(f" - {context_encoder_path}")
|
||||||
print(f" - {decoder_path}")
|
print(f" - {decoder_path}")
|
||||||
print(f" - {output_dir}/example_inputs.npz")
|
print(f" - {output_dir / 'example_inputs.npz'}")
|
||||||
print(f" - {output_dir}/example_inputs.pt")
|
print(f" - {output_dir / 'example_inputs.pt'}")
|
||||||
print(f" - {output_dir}/inference_example.py")
|
print(f" - {output_dir / 'inference_example.py'}")
|
||||||
print("\n使用方法:")
|
print()
|
||||||
|
print("使用方法:")
|
||||||
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
|
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
|
||||||
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
|
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
|
||||||
print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类")
|
print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类")
|
||||||
print("\n注意:")
|
print()
|
||||||
print(" - 请确保安装了onnxruntime: pip install onnxruntime")
|
print("注意:")
|
||||||
print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu")
|
print(" - 请确保安装了 onnxruntime: pip install onnxruntime")
|
||||||
print(" - 束搜索算法需要根据实际需求进行调整")
|
print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu")
|
||||||
|
print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ ONNX模型推理示例
|
||||||
展示如何使用导出的两个ONNX模型进行推理,
|
展示如何使用导出的两个ONNX模型进行推理,
|
||||||
包括束搜索(beam search)算法。
|
包括束搜索(beam search)算法。
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -15,94 +15,46 @@ from typing import List, Tuple
|
||||||
|
|
||||||
class ONNXInference:
|
class ONNXInference:
|
||||||
"""ONNX模型推理器"""
|
"""ONNX模型推理器"""
|
||||||
|
|
||||||
def __init__(self, context_encoder_path, decoder_path):
|
def __init__(self, context_encoder_path, decoder_path):
|
||||||
"""
|
|
||||||
初始化ONNX推理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_encoder_path: 上下文编码器ONNX模型路径
|
|
||||||
decoder_path: 解码器ONNX模型路径
|
|
||||||
"""
|
|
||||||
# 创建ONNX Runtime会话
|
|
||||||
self.context_encoder_session = ort.InferenceSession(
|
self.context_encoder_session = ort.InferenceSession(
|
||||||
context_encoder_path,
|
context_encoder_path,
|
||||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
providers=['CPUExecutionProvider']
|
||||||
)
|
)
|
||||||
self.decoder_session = ort.InferenceSession(
|
self.decoder_session = ort.InferenceSession(
|
||||||
decoder_path,
|
decoder_path,
|
||||||
providers=['CPUExecutionProvider']
|
providers=['CPUExecutionProvider']
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取输入输出名称
|
|
||||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||||
|
|
||||||
print(f"上下文编码器输入: {self.context_input_names}")
|
print(f"上下文编码器输入: {self.context_input_names}")
|
||||||
print(f"上下文编码器输出: {self.context_output_names}")
|
print(f"上下文编码器输出: {self.context_output_names}")
|
||||||
print(f"解码器输入: {self.decoder_input_names}")
|
print(f"解码器输入: {self.decoder_input_names}")
|
||||||
print(f"解码器输出: {self.decoder_output_names}")
|
print(f"解码器输出: {self.decoder_output_names}")
|
||||||
|
|
||||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||||
"""
|
|
||||||
准备模型输入(与原始推理脚本保持一致)
|
|
||||||
|
|
||||||
注意: 这里需要实现文本到token的转换
|
|
||||||
为了简化示例,假设已经实现了相关函数
|
|
||||||
"""
|
|
||||||
# 这里应该调用实际的预处理函数
|
|
||||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
|
||||||
raise NotImplementedError("请实现实际的输入预处理")
|
raise NotImplementedError("请实现实际的输入预处理")
|
||||||
|
|
||||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||||
"""
|
|
||||||
运行上下文编码器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: [batch, seq_len]
|
|
||||||
pinyin_ids: [batch, 24]
|
|
||||||
attention_mask: [batch, seq_len]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask
|
|
||||||
"""
|
|
||||||
# 准备输入
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行推理
|
|
||||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||||
|
|
||||||
# 解包输出
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||||
|
|
||||||
return (
|
return (
|
||||||
torch.from_numpy(context_H),
|
torch.from_numpy(context_H),
|
||||||
torch.from_numpy(pinyin_P),
|
torch.from_numpy(pinyin_P),
|
||||||
torch.from_numpy(context_mask),
|
torch.from_numpy(context_mask),
|
||||||
torch.from_numpy(pinyin_mask),
|
torch.from_numpy(pinyin_mask),
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||||
"""
|
|
||||||
运行解码器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_H: [batch, seq_len, 512]
|
|
||||||
pinyin_P: [batch, 24, 512]
|
|
||||||
history_slot_ids: [batch, 8]
|
|
||||||
context_mask: [batch, seq_len]
|
|
||||||
pinyin_mask: [batch, 24]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits: [batch, vocab_size]
|
|
||||||
"""
|
|
||||||
# 准备输入
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||||
|
|
@ -110,99 +62,46 @@ class ONNXInference:
|
||||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行推理
|
|
||||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||||
|
|
||||||
# 解包输出
|
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
return torch.from_numpy(logits)
|
return torch.from_numpy(logits)
|
||||||
|
|
||||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||||
beam_size=5, max_length=10, vocab_size=10019):
|
beam_size=5, max_length=10, vocab_size=10019):
|
||||||
"""
|
beams = [([], 0.0)]
|
||||||
束搜索算法示例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context_H: 上下文编码
|
|
||||||
pinyin_P: 拼音编码
|
|
||||||
context_mask: 上下文掩码
|
|
||||||
pinyin_mask: 拼音掩码
|
|
||||||
beam_size: 束大小
|
|
||||||
max_length: 最大生成长度
|
|
||||||
vocab_size: 词汇表大小
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
最佳序列列表
|
|
||||||
"""
|
|
||||||
# 初始束:空序列,分数为0
|
|
||||||
beams = [([], 0.0)] # (序列, 对数概率)
|
|
||||||
|
|
||||||
for step in range(max_length):
|
for step in range(max_length):
|
||||||
new_beams = []
|
new_beams = []
|
||||||
|
|
||||||
for seq, score in beams:
|
for seq, score in beams:
|
||||||
# 构建history_slot_ids:已确认的字符ID
|
|
||||||
if len(seq) < 8:
|
if len(seq) < 8:
|
||||||
history = seq + [0] * (8 - len(seq))
|
history = seq + [0] * (8 - len(seq))
|
||||||
else:
|
else:
|
||||||
history = seq[-8:] # 只保留最近8个
|
history = seq[-8:]
|
||||||
|
|
||||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||||
|
|
||||||
# 运行解码器
|
|
||||||
logits = self.run_decoder(
|
logits = self.run_decoder(
|
||||||
context_H, pinyin_P, history_tensor,
|
context_H, pinyin_P, history_tensor,
|
||||||
context_mask, pinyin_mask
|
context_mask, pinyin_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取概率
|
|
||||||
probs = F.softmax(logits[0], dim=-1)
|
probs = F.softmax(logits[0], dim=-1)
|
||||||
|
|
||||||
# 获取top-k候选
|
|
||||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||||
|
|
||||||
# 扩展束
|
|
||||||
for prob, idx in zip(top_probs, top_indices):
|
for prob, idx in zip(top_probs, top_indices):
|
||||||
new_seq = seq + [idx.item()]
|
new_seq = seq + [idx.item()]
|
||||||
new_score = score + torch.log(prob).item()
|
new_score = score + torch.log(prob).item()
|
||||||
new_beams.append((new_seq, new_score))
|
new_beams.append((new_seq, new_score))
|
||||||
|
|
||||||
# 剪枝:保留beam_size个最佳候选
|
|
||||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||||
beams = new_beams[:beam_size]
|
beams = new_beams[:beam_size]
|
||||||
|
|
||||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
|
||||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||||
if all_ended:
|
if all_ended:
|
||||||
break
|
break
|
||||||
|
|
||||||
return beams
|
return beams
|
||||||
|
|
||||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||||
"""
|
|
||||||
单步预测
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: 输入token IDs
|
|
||||||
pinyin_ids: 拼音IDs
|
|
||||||
attention_mask: 注意力掩码
|
|
||||||
history_slot_ids: 历史槽位IDs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
预测logits
|
|
||||||
"""
|
|
||||||
# 1. 运行上下文编码器
|
|
||||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||||
input_ids, pinyin_ids, attention_mask
|
input_ids, pinyin_ids, attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 运行解码器
|
|
||||||
logits = self.run_decoder(
|
logits = self.run_decoder(
|
||||||
context_H, pinyin_P, history_slot_ids,
|
context_H, pinyin_P, history_slot_ids,
|
||||||
context_mask, pinyin_mask
|
context_mask, pinyin_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -210,21 +109,19 @@ def main():
|
||||||
"""示例主函数"""
|
"""示例主函数"""
|
||||||
print("ONNX模型推理示例")
|
print("ONNX模型推理示例")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 初始化推理器
|
|
||||||
context_encoder_path = "context_encoder.onnx"
|
context_encoder_path = "context_encoder.onnx"
|
||||||
decoder_path = "decoder.onnx"
|
decoder_path = "decoder.onnx"
|
||||||
|
|
||||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||||
print("错误: 找不到ONNX模型文件")
|
print("错误: 找不到ONNX模型文件")
|
||||||
print("请先运行export_onnx.py导出模型")
|
print("请先运行 train-model export 导出模型")
|
||||||
return
|
return
|
||||||
|
|
||||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||||
|
print("\u2705 ONNX推理器初始化完成")
|
||||||
print("✅ ONNX推理器初始化完成")
|
|
||||||
print("请参考此示例实现完整的输入法推理流程")
|
print("请参考此示例实现完整的输入法推理流程")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,409 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
临时迁移训练脚本:在预训练模型基础上重新训练,支持冻结 context_encoder。
|
||||||
|
|
||||||
|
与 src/model/trainer.py 的 train 命令行为完全一致,额外增加:
|
||||||
|
- --pretrained-checkpoint: 加载预训练权重(必需的迁移学习源)
|
||||||
|
- --freeze-context-encoder: 冻结 context_encoder 层(默认开启)
|
||||||
|
|
||||||
|
运行方式:
|
||||||
|
python scripts/finetune_slots.py \
|
||||||
|
--pretrained-checkpoint ./output/checkpoints/best_model.pt \
|
||||||
|
--train-data-path /path/to/train_data \
|
||||||
|
--eval-data-path /path/to/eval_data \
|
||||||
|
--output-dir ./finetune_output \
|
||||||
|
--freeze-context-encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||||
|
|
||||||
|
from model.model import InputMethodEngine
|
||||||
|
from model.trainer import (
|
||||||
|
Trainer,
|
||||||
|
create_dataloader,
|
||||||
|
worker_init_fn,
|
||||||
|
)
|
||||||
|
from model.dataset import PinyinInputDataset
|
||||||
|
from model.preprocessed_dataset import (
|
||||||
|
PreProcessedDataset,
|
||||||
|
is_preprocessed_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="迁移学习训练:加载预训练模型,冻结指定层后重新训练",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# === 数据参数 ===
|
||||||
|
parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径")
|
||||||
|
parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径")
|
||||||
|
parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录")
|
||||||
|
parser.add_argument("--max-iter-length", type=int,
|
||||||
|
default=1024 * 1024 * 128, help="每个 epoch 最大样本数")
|
||||||
|
|
||||||
|
# === 迁移学习参数 ===
|
||||||
|
parser.add_argument("--pretrained-checkpoint", "-c", required=True,
|
||||||
|
help="预训练模型检查点路径")
|
||||||
|
parser.add_argument("--freeze-context-encoder", action="store_true", default=True,
|
||||||
|
help="冻结 context_encoder 层 (默认开启)")
|
||||||
|
parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder",
|
||||||
|
action="store_false",
|
||||||
|
help="不冻结 context_encoder")
|
||||||
|
|
||||||
|
# === 训练参数 ===
|
||||||
|
parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小")
|
||||||
|
parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数")
|
||||||
|
parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4,
|
||||||
|
help="学习率")
|
||||||
|
parser.add_argument("--min-learning-rate", type=float, default=1e-9,
|
||||||
|
help="最小学习率")
|
||||||
|
parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减")
|
||||||
|
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例")
|
||||||
|
parser.add_argument("--label-smoothing", type=float, default=0.1,
|
||||||
|
help="标签平滑参数")
|
||||||
|
parser.add_argument("--grad-accum-steps", type=int, default=1,
|
||||||
|
help="梯度累积步数")
|
||||||
|
parser.add_argument("--clip-grad-norm", type=float, default=1.0,
|
||||||
|
help="梯度裁剪范数")
|
||||||
|
parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率")
|
||||||
|
parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率")
|
||||||
|
|
||||||
|
# === 其他参数 ===
|
||||||
|
parser.add_argument("--mixed-precision", action="store_true", default=True)
|
||||||
|
parser.add_argument("--no-mixed-precision", dest="mixed_precision",
|
||||||
|
action="store_false", help="禁用混合精度")
|
||||||
|
parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数")
|
||||||
|
parser.add_argument("--tensorboard", action="store_true", default=True)
|
||||||
|
parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false",
|
||||||
|
help="禁用 TensorBoard")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||||
|
parser.add_argument("--compile", action="store_true", default=False,
|
||||||
|
help="使用 torch.compile 优化")
|
||||||
|
parser.add_argument("--moe-mode", default="all",
|
||||||
|
choices=["all", "sparse", "sparse_allow_graph"],
|
||||||
|
help="MoE 计算策略")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 初始化
|
||||||
|
# ================================================================
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
output_path = Path(args.output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 模型常量 (与 trainer.py 保持一致)
|
||||||
|
# ================================================================
|
||||||
|
vocab_size = 10019
|
||||||
|
pinyin_vocab_size = 30
|
||||||
|
dim = 512
|
||||||
|
num_slots = 8
|
||||||
|
n_layers = 4
|
||||||
|
n_heads = 4
|
||||||
|
num_experts = 10
|
||||||
|
max_seq_len = 128
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 打印配置
|
||||||
|
# ================================================================
|
||||||
|
console.print(Panel.fit(
|
||||||
|
"[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan"))
|
||||||
|
config_table = Table(show_header=True, header_style="bold magenta")
|
||||||
|
config_table.add_column("Category", style="cyan")
|
||||||
|
config_table.add_column("Parameter", style="green")
|
||||||
|
config_table.add_column("Value", style="yellow")
|
||||||
|
|
||||||
|
config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint)
|
||||||
|
config_table.add_row("迁移学习", "冻结 context_encoder",
|
||||||
|
str(args.freeze_context_encoder))
|
||||||
|
config_table.add_row("数据", "训练数据路径", args.train_data_path)
|
||||||
|
config_table.add_row("数据", "评估数据路径", args.eval_data_path)
|
||||||
|
config_table.add_row("数据", "输出目录", args.output_dir)
|
||||||
|
config_table.add_row("数据", "批次大小", str(args.batch_size))
|
||||||
|
config_table.add_row("数据", "Worker数量", str(args.num_workers))
|
||||||
|
config_table.add_row("模型", "MoE策略", args.moe_mode)
|
||||||
|
config_table.add_row("模型", "编译优化", str(args.compile))
|
||||||
|
config_table.add_row("训练", "训练轮数", str(args.num_epochs))
|
||||||
|
config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}")
|
||||||
|
config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}")
|
||||||
|
config_table.add_row("训练", "权重衰减", str(args.weight_decay))
|
||||||
|
config_table.add_row("训练", "热身比例", str(args.warmup_ratio))
|
||||||
|
config_table.add_row("训练", "标签平滑", str(args.label_smoothing))
|
||||||
|
config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps))
|
||||||
|
config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm))
|
||||||
|
config_table.add_row("训练", "混合精度", str(args.mixed_precision))
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 创建数据加载器 (逻辑与 trainer.py CLI 完全一致)
|
||||||
|
# ================================================================
|
||||||
|
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||||
|
|
||||||
|
is_train_preprocessed = is_preprocessed_data(args.train_data_path)
|
||||||
|
is_eval_preprocessed = is_preprocessed_data(args.eval_data_path)
|
||||||
|
|
||||||
|
if is_train_preprocessed:
|
||||||
|
train_dataset = PreProcessedDataset(args.train_data_path,
|
||||||
|
max_cache_shards=2)
|
||||||
|
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
|
||||||
|
shuffle_train = not pre_shuffled
|
||||||
|
if args.max_iter_length > 0:
|
||||||
|
capped_samples = min(len(train_dataset), args.max_iter_length)
|
||||||
|
else:
|
||||||
|
capped_samples = len(train_dataset)
|
||||||
|
total_steps = (capped_samples // args.batch_size) * args.num_epochs
|
||||||
|
train_num_workers = min(args.num_workers, 1)
|
||||||
|
logger.info(
|
||||||
|
f"Preprocessed dataset: {len(train_dataset):,} samples, "
|
||||||
|
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
|
||||||
|
f"workers={train_num_workers}, steps={total_steps:,}")
|
||||||
|
train_dataloader = create_dataloader(
|
||||||
|
dataset=train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=train_num_workers,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
shuffle=shuffle_train,
|
||||||
|
)
|
||||||
|
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
||||||
|
else:
|
||||||
|
train_dataset = PinyinInputDataset(
|
||||||
|
data_path=args.train_data_path,
|
||||||
|
max_workers=-1,
|
||||||
|
max_iter_length=args.max_iter_length,
|
||||||
|
max_seq_length=max_seq_len,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=2000000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
|
||||||
|
5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
total_steps = int(args.max_iter_length *
|
||||||
|
args.num_epochs / args.batch_size)
|
||||||
|
train_dataloader = create_dataloader(
|
||||||
|
dataset=train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
max_iter_length=args.max_iter_length,
|
||||||
|
)
|
||||||
|
config_table.add_row("数据", "训练数据类型", "流式数据")
|
||||||
|
|
||||||
|
if is_eval_preprocessed:
|
||||||
|
eval_dataset = PreProcessedDataset(args.eval_data_path,
|
||||||
|
max_cache_shards=1)
|
||||||
|
eval_dataloader = create_dataloader(
|
||||||
|
dataset=eval_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
config_table.add_row("数据", "评估数据类型", "预处理数据")
|
||||||
|
else:
|
||||||
|
eval_dataset = PinyinInputDataset(
|
||||||
|
data_path=args.eval_data_path,
|
||||||
|
max_workers=-1,
|
||||||
|
max_iter_length=args.batch_size * 64,
|
||||||
|
max_seq_length=max_seq_len,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=2000000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
|
||||||
|
5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
eval_dataloader = create_dataloader(
|
||||||
|
dataset=eval_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=2,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
max_iter_length=args.batch_size * 64,
|
||||||
|
)
|
||||||
|
config_table.add_row("数据", "评估数据类型", "流式数据")
|
||||||
|
|
||||||
|
config_table.add_row("数据", "总步数", str(total_steps))
|
||||||
|
console.print(config_table)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 创建模型并加载预训练权重
|
||||||
|
# ================================================================
|
||||||
|
console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]")
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
pinyin_vocab_size=pinyin_vocab_size,
|
||||||
|
dim=dim,
|
||||||
|
num_slots=num_slots,
|
||||||
|
n_layers=n_layers,
|
||||||
|
n_heads=n_heads,
|
||||||
|
num_experts=num_experts,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
compile=args.compile,
|
||||||
|
moe_mode=args.moe_mode,
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# 加载预训练权重
|
||||||
|
pretrained_path = Path(args.pretrained_checkpoint)
|
||||||
|
if not pretrained_path.exists():
|
||||||
|
console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device)
|
||||||
|
if "model_state_dict" in checkpoint:
|
||||||
|
pretrained_weights = checkpoint["model_state_dict"]
|
||||||
|
else:
|
||||||
|
pretrained_weights = checkpoint
|
||||||
|
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(
|
||||||
|
pretrained_weights, strict=False)
|
||||||
|
|
||||||
|
if missing_keys:
|
||||||
|
console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): "
|
||||||
|
f"{missing_keys[:5]}...[/yellow]")
|
||||||
|
if unexpected_keys:
|
||||||
|
console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): "
|
||||||
|
f"{unexpected_keys[:5]}...[/yellow]")
|
||||||
|
|
||||||
|
console.print(f"[green]✓ 预训练权重加载完成[/green]")
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 冻结 context_encoder
|
||||||
|
# ================================================================
|
||||||
|
if args.freeze_context_encoder:
|
||||||
|
console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]")
|
||||||
|
frozen_count = 0
|
||||||
|
trainable_count = 0
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name.startswith("context_encoder"):
|
||||||
|
param.requires_grad = False
|
||||||
|
frozen_count += param.numel()
|
||||||
|
else:
|
||||||
|
trainable_count += param.numel()
|
||||||
|
|
||||||
|
total_params = frozen_count + trainable_count
|
||||||
|
console.print(f"[green]✓ context_encoder 已冻结[/green]")
|
||||||
|
logger.info(
|
||||||
|
f"冻结参数: {frozen_count:,} / {total_params:,} "
|
||||||
|
f"({frozen_count / total_params * 100:.1f}%), "
|
||||||
|
f"可训练参数: {trainable_count:,} / {total_params:,} "
|
||||||
|
f"({trainable_count / total_params * 100:.1f}%)")
|
||||||
|
else:
|
||||||
|
logger.info("未冻结任何层,全模型参与训练")
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 保存配置
|
||||||
|
# ================================================================
|
||||||
|
config = {
|
||||||
|
"pretrained_checkpoint": args.pretrained_checkpoint,
|
||||||
|
"freeze_context_encoder": args.freeze_context_encoder,
|
||||||
|
"train_data_path": args.train_data_path,
|
||||||
|
"eval_data_path": args.eval_data_path,
|
||||||
|
"output_dir": args.output_dir,
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"num_epochs": args.num_epochs,
|
||||||
|
"learning_rate": args.learning_rate,
|
||||||
|
"min_learning_rate": args.min_learning_rate,
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
"warmup_ratio": args.warmup_ratio,
|
||||||
|
"label_smoothing": args.label_smoothing,
|
||||||
|
"grad_accum_steps": args.grad_accum_steps,
|
||||||
|
"clip_grad_norm": args.clip_grad_norm,
|
||||||
|
"eval_frequency": args.eval_frequency,
|
||||||
|
"save_frequency": args.save_frequency,
|
||||||
|
"mixed_precision": args.mixed_precision,
|
||||||
|
"num_workers": args.num_workers,
|
||||||
|
"use_tensorboard": args.tensorboard,
|
||||||
|
"seed": args.seed,
|
||||||
|
"compile": args.compile,
|
||||||
|
"moe_mode": args.moe_mode,
|
||||||
|
"total_steps": total_steps,
|
||||||
|
"vocab_size": vocab_size,
|
||||||
|
"pinyin_vocab_size": pinyin_vocab_size,
|
||||||
|
"dim": dim,
|
||||||
|
"num_slots": num_slots,
|
||||||
|
"n_layers": n_layers,
|
||||||
|
"n_heads": n_heads,
|
||||||
|
"num_experts": num_experts,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
"is_train_preprocessed": is_train_preprocessed,
|
||||||
|
"is_eval_preprocessed": is_eval_preprocessed,
|
||||||
|
}
|
||||||
|
config_file = output_path / "training_config.json"
|
||||||
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||||
|
logger.info(f"Configuration saved to {config_file}")
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# 创建 Trainer 并开始训练
|
||||||
|
# ================================================================
|
||||||
|
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
total_steps=total_steps,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
num_epochs=args.num_epochs,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
min_learning_rate=args.min_learning_rate,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
warmup_ratio=args.warmup_ratio,
|
||||||
|
label_smoothing=args.label_smoothing,
|
||||||
|
grad_accum_steps=args.grad_accum_steps,
|
||||||
|
clip_grad_norm=args.clip_grad_norm,
|
||||||
|
eval_frequency=args.eval_frequency,
|
||||||
|
save_frequency=args.save_frequency,
|
||||||
|
mixed_precision=args.mixed_precision,
|
||||||
|
device=device,
|
||||||
|
use_tensorboard=args.tensorboard,
|
||||||
|
status_file="training_status.json",
|
||||||
|
)
|
||||||
|
console.print("[green]✓ 训练器创建完成[/green]")
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
||||||
|
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
trainer.train(
|
||||||
|
resume_from=None,
|
||||||
|
reset_training_state=False,
|
||||||
|
auto_resume=False, # 迁移学习从头开始,不自动恢复
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("[bold green]训练被终止[/bold green]")
|
||||||
|
trainer.save_checkpoint("interrupted_model.pt")
|
||||||
|
|
||||||
|
console.print("[bold green]✓ 训练完成![/bold green]")
|
||||||
|
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
console.print(f"模型和日志保存在: {args.output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -53,27 +53,7 @@ class PinyinLSTMEncoder(nn.Module):
|
||||||
self.layer_norm = nn.LayerNorm(input_dim)
|
self.layer_norm = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
"""
|
output, _ = self.lstm(x)
|
||||||
Args:
|
|
||||||
x: [batch, seq_len, input_dim] pinyin embeddings
|
|
||||||
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
|
|
||||||
Returns:
|
|
||||||
output: [batch, seq_len, input_dim] 每位置的拼音编码
|
|
||||||
"""
|
|
||||||
total_len = x.size(1)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
lengths = mask.sum(dim=1).cpu().clamp(min=1)
|
|
||||||
packed = nn.utils.rnn.pack_padded_sequence(
|
|
||||||
x, lengths, batch_first=True, enforce_sorted=False
|
|
||||||
)
|
|
||||||
packed_out, (hidden, cell) = self.lstm(packed)
|
|
||||||
output, _ = nn.utils.rnn.pad_packed_sequence(
|
|
||||||
packed_out, batch_first=True, total_length=total_len
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output, (hidden, cell) = self.lstm(x)
|
|
||||||
|
|
||||||
projected = self.proj(output)
|
projected = self.proj(output)
|
||||||
return self.layer_norm(projected)
|
return self.layer_norm(projected)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,10 @@ class InputMethodEngine(nn.Module):
|
||||||
|
|
||||||
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
||||||
|
|
||||||
|
is_zero = (history_slot_ids == 0)
|
||||||
|
cumsum_zero = is_zero.cumsum(dim=1)
|
||||||
|
slot_mask = (cumsum_zero <= 1).to(torch.get_default_dtype())
|
||||||
|
|
||||||
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||||||
|
|
||||||
S = self.slot_memory(history_slot_ids)
|
S = self.slot_memory(history_slot_ids)
|
||||||
|
|
@ -120,10 +124,13 @@ class InputMethodEngine(nn.Module):
|
||||||
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
|
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fused = fused * slot_mask.unsqueeze(-1)
|
||||||
|
|
||||||
moe_out = self.moe(fused)
|
moe_out = self.moe(fused)
|
||||||
|
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||||||
|
slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9)
|
||||||
slot_weights = torch.softmax(slot_scores, dim=1)
|
slot_weights = torch.softmax(slot_scores, dim=1)
|
||||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,366 @@
|
||||||
|
"""
|
||||||
|
ONNX模型导出核心逻辑
|
||||||
|
|
||||||
|
将 InputMethodEngine 模型导出为两个 ONNX 文件:
|
||||||
|
1. context_encoder.onnx - 上下文编码器(可复用)
|
||||||
|
2. decoder.onnx - 解码器
|
||||||
|
|
||||||
|
共用此模块的入口:
|
||||||
|
- CLI: train-model export
|
||||||
|
- 脚本: python export_onnx.py(薄壳)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .export_models import create_export_models_from_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def check_onnx_available() -> bool:
|
||||||
|
try:
|
||||||
|
import onnx # noqa: F401
|
||||||
|
import onnxruntime as ort # noqa: F401
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
print("错误: ONNX导出需要以下依赖:")
|
||||||
|
print(" pip install onnx onnxruntime")
|
||||||
|
print("请安装后重试")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def export_context_encoder(
|
||||||
|
model,
|
||||||
|
output_path: str,
|
||||||
|
config: Dict,
|
||||||
|
skip_verification: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = config.get("max_seq_len", 128)
|
||||||
|
pinyin_len = 24
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||||
|
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||||
|
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||||
|
|
||||||
|
dynamic_axes = {
|
||||||
|
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||||||
|
"pinyin_ids": {0: "batch_size"},
|
||||||
|
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||||||
|
"context_H": {0: "batch_size", 1: "seq_len"},
|
||||||
|
"pinyin_P": {0: "batch_size"},
|
||||||
|
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||||||
|
"pinyin_mask": {0: "batch_size"},
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(input_ids, pinyin_ids, attention_mask),
|
||||||
|
output_path,
|
||||||
|
input_names=["input_ids", "pinyin_ids", "attention_mask"],
|
||||||
|
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
opset_version=18,
|
||||||
|
do_constant_folding=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 上下文编码器导出完成 -> {output_path}")
|
||||||
|
|
||||||
|
if not skip_verification:
|
||||||
|
try:
|
||||||
|
import onnx
|
||||||
|
onnx.checker.check_model(onnx.load(output_path))
|
||||||
|
print(" ONNX 模型验证通过")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ONNX 模型验证警告: {e}")
|
||||||
|
|
||||||
|
return input_ids, pinyin_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder(
|
||||||
|
model,
|
||||||
|
output_path: str,
|
||||||
|
config: Dict,
|
||||||
|
example_inputs: Optional[Tuple] = None,
|
||||||
|
skip_verification: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = config.get("max_seq_len", 128)
|
||||||
|
pinyin_len = 24
|
||||||
|
dim = config.get("dim", 512)
|
||||||
|
num_slots = config.get("num_slots", 8)
|
||||||
|
|
||||||
|
if example_inputs is not None:
|
||||||
|
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
|
||||||
|
batch_size = context_H.size(0)
|
||||||
|
history_slot_ids = torch.randint(
|
||||||
|
0, 100, (batch_size, num_slots), dtype=torch.long
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||||||
|
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||||||
|
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||||||
|
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||||||
|
history_slot_ids = torch.randint(
|
||||||
|
0, 100, (batch_size, num_slots), dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
dynamic_axes = {
|
||||||
|
"context_H": {0: "batch_size", 1: "seq_len"},
|
||||||
|
"pinyin_P": {0: "batch_size"},
|
||||||
|
"history_slot_ids": {0: "batch_size"},
|
||||||
|
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||||||
|
"pinyin_mask": {0: "batch_size"},
|
||||||
|
"logits": {0: "batch_size"},
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
|
||||||
|
output_path,
|
||||||
|
input_names=[
|
||||||
|
"context_H",
|
||||||
|
"pinyin_P",
|
||||||
|
"history_slot_ids",
|
||||||
|
"context_mask",
|
||||||
|
"pinyin_mask",
|
||||||
|
],
|
||||||
|
output_names=["logits"],
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
opset_version=18,
|
||||||
|
do_constant_folding=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 解码器导出完成 -> {output_path}")
|
||||||
|
|
||||||
|
if not skip_verification:
|
||||||
|
try:
|
||||||
|
import onnx
|
||||||
|
onnx.checker.check_model(onnx.load(output_path))
|
||||||
|
print(" ONNX 模型验证通过")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ONNX 模型验证警告: {e}")
|
||||||
|
|
||||||
|
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||||
|
|
||||||
|
|
||||||
|
def save_example_inputs(output_dir: str, example_inputs_dict: Dict) -> None:
|
||||||
|
npz_path = os.path.join(output_dir, "example_inputs.npz")
|
||||||
|
np_data = {}
|
||||||
|
for key, tensor in example_inputs_dict.items():
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
np_data[key] = tensor.cpu().numpy()
|
||||||
|
elif isinstance(tensor, tuple):
|
||||||
|
for i, t in enumerate(tensor):
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
np_data[f"{key}_{i}"] = t.cpu().numpy()
|
||||||
|
np.savez(npz_path, **np_data)
|
||||||
|
print(f" 示例输入已保存到: {npz_path}")
|
||||||
|
|
||||||
|
torch_path = os.path.join(output_dir, "example_inputs.pt")
|
||||||
|
torch.save(example_inputs_dict, torch_path)
|
||||||
|
print(f" PyTorch 示例输入已保存到: {torch_path}")
|
||||||
|
|
||||||
|
|
||||||
|
_INFERENCE_EXAMPLE_TEMPLATE = '''#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
ONNX模型推理示例
|
||||||
|
|
||||||
|
展示如何使用导出的两个ONNX模型进行推理,
|
||||||
|
包括束搜索(beam search)算法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXInference:
|
||||||
|
"""ONNX模型推理器"""
|
||||||
|
|
||||||
|
def __init__(self, context_encoder_path, decoder_path):
|
||||||
|
self.context_encoder_session = ort.InferenceSession(
|
||||||
|
context_encoder_path,
|
||||||
|
providers=['CPUExecutionProvider']
|
||||||
|
)
|
||||||
|
self.decoder_session = ort.InferenceSession(
|
||||||
|
decoder_path,
|
||||||
|
providers=['CPUExecutionProvider']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||||
|
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||||
|
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||||
|
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||||
|
|
||||||
|
print(f"上下文编码器输入: {self.context_input_names}")
|
||||||
|
print(f"上下文编码器输出: {self.context_output_names}")
|
||||||
|
print(f"解码器输入: {self.decoder_input_names}")
|
||||||
|
print(f"解码器输出: {self.decoder_output_names}")
|
||||||
|
|
||||||
|
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||||
|
raise NotImplementedError("请实现实际的输入预处理")
|
||||||
|
|
||||||
|
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||||
|
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||||
|
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||||
|
}
|
||||||
|
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||||
|
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||||
|
return (
|
||||||
|
torch.from_numpy(context_H),
|
||||||
|
torch.from_numpy(pinyin_P),
|
||||||
|
torch.from_numpy(context_mask),
|
||||||
|
torch.from_numpy(pinyin_mask),
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||||
|
inputs = {
|
||||||
|
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||||
|
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||||
|
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||||
|
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||||
|
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||||
|
}
|
||||||
|
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||||
|
logits = outputs[0]
|
||||||
|
return torch.from_numpy(logits)
|
||||||
|
|
||||||
|
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||||
|
beam_size=5, max_length=10, vocab_size=10019):
|
||||||
|
beams = [([], 0.0)]
|
||||||
|
for step in range(max_length):
|
||||||
|
new_beams = []
|
||||||
|
for seq, score in beams:
|
||||||
|
if len(seq) < 8:
|
||||||
|
history = seq + [0] * (8 - len(seq))
|
||||||
|
else:
|
||||||
|
history = seq[-8:]
|
||||||
|
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||||
|
logits = self.run_decoder(
|
||||||
|
context_H, pinyin_P, history_tensor,
|
||||||
|
context_mask, pinyin_mask
|
||||||
|
)
|
||||||
|
probs = F.softmax(logits[0], dim=-1)
|
||||||
|
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||||
|
for prob, idx in zip(top_probs, top_indices):
|
||||||
|
new_seq = seq + [idx.item()]
|
||||||
|
new_score = score + torch.log(prob).item()
|
||||||
|
new_beams.append((new_seq, new_score))
|
||||||
|
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
beams = new_beams[:beam_size]
|
||||||
|
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||||
|
if all_ended:
|
||||||
|
break
|
||||||
|
return beams
|
||||||
|
|
||||||
|
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||||
|
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||||
|
input_ids, pinyin_ids, attention_mask
|
||||||
|
)
|
||||||
|
logits = self.run_decoder(
|
||||||
|
context_H, pinyin_P, history_slot_ids,
|
||||||
|
context_mask, pinyin_mask
|
||||||
|
)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""示例主函数"""
|
||||||
|
print("ONNX模型推理示例")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
context_encoder_path = "context_encoder.onnx"
|
||||||
|
decoder_path = "decoder.onnx"
|
||||||
|
|
||||||
|
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||||
|
print("错误: 找不到ONNX模型文件")
|
||||||
|
print("请先运行 train-model export 导出模型")
|
||||||
|
return
|
||||||
|
|
||||||
|
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||||
|
print("\\u2705 ONNX推理器初始化完成")
|
||||||
|
print("请参考此示例实现完整的输入法推理流程")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def create_inference_example(output_dir: str, config: Dict) -> None:
|
||||||
|
example_path = os.path.join(output_dir, "inference_example.py")
|
||||||
|
with open(example_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(_INFERENCE_EXAMPLE_TEMPLATE)
|
||||||
|
print(f" 推理示例脚本已保存到: {example_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_full_export(
|
||||||
|
checkpoint_path: str,
|
||||||
|
output_dir: str,
|
||||||
|
device: str = "cpu",
|
||||||
|
skip_verification: bool = False,
|
||||||
|
) -> Tuple[str, str, Dict]:
|
||||||
|
"""
|
||||||
|
完整的 ONNX 导出流程
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(context_encoder_path, decoder_path, config)
|
||||||
|
"""
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"加载 checkpoint: {checkpoint_path}")
|
||||||
|
context_encoder_export, decoder_export, config = create_export_models_from_checkpoint(
|
||||||
|
checkpoint_path, device
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"模型配置: vocab_size={config.get('vocab_size')}, "
|
||||||
|
f"dim={config.get('dim')}, "
|
||||||
|
f"num_slots={config.get('num_slots')}, "
|
||||||
|
f"num_experts={config.get('num_experts')}, "
|
||||||
|
f"moe_mode={config.get('moe_mode', 'all')}")
|
||||||
|
|
||||||
|
print("正在导出模型...")
|
||||||
|
|
||||||
|
context_encoder_path = str(output_path / "context_encoder.onnx")
|
||||||
|
example_inputs = export_context_encoder(
|
||||||
|
context_encoder_export, context_encoder_path, config,
|
||||||
|
skip_verification=skip_verification,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
|
||||||
|
*example_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_path = str(output_path / "decoder.onnx")
|
||||||
|
export_decoder(
|
||||||
|
decoder_export, decoder_path, config,
|
||||||
|
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
|
||||||
|
skip_verification=skip_verification,
|
||||||
|
)
|
||||||
|
|
||||||
|
example_inputs_dict = {
|
||||||
|
"input_ids": example_inputs[0],
|
||||||
|
"pinyin_ids": example_inputs[1],
|
||||||
|
"attention_mask": example_inputs[2],
|
||||||
|
"context_H": context_H,
|
||||||
|
"pinyin_P": pinyin_P,
|
||||||
|
"context_mask": context_mask,
|
||||||
|
"pinyin_mask": pinyin_mask,
|
||||||
|
}
|
||||||
|
save_example_inputs(output_dir, example_inputs_dict)
|
||||||
|
create_inference_example(output_dir, config)
|
||||||
|
|
||||||
|
return context_encoder_path, decoder_path, config
|
||||||
|
|
@ -838,6 +838,20 @@ class Trainer:
|
||||||
# epoch 完成
|
# epoch 完成
|
||||||
progress.update(epoch_task, advance=1)
|
progress.update(epoch_task, advance=1)
|
||||||
|
|
||||||
|
# 每个 epoch 结束后评估并保存 best model
|
||||||
|
epoch_eval_metrics = self.evaluate()
|
||||||
|
if epoch_eval_metrics:
|
||||||
|
epoch_log_metrics = {
|
||||||
|
"epoch/eval_loss": epoch_eval_metrics["eval_loss"],
|
||||||
|
"epoch/eval_accuracy": epoch_eval_metrics["eval_accuracy"],
|
||||||
|
}
|
||||||
|
if self.writer is not None:
|
||||||
|
for key, value in epoch_log_metrics.items():
|
||||||
|
self.writer.add_scalar(key, value, global_step)
|
||||||
|
if epoch_eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||||
|
self.best_eval_loss = epoch_eval_metrics["eval_loss"]
|
||||||
|
self.save_checkpoint("best_model.pt", is_best=True)
|
||||||
|
|
||||||
# 每个 epoch 结束后保存检查点
|
# 每个 epoch 结束后保存检查点
|
||||||
self.save_epoch_checkpoint(epoch + 1)
|
self.save_epoch_checkpoint(epoch + 1)
|
||||||
|
|
||||||
|
|
@ -848,6 +862,34 @@ class Trainer:
|
||||||
# 训练完成
|
# 训练完成
|
||||||
logger.info("Training completed!")
|
logger.info("Training completed!")
|
||||||
|
|
||||||
|
# 最终评估并保存 best model
|
||||||
|
logger.info("Running final evaluation...")
|
||||||
|
final_eval_metrics = self.evaluate()
|
||||||
|
if final_eval_metrics:
|
||||||
|
final_log_metrics = {
|
||||||
|
"train/loss": accumulated_loss / accumulation_counter
|
||||||
|
if accumulation_counter > 0
|
||||||
|
else 0.0,
|
||||||
|
"train/accuracy": accumulated_accuracy / accumulation_counter
|
||||||
|
if accumulation_counter > 0
|
||||||
|
else 0.0,
|
||||||
|
"train/learning_rate": self._get_current_lr(),
|
||||||
|
"eval/loss": final_eval_metrics["eval_loss"],
|
||||||
|
"eval/accuracy": final_eval_metrics["eval_accuracy"],
|
||||||
|
}
|
||||||
|
self._log_to_tensorboard(final_log_metrics, global_step)
|
||||||
|
logger.info(
|
||||||
|
f"Final eval loss: {final_eval_metrics['eval_loss']:.4f}, "
|
||||||
|
f"Final eval accuracy: {final_eval_metrics['eval_accuracy']:.4f}"
|
||||||
|
)
|
||||||
|
if final_eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||||
|
self.best_eval_loss = final_eval_metrics["eval_loss"]
|
||||||
|
self.save_checkpoint("best_model.pt", is_best=True)
|
||||||
|
|
||||||
|
# 保存最终模型
|
||||||
|
self.save_checkpoint("final_model.pt")
|
||||||
|
logger.info("Final model saved.")
|
||||||
|
|
||||||
# 显示保存的 epoch checkpoint 信息
|
# 显示保存的 epoch checkpoint 信息
|
||||||
if self.epoch_checkpoints:
|
if self.epoch_checkpoints:
|
||||||
sorted_checkpoints = self.get_epoch_checkpoints()
|
sorted_checkpoints = self.get_epoch_checkpoints()
|
||||||
|
|
@ -1096,6 +1138,7 @@ def create_dataloader(
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
collate_fn=preprocessed_collate_fn,
|
collate_fn=preprocessed_collate_fn,
|
||||||
|
drop_last=True,
|
||||||
persistent_workers=True if num_workers > 0 else False,
|
persistent_workers=True if num_workers > 0 else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1108,7 +1151,8 @@ def create_dataloader(
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
|
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
|
||||||
prefetch_factor=2, # 减少预取以避免内存问题
|
drop_last=True,
|
||||||
|
prefetch_factor=2,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
)
|
)
|
||||||
|
|
@ -1464,21 +1508,46 @@ def evaluate(
|
||||||
@app.command()
|
@app.command()
|
||||||
def export(
|
def export(
|
||||||
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
||||||
output_path: str = typer.Option(
|
output_dir: str = typer.Option(
|
||||||
"./exported_model.onnx", "--output", "-o", help="输出路径"
|
"./exported_models", "--output", "-o", help="输出目录"
|
||||||
|
),
|
||||||
|
device: str = typer.Option("cpu", "--device", help="导出设备(cpu/cuda)"),
|
||||||
|
skip_verification: bool = typer.Option(
|
||||||
|
False, "--skip-verification", help="跳过ONNX模型验证"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
导出模型为ONNX格式
|
导出模型为ONNX格式(生成 context_encoder.onnx 和 decoder.onnx)
|
||||||
"""
|
"""
|
||||||
|
from .onnx_export import check_onnx_available, run_full_export
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
|
console.print("[bold cyan]━━━ ONNX 模型导出 ━━━[/bold cyan]")
|
||||||
|
console.print(f" 检查点: {checkpoint_path}")
|
||||||
|
console.print(f" 输出目录: {output_dir}")
|
||||||
|
console.print(f" 设备: {device}")
|
||||||
|
|
||||||
# 这里应该实现导出逻辑
|
if not check_onnx_available():
|
||||||
# 1. 加载检查点
|
raise typer.Exit(1)
|
||||||
# 2. 导出为ONNX
|
|
||||||
|
|
||||||
console.print("[yellow]导出功能待实现[/yellow]")
|
try:
|
||||||
|
context_encoder_path, decoder_path, config = run_full_export(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
output_dir=output_dir,
|
||||||
|
device=device,
|
||||||
|
skip_verification=skip_verification,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[bold red]导出失败: {e}[/bold red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print("\n[bold green]✓ 导出完成![/bold green]")
|
||||||
|
console.print(f" context_encoder.onnx -> {context_encoder_path}")
|
||||||
|
console.print(f" decoder.onnx -> {decoder_path}")
|
||||||
|
console.print(
|
||||||
|
f"\n MoE 层使用 'all' 模式(全量计算 {config.get('num_experts', '?')} 个专家),"
|
||||||
|
f"稀疏化优化可后续迭代"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue