diff --git a/README.md b/README.md index 4bcaee6..9d10226 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,36 @@ * **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。 * **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。 +#### 4.1.1 历史槽位(history_slot_ids)的设计策略 + +`history_slot_ids` 的语义是**当前拼音组内已经确认的字符**,用于模拟真实的输入法输入场景。 + +**设计原则:** + +1. **拼音组隔离**:每个拼音组(jieba 分词的一个词或合并后的短词组)对应一次独立的拼音输入会话。拼音组开始时 `history_slot_ids` 为空(全 0),组内逐字累积。 + ``` + 输入: 特种兵 → 拼音组 "tezhongbing" + Step 1: pinyin="tezhongbing", history=[] → 预测 特 + Step 2: pinyin="tezhongbing", history=[特] → 预测 种 + Step 3: pinyin="tezhongbing", history=[特, 种] → 预测 兵 + 输入: 十分倾佩 → 拼音组 "shifen" + 拼音组 "qing" + 拼音组 "pei" + 拼音组 "qing": pinyin="qing", history=[] → 预测 倾 + 拼音组 "pei": pinyin="pei", history=[] → 预测 佩 + ``` + +2. **上下文承载前缀字**:当拼音组切换时,前一个拼音组已确认的字符(如"特")已写入文本,通过 BERT 编码的 `input_ids`(而非 `history_slot_ids`)传递给模型。这模拟了真实输入法环境——用户打完"te"确认"特"后,文本中已有"特",再输入"zhongbing"时,模型从上下文(`input_ids`)中能看到"特"。 + +3. **破词续接**:当一个词被按一定概率(~10%)拆分时(`should_break=True`),Phase 1 处理前缀(如"特"),Phase 2 处理续接(如"种兵")。Phase 2 的 `cont_processed_history = []` 是**正确的设计**——前缀已确认为文本的一部分,续接是一个新的拼音会话,history 从空开始。 + ``` + "特种兵" 断在"特"后: + Phase 1: pinyin="te", history=[] → 预测 特 → 确认为文本上下文 + Phase 2: pinyin="zhongbing", history=[] → 预测 种 + history=[种] → 预测 兵 + ``` + Phase 2 中"特"已通过 `part1`(光标前文本)进入 BERT 上下文中,不在 `history_slot_ids` 中。 + +4. **短词合并**:相邻单字词(≤2 字符)有 50% 概率合并为一个拼音组。合并后组内字符共享 history 累积;未合并时各自为独立拼音组(history 为空)。两种方式都对应真实输入法场景——用户可能一次性输入多个字,也可能逐字输入。 + ### 4.2 损失函数与优化 * **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。 * **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。 diff --git a/export_onnx.py b/export_onnx.py index b9c461b..b5b2662 100644 --- a/export_onnx.py +++ b/export_onnx.py @@ -14,450 +14,12 @@ """ import argparse -import os import sys from pathlib import Path -import torch -import numpy as np - -# 检查ONNX是否可用 -try: - import onnx - import onnxruntime as ort - - ONNX_AVAILABLE = True -except ImportError: - ONNX_AVAILABLE = False - print("警告: ONNX或ONNX Runtime未安装") - print("请运行: pip install onnx onnxruntime") - -# 添加src目录到路径 sys.path.insert(0, str(Path(__file__).parent)) -from src.model.export_models import create_export_models_from_checkpoint - - -def check_onnx_available(): - """检查ONNX依赖是否可用""" - if not ONNX_AVAILABLE: - print("错误: ONNX导出需要以下依赖:") - print(" pip install onnx onnxruntime") - print("请安装后重试") - return False - return True - - -def export_context_encoder(model, output_path, config): - """ - 导出上下文编码器为ONNX格式 - - Args: - model: ContextEncoderExport实例 - output_path: 输出路径 - config: 模型配置 - """ - print(f"正在导出上下文编码器到: {output_path}") - - # 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理 - batch_size = 2 - seq_len = config.get("max_seq_len", 128) - pinyin_len = 24 # 固定长度 - dim = config.get("dim", 512) - - input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) - pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) - attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) - - # 设置动态轴 - dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "pinyin_ids": {0: "batch_size"}, - "attention_mask": {0: "batch_size", 1: "seq_len"}, - "context_H": {0: "batch_size", 1: "seq_len"}, - "pinyin_P": {0: "batch_size"}, - "context_mask": {0: "batch_size", 1: "seq_len"}, - "pinyin_mask": {0: "batch_size"}, - } - - # 导出模型 - torch.onnx.export( - model, - (input_ids, pinyin_ids, attention_mask), - output_path, - input_names=["input_ids", "pinyin_ids", "attention_mask"], - output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"], - dynamic_axes=dynamic_axes, - opset_version=18, # 支持现代操作符 - do_constant_folding=True, - verbose=False, - ) - - print("✅ 上下文编码器导出完成") - - # 验证导出 - try: - onnx_model = onnx.load(output_path) - onnx.checker.check_model(onnx_model) - print("✅ ONNX模型验证通过") - except Exception as e: - print(f"⚠️ ONNX模型验证警告: {e}") - - return input_ids, pinyin_ids, attention_mask - - -def export_decoder(model, output_path, config, example_inputs=None): - """ - 导出解码器为ONNX格式 - - Args: - model: DecoderExport实例 - output_path: 输出路径 - config: 模型配置 - example_inputs: 示例输入(用于验证一致性) - """ - print(f"正在导解码器到: {output_path}") - - # 创建示例输入 - 使用batch_size=2以确保ONNX支持动态批处理 - batch_size = 2 - seq_len = config.get("max_seq_len", 128) - pinyin_len = 24 - dim = config.get("dim", 512) - num_slots = config.get("num_slots", 8) - - # 使用提供的示例输入或创建新的 - if example_inputs is not None: - context_H, pinyin_P, context_mask, pinyin_mask = example_inputs - batch_size = context_H.size(0) # 使用实际batch size - history_slot_ids = torch.randint( - 0, 100, (batch_size, num_slots), dtype=torch.long - ) - else: - context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32) - pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32) - context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32) - pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32) - history_slot_ids = torch.randint( - 0, 100, (batch_size, num_slots), dtype=torch.long - ) - - # 设置动态轴 - dynamic_axes = { - "context_H": {0: "batch_size", 1: "seq_len"}, - "pinyin_P": {0: "batch_size"}, - "history_slot_ids": {0: "batch_size"}, - "context_mask": {0: "batch_size", 1: "seq_len"}, - "pinyin_mask": {0: "batch_size"}, - "logits": {0: "batch_size"}, - } - - # 导出模型 - torch.onnx.export( - model, - (context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask), - output_path, - input_names=[ - "context_H", - "pinyin_P", - "history_slot_ids", - "context_mask", - "pinyin_mask", - ], - output_names=["logits"], - dynamic_axes=dynamic_axes, - opset_version=18, - do_constant_folding=True, - verbose=False, - ) - - print("✅ 解码器导出完成") - - # 验证导出 - try: - onnx_model = onnx.load(output_path) - onnx.checker.check_model(onnx_model) - print("✅ ONNX模型验证通过") - except Exception as e: - print(f"⚠️ ONNX模型验证警告: {e}") - - return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask - - -def save_example_inputs(output_dir, example_inputs_dict): - """ - 保存示例输入为NPZ文件,用于验证 - - Args: - output_dir: 输出目录 - example_inputs_dict: 示例输入字典 - """ - npz_path = os.path.join(output_dir, "example_inputs.npz") - - # 转换为numpy数组 - np_data = {} - for key, tensor in example_inputs_dict.items(): - if isinstance(tensor, torch.Tensor): - np_data[key] = tensor.cpu().numpy() - elif isinstance(tensor, tuple): - for i, t in enumerate(tensor): - if isinstance(t, torch.Tensor): - np_data[f"{key}_{i}"] = t.cpu().numpy() - - np.savez(npz_path, **np_data) - print(f"✅ 示例输入已保存到: {npz_path}") - - # 同时保存为PyTorch格式 - torch_path = os.path.join(output_dir, "example_inputs.pt") - torch.save(example_inputs_dict, torch_path) - print(f"✅ PyTorch示例输入已保存到: {torch_path}") - - -def create_inference_example(output_dir, config): - """ - 创建推理示例脚本 - - Args: - output_dir: 输出目录 - config: 模型配置 - """ - example_path = os.path.join(output_dir, "inference_example.py") - - example_code = '''#!/usr/bin/env python3 -""" -ONNX模型推理示例 - -展示如何使用导出的两个ONNX模型进行推理, -包括束搜索(beam search)算法。 -""" - -import numpy as np -import onnxruntime as ort -import torch -import torch.nn.functional as F -from typing import List, Tuple - - -class ONNXInference: - """ONNX模型推理器""" - - def __init__(self, context_encoder_path, decoder_path): - """ - 初始化ONNX推理器 - - Args: - context_encoder_path: 上下文编码器ONNX模型路径 - decoder_path: 解码器ONNX模型路径 - """ - # 创建ONNX Runtime会话 - self.context_encoder_session = ort.InferenceSession( - context_encoder_path, - providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' - ) - self.decoder_session = ort.InferenceSession( - decoder_path, - providers=['CPUExecutionProvider'] - ) - - # 获取输入输出名称 - self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] - self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] - self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] - self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] - - print(f"上下文编码器输入: {self.context_input_names}") - print(f"上下文编码器输出: {self.context_output_names}") - print(f"解码器输入: {self.decoder_input_names}") - print(f"解码器输出: {self.decoder_output_names}") - - def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): - """ - 准备模型输入(与原始推理脚本保持一致) - - 注意: 这里需要实现文本到token的转换 - 为了简化示例,假设已经实现了相关函数 - """ - # 这里应该调用实际的预处理函数 - # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids - raise NotImplementedError("请实现实际的输入预处理") - - def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): - """ - 运行上下文编码器 - - Args: - input_ids: [batch, seq_len] - pinyin_ids: [batch, 24] - attention_mask: [batch, seq_len] - - Returns: - context_H, pinyin_P, context_mask, pinyin_mask - """ - # 准备输入 - inputs = { - "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, - "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, - "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, - } - - # 运行推理 - outputs = self.context_encoder_session.run(self.context_output_names, inputs) - - # 解包输出 - context_H, pinyin_P, context_mask, pinyin_mask = outputs - - return ( - torch.from_numpy(context_H), - torch.from_numpy(pinyin_P), - torch.from_numpy(context_mask), - torch.from_numpy(pinyin_mask), - ) - - def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): - """ - 运行解码器 - - Args: - context_H: [batch, seq_len, 512] - pinyin_P: [batch, 24, 512] - history_slot_ids: [batch, 8] - context_mask: [batch, seq_len] - pinyin_mask: [batch, 24] - - Returns: - logits: [batch, vocab_size] - """ - # 准备输入 - inputs = { - "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, - "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, - "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, - "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, - "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, - } - - # 运行推理 - outputs = self.decoder_session.run(self.decoder_output_names, inputs) - - # 解包输出 - logits = outputs[0] - return torch.from_numpy(logits) - - def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, - beam_size=5, max_length=10, vocab_size=10019): - """ - 束搜索算法示例 - - Args: - context_H: 上下文编码 - pinyin_P: 拼音编码 - context_mask: 上下文掩码 - pinyin_mask: 拼音掩码 - beam_size: 束大小 - max_length: 最大生成长度 - vocab_size: 词汇表大小 - - Returns: - 最佳序列列表 - """ - # 初始束:空序列,分数为0 - beams = [([], 0.0)] # (序列, 对数概率) - - for step in range(max_length): - new_beams = [] - - for seq, score in beams: - # 构建history_slot_ids:已确认的字符ID - if len(seq) < 8: - history = seq + [0] * (8 - len(seq)) - else: - history = seq[-8:] # 只保留最近8个 - - history_tensor = torch.tensor([history], dtype=torch.long) - - # 运行解码器 - logits = self.run_decoder( - context_H, pinyin_P, history_tensor, - context_mask, pinyin_mask - ) - - # 获取概率 - probs = F.softmax(logits[0], dim=-1) - - # 获取top-k候选 - top_probs, top_indices = torch.topk(probs, beam_size) - - # 扩展束 - for prob, idx in zip(top_probs, top_indices): - new_seq = seq + [idx.item()] - new_score = score + torch.log(prob).item() - new_beams.append((new_seq, new_score)) - - # 剪枝:保留beam_size个最佳候选 - new_beams.sort(key=lambda x: x[1], reverse=True) - beams = new_beams[:beam_size] - - # 检查是否所有序列都已结束(以结束符0结尾) - all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) - if all_ended: - break - - return beams - - def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): - """ - 单步预测 - - Args: - input_ids: 输入token IDs - pinyin_ids: 拼音IDs - attention_mask: 注意力掩码 - history_slot_ids: 历史槽位IDs - - Returns: - 预测logits - """ - # 1. 运行上下文编码器 - context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( - input_ids, pinyin_ids, attention_mask - ) - - # 2. 运行解码器 - logits = self.run_decoder( - context_H, pinyin_P, history_slot_ids, - context_mask, pinyin_mask - ) - - return logits - - -def main(): - """示例主函数""" - print("ONNX模型推理示例") - print("=" * 60) - - # 初始化推理器 - context_encoder_path = "context_encoder.onnx" - decoder_path = "decoder.onnx" - - if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): - print("错误: 找不到ONNX模型文件") - print("请先运行export_onnx.py导出模型") - return - - inference = ONNXInference(context_encoder_path, decoder_path) - - print("✅ ONNX推理器初始化完成") - print("请参考此示例实现完整的输入法推理流程") - - -if __name__ == "__main__": - main() -''' - - with open(example_path, "w", encoding="utf-8") as f: - f.write(example_code) - - print(f"✅ 推理示例脚本已保存到: {example_path}") +from src.model.onnx_export import check_onnx_available, run_full_export def main(): @@ -485,77 +47,39 @@ def main(): args = parser.parse_args() - # 检查依赖 if not check_onnx_available(): sys.exit(1) - # 创建输出目录 + print(f"输出目录: {Path(args.output_dir).absolute()}") + + context_encoder_path, decoder_path, config = run_full_export( + checkpoint_path=args.checkpoint, + output_dir=args.output_dir, + device=args.device, + skip_verification=args.skip_verification, + ) + output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - print(f"📁 输出目录: {output_dir.absolute()}") - - # 加载checkpoint并创建导出模型 - print(f"📦 加载checkpoint: {args.checkpoint}") - context_encoder_export, decoder_export, config = ( - create_export_models_from_checkpoint(args.checkpoint, args.device) - ) - - print(f"📊 模型配置: {config}") - - # 导出上下文编码器 - context_encoder_path = output_dir / "context_encoder.onnx" - example_inputs = export_context_encoder( - context_encoder_export, str(context_encoder_path), config - ) - - # 使用上下文编码器的输出作为解码器的示例输入 - with torch.no_grad(): - context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export( - *example_inputs - ) - - # 导出解码器 - decoder_path = output_dir / "decoder.onnx" - export_decoder( - decoder_export, - str(decoder_path), - config, - example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask), - ) - - # 保存示例输入 - example_inputs_dict = { - "input_ids": example_inputs[0], - "pinyin_ids": example_inputs[1], - "attention_mask": example_inputs[2], - "context_H": context_H, - "pinyin_P": pinyin_P, - "context_mask": context_mask, - "pinyin_mask": pinyin_mask, - } - save_example_inputs(output_dir, example_inputs_dict) - - # 创建推理示例脚本 - create_inference_example(output_dir, config) - - print("\n" + "=" * 60) - print("🎉 ONNX导出完成!") + print() print("=" * 60) - print(f"生成的模型文件:") + print("ONNX导出完成!") + print("=" * 60) + print("生成的模型文件:") print(f" - {context_encoder_path}") print(f" - {decoder_path}") - print(f" - {output_dir}/example_inputs.npz") - print(f" - {output_dir}/example_inputs.pt") - print(f" - {output_dir}/inference_example.py") - print("\n使用方法:") + print(f" - {output_dir / 'example_inputs.npz'}") + print(f" - {output_dir / 'example_inputs.pt'}") + print(f" - {output_dir / 'inference_example.py'}") + print() + print("使用方法:") print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}") print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py") - print(f" 3. 集成到您的应用: 参考inference_example.py中的ONNXInference类") - print("\n注意:") - print(" - 请确保安装了onnxruntime: pip install onnxruntime") - print(" - GPU推理需要onnxruntime-gpu: pip install onnxruntime-gpu") - print(" - 束搜索算法需要根据实际需求进行调整") + print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类") + print() + print("注意:") + print(" - 请确保安装了 onnxruntime: pip install onnxruntime") + print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu") + print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代") if __name__ == "__main__": diff --git a/exported_models/inference_example.py b/exported_models/inference_example.py index a140239..1277b81 100644 --- a/exported_models/inference_example.py +++ b/exported_models/inference_example.py @@ -5,7 +5,7 @@ ONNX模型推理示例 展示如何使用导出的两个ONNX模型进行推理, 包括束搜索(beam search)算法。 """ - +import os import numpy as np import onnxruntime as ort import torch @@ -15,94 +15,46 @@ from typing import List, Tuple class ONNXInference: """ONNX模型推理器""" - + def __init__(self, context_encoder_path, decoder_path): - """ - 初始化ONNX推理器 - - Args: - context_encoder_path: 上下文编码器ONNX模型路径 - decoder_path: 解码器ONNX模型路径 - """ - # 创建ONNX Runtime会话 self.context_encoder_session = ort.InferenceSession( context_encoder_path, - providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider' + providers=['CPUExecutionProvider'] ) self.decoder_session = ort.InferenceSession( decoder_path, providers=['CPUExecutionProvider'] ) - - # 获取输入输出名称 + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] - + print(f"上下文编码器输入: {self.context_input_names}") print(f"上下文编码器输出: {self.context_output_names}") print(f"解码器输入: {self.decoder_input_names}") print(f"解码器输出: {self.decoder_output_names}") - + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): - """ - 准备模型输入(与原始推理脚本保持一致) - - 注意: 这里需要实现文本到token的转换 - 为了简化示例,假设已经实现了相关函数 - """ - # 这里应该调用实际的预处理函数 - # 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids raise NotImplementedError("请实现实际的输入预处理") - + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): - """ - 运行上下文编码器 - - Args: - input_ids: [batch, seq_len] - pinyin_ids: [batch, 24] - attention_mask: [batch, seq_len] - - Returns: - context_H, pinyin_P, context_mask, pinyin_mask - """ - # 准备输入 inputs = { "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, } - - # 运行推理 outputs = self.context_encoder_session.run(self.context_output_names, inputs) - - # 解包输出 context_H, pinyin_P, context_mask, pinyin_mask = outputs - return ( torch.from_numpy(context_H), torch.from_numpy(pinyin_P), torch.from_numpy(context_mask), torch.from_numpy(pinyin_mask), ) - + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): - """ - 运行解码器 - - Args: - context_H: [batch, seq_len, 512] - pinyin_P: [batch, 24, 512] - history_slot_ids: [batch, 8] - context_mask: [batch, seq_len] - pinyin_mask: [batch, 24] - - Returns: - logits: [batch, vocab_size] - """ - # 准备输入 inputs = { "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, @@ -110,99 +62,46 @@ class ONNXInference: "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, } - - # 运行推理 outputs = self.decoder_session.run(self.decoder_output_names, inputs) - - # 解包输出 logits = outputs[0] return torch.from_numpy(logits) - - def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, beam_size=5, max_length=10, vocab_size=10019): - """ - 束搜索算法示例 - - Args: - context_H: 上下文编码 - pinyin_P: 拼音编码 - context_mask: 上下文掩码 - pinyin_mask: 拼音掩码 - beam_size: 束大小 - max_length: 最大生成长度 - vocab_size: 词汇表大小 - - Returns: - 最佳序列列表 - """ - # 初始束:空序列,分数为0 - beams = [([], 0.0)] # (序列, 对数概率) - + beams = [([], 0.0)] for step in range(max_length): new_beams = [] - for seq, score in beams: - # 构建history_slot_ids:已确认的字符ID if len(seq) < 8: history = seq + [0] * (8 - len(seq)) else: - history = seq[-8:] # 只保留最近8个 - + history = seq[-8:] history_tensor = torch.tensor([history], dtype=torch.long) - - # 运行解码器 logits = self.run_decoder( context_H, pinyin_P, history_tensor, context_mask, pinyin_mask ) - - # 获取概率 probs = F.softmax(logits[0], dim=-1) - - # 获取top-k候选 top_probs, top_indices = torch.topk(probs, beam_size) - - # 扩展束 for prob, idx in zip(top_probs, top_indices): new_seq = seq + [idx.item()] new_score = score + torch.log(prob).item() new_beams.append((new_seq, new_score)) - - # 剪枝:保留beam_size个最佳候选 new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] - - # 检查是否所有序列都已结束(以结束符0结尾) all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) if all_ended: break - return beams - + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): - """ - 单步预测 - - Args: - input_ids: 输入token IDs - pinyin_ids: 拼音IDs - attention_mask: 注意力掩码 - history_slot_ids: 历史槽位IDs - - Returns: - 预测logits - """ - # 1. 运行上下文编码器 context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( input_ids, pinyin_ids, attention_mask ) - - # 2. 运行解码器 logits = self.run_decoder( context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask ) - return logits @@ -210,21 +109,19 @@ def main(): """示例主函数""" print("ONNX模型推理示例") print("=" * 60) - - # 初始化推理器 + context_encoder_path = "context_encoder.onnx" decoder_path = "decoder.onnx" - + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): print("错误: 找不到ONNX模型文件") - print("请先运行export_onnx.py导出模型") + print("请先运行 train-model export 导出模型") return - + inference = ONNXInference(context_encoder_path, decoder_path) - - print("✅ ONNX推理器初始化完成") + print("\u2705 ONNX推理器初始化完成") print("请参考此示例实现完整的输入法推理流程") - + if __name__ == "__main__": main() diff --git a/scripts/finetune_slots.py b/scripts/finetune_slots.py new file mode 100644 index 0000000..6b8bdbe --- /dev/null +++ b/scripts/finetune_slots.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +""" +临时迁移训练脚本:在预训练模型基础上重新训练,支持冻结 context_encoder。 + +与 src/model/trainer.py 的 train 命令行为完全一致,额外增加: + - --pretrained-checkpoint: 加载预训练权重(必需的迁移学习源) + - --freeze-context-encoder: 冻结 context_encoder 层(默认开启) + +运行方式: + python scripts/finetune_slots.py \ + --pretrained-checkpoint ./output/checkpoints/best_model.pt \ + --train-data-path /path/to/train_data \ + --eval-data-path /path/to/eval_data \ + --output-dir ./finetune_output \ + --freeze-context-encoder +""" + +import argparse +import json +import os +import random +import sys +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) + +from model.model import InputMethodEngine +from model.trainer import ( + Trainer, + create_dataloader, + worker_init_fn, +) +from model.dataset import PinyinInputDataset +from model.preprocessed_dataset import ( + PreProcessedDataset, + is_preprocessed_data, +) + +from loguru import logger +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + + +def main(): + parser = argparse.ArgumentParser( + description="迁移学习训练:加载预训练模型,冻结指定层后重新训练", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # === 数据参数 === + parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径") + parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径") + parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录") + parser.add_argument("--max-iter-length", type=int, + default=1024 * 1024 * 128, help="每个 epoch 最大样本数") + + # === 迁移学习参数 === + parser.add_argument("--pretrained-checkpoint", "-c", required=True, + help="预训练模型检查点路径") + parser.add_argument("--freeze-context-encoder", action="store_true", default=True, + help="冻结 context_encoder 层 (默认开启)") + parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder", + action="store_false", + help="不冻结 context_encoder") + + # === 训练参数 === + parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小") + parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数") + parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4, + help="学习率") + parser.add_argument("--min-learning-rate", type=float, default=1e-9, + help="最小学习率") + parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减") + parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例") + parser.add_argument("--label-smoothing", type=float, default=0.1, + help="标签平滑参数") + parser.add_argument("--grad-accum-steps", type=int, default=1, + help="梯度累积步数") + parser.add_argument("--clip-grad-norm", type=float, default=1.0, + help="梯度裁剪范数") + parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率") + parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率") + + # === 其他参数 === + parser.add_argument("--mixed-precision", action="store_true", default=True) + parser.add_argument("--no-mixed-precision", dest="mixed_precision", + action="store_false", help="禁用混合精度") + parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数") + parser.add_argument("--tensorboard", action="store_true", default=True) + parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false", + help="禁用 TensorBoard") + parser.add_argument("--seed", type=int, default=42, help="随机种子") + parser.add_argument("--compile", action="store_true", default=False, + help="使用 torch.compile 优化") + parser.add_argument("--moe-mode", default="all", + choices=["all", "sparse", "sparse_allow_graph"], + help="MoE 计算策略") + + args = parser.parse_args() + + # ================================================================ + # 初始化 + # ================================================================ + torch.multiprocessing.set_sharing_strategy("file_system") + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("high") + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + console = Console() + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # ================================================================ + # 模型常量 (与 trainer.py 保持一致) + # ================================================================ + vocab_size = 10019 + pinyin_vocab_size = 30 + dim = 512 + num_slots = 8 + n_layers = 4 + n_heads = 4 + num_experts = 10 + max_seq_len = 128 + + # ================================================================ + # 打印配置 + # ================================================================ + console.print(Panel.fit( + "[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan")) + config_table = Table(show_header=True, header_style="bold magenta") + config_table.add_column("Category", style="cyan") + config_table.add_column("Parameter", style="green") + config_table.add_column("Value", style="yellow") + + config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint) + config_table.add_row("迁移学习", "冻结 context_encoder", + str(args.freeze_context_encoder)) + config_table.add_row("数据", "训练数据路径", args.train_data_path) + config_table.add_row("数据", "评估数据路径", args.eval_data_path) + config_table.add_row("数据", "输出目录", args.output_dir) + config_table.add_row("数据", "批次大小", str(args.batch_size)) + config_table.add_row("数据", "Worker数量", str(args.num_workers)) + config_table.add_row("模型", "MoE策略", args.moe_mode) + config_table.add_row("模型", "编译优化", str(args.compile)) + config_table.add_row("训练", "训练轮数", str(args.num_epochs)) + config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}") + config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}") + config_table.add_row("训练", "权重衰减", str(args.weight_decay)) + config_table.add_row("训练", "热身比例", str(args.warmup_ratio)) + config_table.add_row("训练", "标签平滑", str(args.label_smoothing)) + config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps)) + config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm)) + config_table.add_row("训练", "混合精度", str(args.mixed_precision)) + + # ================================================================ + # 创建数据加载器 (逻辑与 trainer.py CLI 完全一致) + # ================================================================ + console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") + + is_train_preprocessed = is_preprocessed_data(args.train_data_path) + is_eval_preprocessed = is_preprocessed_data(args.eval_data_path) + + if is_train_preprocessed: + train_dataset = PreProcessedDataset(args.train_data_path, + max_cache_shards=2) + pre_shuffled = train_dataset.metadata.get("pre_shuffled", False) + shuffle_train = not pre_shuffled + if args.max_iter_length > 0: + capped_samples = min(len(train_dataset), args.max_iter_length) + else: + capped_samples = len(train_dataset) + total_steps = (capped_samples // args.batch_size) * args.num_epochs + train_num_workers = min(args.num_workers, 1) + logger.info( + f"Preprocessed dataset: {len(train_dataset):,} samples, " + f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, " + f"workers={train_num_workers}, steps={total_steps:,}") + train_dataloader = create_dataloader( + dataset=train_dataset, + batch_size=args.batch_size, + num_workers=train_num_workers, + pin_memory=torch.cuda.is_available(), + shuffle=shuffle_train, + ) + config_table.add_row("数据", "训练数据类型", "预处理数据") + else: + train_dataset = PinyinInputDataset( + data_path=args.train_data_path, + max_workers=-1, + max_iter_length=args.max_iter_length, + max_seq_length=max_seq_len, + text_field="text", + py_style_weight=(9, 2, 1), + shuffle_buffer_size=2000000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, + 5: 15, 6: 10, 7: 5, 8: 2}, + ) + total_steps = int(args.max_iter_length * + args.num_epochs / args.batch_size) + train_dataloader = create_dataloader( + dataset=train_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=torch.cuda.is_available(), + max_iter_length=args.max_iter_length, + ) + config_table.add_row("数据", "训练数据类型", "流式数据") + + if is_eval_preprocessed: + eval_dataset = PreProcessedDataset(args.eval_data_path, + max_cache_shards=1) + eval_dataloader = create_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + num_workers=0, + pin_memory=torch.cuda.is_available(), + shuffle=False, + ) + config_table.add_row("数据", "评估数据类型", "预处理数据") + else: + eval_dataset = PinyinInputDataset( + data_path=args.eval_data_path, + max_workers=-1, + max_iter_length=args.batch_size * 64, + max_seq_length=max_seq_len, + text_field="text", + py_style_weight=(9, 2, 1), + shuffle_buffer_size=2000000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, + 5: 15, 6: 10, 7: 5, 8: 2}, + ) + eval_dataloader = create_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + num_workers=2, + pin_memory=torch.cuda.is_available(), + max_iter_length=args.batch_size * 64, + ) + config_table.add_row("数据", "评估数据类型", "流式数据") + + config_table.add_row("数据", "总步数", str(total_steps)) + console.print(config_table) + + # ================================================================ + # 创建模型并加载预训练权重 + # ================================================================ + console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = InputMethodEngine( + vocab_size=vocab_size, + pinyin_vocab_size=pinyin_vocab_size, + dim=dim, + num_slots=num_slots, + n_layers=n_layers, + n_heads=n_heads, + num_experts=num_experts, + max_seq_len=max_seq_len, + compile=args.compile, + moe_mode=args.moe_mode, + ) + model.to(device) + + # 加载预训练权重 + pretrained_path = Path(args.pretrained_checkpoint) + if not pretrained_path.exists(): + console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]") + sys.exit(1) + + checkpoint = torch.load(args.pretrained_checkpoint, map_location=device) + if "model_state_dict" in checkpoint: + pretrained_weights = checkpoint["model_state_dict"] + else: + pretrained_weights = checkpoint + + missing_keys, unexpected_keys = model.load_state_dict( + pretrained_weights, strict=False) + + if missing_keys: + console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): " + f"{missing_keys[:5]}...[/yellow]") + if unexpected_keys: + console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): " + f"{unexpected_keys[:5]}...[/yellow]") + + console.print(f"[green]✓ 预训练权重加载完成[/green]") + + # ================================================================ + # 冻结 context_encoder + # ================================================================ + if args.freeze_context_encoder: + console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]") + frozen_count = 0 + trainable_count = 0 + for name, param in model.named_parameters(): + if name.startswith("context_encoder"): + param.requires_grad = False + frozen_count += param.numel() + else: + trainable_count += param.numel() + + total_params = frozen_count + trainable_count + console.print(f"[green]✓ context_encoder 已冻结[/green]") + logger.info( + f"冻结参数: {frozen_count:,} / {total_params:,} " + f"({frozen_count / total_params * 100:.1f}%), " + f"可训练参数: {trainable_count:,} / {total_params:,} " + f"({trainable_count / total_params * 100:.1f}%)") + else: + logger.info("未冻结任何层,全模型参与训练") + + # ================================================================ + # 保存配置 + # ================================================================ + config = { + "pretrained_checkpoint": args.pretrained_checkpoint, + "freeze_context_encoder": args.freeze_context_encoder, + "train_data_path": args.train_data_path, + "eval_data_path": args.eval_data_path, + "output_dir": args.output_dir, + "batch_size": args.batch_size, + "num_epochs": args.num_epochs, + "learning_rate": args.learning_rate, + "min_learning_rate": args.min_learning_rate, + "weight_decay": args.weight_decay, + "warmup_ratio": args.warmup_ratio, + "label_smoothing": args.label_smoothing, + "grad_accum_steps": args.grad_accum_steps, + "clip_grad_norm": args.clip_grad_norm, + "eval_frequency": args.eval_frequency, + "save_frequency": args.save_frequency, + "mixed_precision": args.mixed_precision, + "num_workers": args.num_workers, + "use_tensorboard": args.tensorboard, + "seed": args.seed, + "compile": args.compile, + "moe_mode": args.moe_mode, + "total_steps": total_steps, + "vocab_size": vocab_size, + "pinyin_vocab_size": pinyin_vocab_size, + "dim": dim, + "num_slots": num_slots, + "n_layers": n_layers, + "n_heads": n_heads, + "num_experts": num_experts, + "max_seq_len": max_seq_len, + "is_train_preprocessed": is_train_preprocessed, + "is_eval_preprocessed": is_eval_preprocessed, + } + config_file = output_path / "training_config.json" + with open(config_file, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + logger.info(f"Configuration saved to {config_file}") + + # ================================================================ + # 创建 Trainer 并开始训练 + # ================================================================ + console.print("[bold cyan]正在创建训练器...[/bold cyan]") + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + total_steps=total_steps, + output_dir=args.output_dir, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + min_learning_rate=args.min_learning_rate, + weight_decay=args.weight_decay, + warmup_ratio=args.warmup_ratio, + label_smoothing=args.label_smoothing, + grad_accum_steps=args.grad_accum_steps, + clip_grad_norm=args.clip_grad_norm, + eval_frequency=args.eval_frequency, + save_frequency=args.save_frequency, + mixed_precision=args.mixed_precision, + device=device, + use_tensorboard=args.tensorboard, + status_file="training_status.json", + ) + console.print("[green]✓ 训练器创建完成[/green]") + + console.print("\n[bold cyan]开始训练...[/bold cyan]") + console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + try: + trainer.train( + resume_from=None, + reset_training_state=False, + auto_resume=False, # 迁移学习从头开始,不自动恢复 + ) + except KeyboardInterrupt: + console.print("[bold green]训练被终止[/bold green]") + trainer.save_checkpoint("interrupted_model.pt") + + console.print("[bold green]✓ 训练完成![/bold green]") + console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + console.print(f"模型和日志保存在: {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/model/components.py b/src/model/components.py index 8cbbbf9..f6279d6 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -53,27 +53,7 @@ class PinyinLSTMEncoder(nn.Module): self.layer_norm = nn.LayerNorm(input_dim) def forward(self, x, mask=None): - """ - Args: - x: [batch, seq_len, input_dim] pinyin embeddings - mask: [batch, seq_len] optional padding mask (True for valid, False for padding) - Returns: - output: [batch, seq_len, input_dim] 每位置的拼音编码 - """ - total_len = x.size(1) - - if mask is not None: - lengths = mask.sum(dim=1).cpu().clamp(min=1) - packed = nn.utils.rnn.pack_padded_sequence( - x, lengths, batch_first=True, enforce_sorted=False - ) - packed_out, (hidden, cell) = self.lstm(packed) - output, _ = nn.utils.rnn.pad_packed_sequence( - packed_out, batch_first=True, total_length=total_len - ) - else: - output, (hidden, cell) = self.lstm(x) - + output, _ = self.lstm(x) projected = self.proj(output) return self.layer_norm(projected) diff --git a/src/model/model.py b/src/model/model.py index ab81651..ac50d78 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -110,6 +110,10 @@ class InputMethodEngine(nn.Module): history_slot_ids = history_slot_ids.view(-1, self.num_slots) + is_zero = (history_slot_ids == 0) + cumsum_zero = is_zero.cumsum(dim=1) + slot_mask = (cumsum_zero <= 1).to(torch.get_default_dtype()) + H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask) S = self.slot_memory(history_slot_ids) @@ -120,10 +124,13 @@ class InputMethodEngine(nn.Module): S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask ) + fused = fused * slot_mask.unsqueeze(-1) + moe_out = self.moe(fused) batch_size = input_ids.size(0) slot_scores = self.slot_attention(moe_out).squeeze(-1) + slot_scores = slot_scores.masked_fill(slot_mask == 0, -1e9) slot_weights = torch.softmax(slot_scores, dim=1) pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) diff --git a/src/model/onnx_export.py b/src/model/onnx_export.py new file mode 100644 index 0000000..28d1713 --- /dev/null +++ b/src/model/onnx_export.py @@ -0,0 +1,366 @@ +""" +ONNX模型导出核心逻辑 + +将 InputMethodEngine 模型导出为两个 ONNX 文件: +1. context_encoder.onnx - 上下文编码器(可复用) +2. decoder.onnx - 解码器 + +共用此模块的入口: + - CLI: train-model export + - 脚本: python export_onnx.py(薄壳) +""" + +import os +import sys +from pathlib import Path +from typing import Dict, Optional, Tuple + +import numpy as np +import torch + +from .export_models import create_export_models_from_checkpoint + + +def check_onnx_available() -> bool: + try: + import onnx # noqa: F401 + import onnxruntime as ort # noqa: F401 + return True + except ImportError: + print("错误: ONNX导出需要以下依赖:") + print(" pip install onnx onnxruntime") + print("请安装后重试") + return False + + +def export_context_encoder( + model, + output_path: str, + config: Dict, + skip_verification: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = 2 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 + + input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long) + pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) + + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "pinyin_ids": {0: "batch_size"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + "context_H": {0: "batch_size", 1: "seq_len"}, + "pinyin_P": {0: "batch_size"}, + "context_mask": {0: "batch_size", 1: "seq_len"}, + "pinyin_mask": {0: "batch_size"}, + } + + torch.onnx.export( + model, + (input_ids, pinyin_ids, attention_mask), + output_path, + input_names=["input_ids", "pinyin_ids", "attention_mask"], + output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"], + dynamic_axes=dynamic_axes, + opset_version=18, + do_constant_folding=True, + verbose=False, + ) + + print(f" 上下文编码器导出完成 -> {output_path}") + + if not skip_verification: + try: + import onnx + onnx.checker.check_model(onnx.load(output_path)) + print(" ONNX 模型验证通过") + except Exception as e: + print(f" ONNX 模型验证警告: {e}") + + return input_ids, pinyin_ids, attention_mask + + +def export_decoder( + model, + output_path: str, + config: Dict, + example_inputs: Optional[Tuple] = None, + skip_verification: bool = False, +) -> Tuple[torch.Tensor, ...]: + batch_size = 2 + seq_len = config.get("max_seq_len", 128) + pinyin_len = 24 + dim = config.get("dim", 512) + num_slots = config.get("num_slots", 8) + + if example_inputs is not None: + context_H, pinyin_P, context_mask, pinyin_mask = example_inputs + batch_size = context_H.size(0) + history_slot_ids = torch.randint( + 0, 100, (batch_size, num_slots), dtype=torch.long + ) + else: + context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32) + pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32) + context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32) + pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32) + history_slot_ids = torch.randint( + 0, 100, (batch_size, num_slots), dtype=torch.long + ) + + dynamic_axes = { + "context_H": {0: "batch_size", 1: "seq_len"}, + "pinyin_P": {0: "batch_size"}, + "history_slot_ids": {0: "batch_size"}, + "context_mask": {0: "batch_size", 1: "seq_len"}, + "pinyin_mask": {0: "batch_size"}, + "logits": {0: "batch_size"}, + } + + torch.onnx.export( + model, + (context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask), + output_path, + input_names=[ + "context_H", + "pinyin_P", + "history_slot_ids", + "context_mask", + "pinyin_mask", + ], + output_names=["logits"], + dynamic_axes=dynamic_axes, + opset_version=18, + do_constant_folding=True, + verbose=False, + ) + + print(f" 解码器导出完成 -> {output_path}") + + if not skip_verification: + try: + import onnx + onnx.checker.check_model(onnx.load(output_path)) + print(" ONNX 模型验证通过") + except Exception as e: + print(f" ONNX 模型验证警告: {e}") + + return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask + + +def save_example_inputs(output_dir: str, example_inputs_dict: Dict) -> None: + npz_path = os.path.join(output_dir, "example_inputs.npz") + np_data = {} + for key, tensor in example_inputs_dict.items(): + if isinstance(tensor, torch.Tensor): + np_data[key] = tensor.cpu().numpy() + elif isinstance(tensor, tuple): + for i, t in enumerate(tensor): + if isinstance(t, torch.Tensor): + np_data[f"{key}_{i}"] = t.cpu().numpy() + np.savez(npz_path, **np_data) + print(f" 示例输入已保存到: {npz_path}") + + torch_path = os.path.join(output_dir, "example_inputs.pt") + torch.save(example_inputs_dict, torch_path) + print(f" PyTorch 示例输入已保存到: {torch_path}") + + +_INFERENCE_EXAMPLE_TEMPLATE = '''#!/usr/bin/env python3 +""" +ONNX模型推理示例 + +展示如何使用导出的两个ONNX模型进行推理, +包括束搜索(beam search)算法。 +""" + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from typing import List, Tuple + + +class ONNXInference: + """ONNX模型推理器""" + + def __init__(self, context_encoder_path, decoder_path): + self.context_encoder_session = ort.InferenceSession( + context_encoder_path, + providers=['CPUExecutionProvider'] + ) + self.decoder_session = ort.InferenceSession( + decoder_path, + providers=['CPUExecutionProvider'] + ) + + self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()] + self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()] + self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()] + self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()] + + print(f"上下文编码器输入: {self.context_input_names}") + print(f"上下文编码器输出: {self.context_output_names}") + print(f"解码器输入: {self.decoder_input_names}") + print(f"解码器输出: {self.decoder_output_names}") + + def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128): + raise NotImplementedError("请实现实际的输入预处理") + + def run_context_encoder(self, input_ids, pinyin_ids, attention_mask): + inputs = { + "input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids, + "pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids, + "attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask, + } + outputs = self.context_encoder_session.run(self.context_output_names, inputs) + context_H, pinyin_P, context_mask, pinyin_mask = outputs + return ( + torch.from_numpy(context_H), + torch.from_numpy(pinyin_P), + torch.from_numpy(context_mask), + torch.from_numpy(pinyin_mask), + ) + + def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask): + inputs = { + "context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H, + "pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P, + "history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids, + "context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask, + "pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask, + } + outputs = self.decoder_session.run(self.decoder_output_names, inputs) + logits = outputs[0] + return torch.from_numpy(logits) + + def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask, + beam_size=5, max_length=10, vocab_size=10019): + beams = [([], 0.0)] + for step in range(max_length): + new_beams = [] + for seq, score in beams: + if len(seq) < 8: + history = seq + [0] * (8 - len(seq)) + else: + history = seq[-8:] + history_tensor = torch.tensor([history], dtype=torch.long) + logits = self.run_decoder( + context_H, pinyin_P, history_tensor, + context_mask, pinyin_mask + ) + probs = F.softmax(logits[0], dim=-1) + top_probs, top_indices = torch.topk(probs, beam_size) + for prob, idx in zip(top_probs, top_indices): + new_seq = seq + [idx.item()] + new_score = score + torch.log(prob).item() + new_beams.append((new_seq, new_score)) + new_beams.sort(key=lambda x: x[1], reverse=True) + beams = new_beams[:beam_size] + all_ended = all(seq[-1] == 0 for seq, _ in beams if seq) + if all_ended: + break + return beams + + def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids): + context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder( + input_ids, pinyin_ids, attention_mask + ) + logits = self.run_decoder( + context_H, pinyin_P, history_slot_ids, + context_mask, pinyin_mask + ) + return logits + + +def main(): + """示例主函数""" + print("ONNX模型推理示例") + print("=" * 60) + + context_encoder_path = "context_encoder.onnx" + decoder_path = "decoder.onnx" + + if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path): + print("错误: 找不到ONNX模型文件") + print("请先运行 train-model export 导出模型") + return + + inference = ONNXInference(context_encoder_path, decoder_path) + print("\\u2705 ONNX推理器初始化完成") + print("请参考此示例实现完整的输入法推理流程") + + +if __name__ == "__main__": + main() +''' + + +def create_inference_example(output_dir: str, config: Dict) -> None: + example_path = os.path.join(output_dir, "inference_example.py") + with open(example_path, "w", encoding="utf-8") as f: + f.write(_INFERENCE_EXAMPLE_TEMPLATE) + print(f" 推理示例脚本已保存到: {example_path}") + + +def run_full_export( + checkpoint_path: str, + output_dir: str, + device: str = "cpu", + skip_verification: bool = False, +) -> Tuple[str, str, Dict]: + """ + 完整的 ONNX 导出流程 + + Returns: + (context_encoder_path, decoder_path, config) + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + print(f"加载 checkpoint: {checkpoint_path}") + context_encoder_export, decoder_export, config = create_export_models_from_checkpoint( + checkpoint_path, device + ) + + print(f"模型配置: vocab_size={config.get('vocab_size')}, " + f"dim={config.get('dim')}, " + f"num_slots={config.get('num_slots')}, " + f"num_experts={config.get('num_experts')}, " + f"moe_mode={config.get('moe_mode', 'all')}") + + print("正在导出模型...") + + context_encoder_path = str(output_path / "context_encoder.onnx") + example_inputs = export_context_encoder( + context_encoder_export, context_encoder_path, config, + skip_verification=skip_verification, + ) + + with torch.no_grad(): + context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export( + *example_inputs + ) + + decoder_path = str(output_path / "decoder.onnx") + export_decoder( + decoder_export, decoder_path, config, + example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask), + skip_verification=skip_verification, + ) + + example_inputs_dict = { + "input_ids": example_inputs[0], + "pinyin_ids": example_inputs[1], + "attention_mask": example_inputs[2], + "context_H": context_H, + "pinyin_P": pinyin_P, + "context_mask": context_mask, + "pinyin_mask": pinyin_mask, + } + save_example_inputs(output_dir, example_inputs_dict) + create_inference_example(output_dir, config) + + return context_encoder_path, decoder_path, config diff --git a/src/model/trainer.py b/src/model/trainer.py index f20ef98..90293c7 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -838,6 +838,20 @@ class Trainer: # epoch 完成 progress.update(epoch_task, advance=1) + # 每个 epoch 结束后评估并保存 best model + epoch_eval_metrics = self.evaluate() + if epoch_eval_metrics: + epoch_log_metrics = { + "epoch/eval_loss": epoch_eval_metrics["eval_loss"], + "epoch/eval_accuracy": epoch_eval_metrics["eval_accuracy"], + } + if self.writer is not None: + for key, value in epoch_log_metrics.items(): + self.writer.add_scalar(key, value, global_step) + if epoch_eval_metrics["eval_loss"] < self.best_eval_loss: + self.best_eval_loss = epoch_eval_metrics["eval_loss"] + self.save_checkpoint("best_model.pt", is_best=True) + # 每个 epoch 结束后保存检查点 self.save_epoch_checkpoint(epoch + 1) @@ -848,6 +862,34 @@ class Trainer: # 训练完成 logger.info("Training completed!") + # 最终评估并保存 best model + logger.info("Running final evaluation...") + final_eval_metrics = self.evaluate() + if final_eval_metrics: + final_log_metrics = { + "train/loss": accumulated_loss / accumulation_counter + if accumulation_counter > 0 + else 0.0, + "train/accuracy": accumulated_accuracy / accumulation_counter + if accumulation_counter > 0 + else 0.0, + "train/learning_rate": self._get_current_lr(), + "eval/loss": final_eval_metrics["eval_loss"], + "eval/accuracy": final_eval_metrics["eval_accuracy"], + } + self._log_to_tensorboard(final_log_metrics, global_step) + logger.info( + f"Final eval loss: {final_eval_metrics['eval_loss']:.4f}, " + f"Final eval accuracy: {final_eval_metrics['eval_accuracy']:.4f}" + ) + if final_eval_metrics["eval_loss"] < self.best_eval_loss: + self.best_eval_loss = final_eval_metrics["eval_loss"] + self.save_checkpoint("best_model.pt", is_best=True) + + # 保存最终模型 + self.save_checkpoint("final_model.pt") + logger.info("Final model saved.") + # 显示保存的 epoch checkpoint 信息 if self.epoch_checkpoints: sorted_checkpoints = self.get_epoch_checkpoints() @@ -1096,6 +1138,7 @@ def create_dataloader( num_workers=num_workers, pin_memory=pin_memory, collate_fn=preprocessed_collate_fn, + drop_last=True, persistent_workers=True if num_workers > 0 else False, ) @@ -1108,7 +1151,8 @@ def create_dataloader( pin_memory=pin_memory, worker_init_fn=worker_init_fn, collate_fn=preprocess_collate_fn(fixed_max_seq_length), - prefetch_factor=2, # 减少预取以避免内存问题 + drop_last=True, + prefetch_factor=2, persistent_workers=True, shuffle=shuffle, ) @@ -1464,21 +1508,46 @@ def evaluate( @app.command() def export( checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"), - output_path: str = typer.Option( - "./exported_model.onnx", "--output", "-o", help="输出路径" + output_dir: str = typer.Option( + "./exported_models", "--output", "-o", help="输出目录" + ), + device: str = typer.Option("cpu", "--device", help="导出设备(cpu/cuda)"), + skip_verification: bool = typer.Option( + False, "--skip-verification", help="跳过ONNX模型验证" ), ): """ - 导出模型为ONNX格式 + 导出模型为ONNX格式(生成 context_encoder.onnx 和 decoder.onnx) """ + from .onnx_export import check_onnx_available, run_full_export + console = Console() - console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]") + console.print("[bold cyan]━━━ ONNX 模型导出 ━━━[/bold cyan]") + console.print(f" 检查点: {checkpoint_path}") + console.print(f" 输出目录: {output_dir}") + console.print(f" 设备: {device}") - # 这里应该实现导出逻辑 - # 1. 加载检查点 - # 2. 导出为ONNX + if not check_onnx_available(): + raise typer.Exit(1) - console.print("[yellow]导出功能待实现[/yellow]") + try: + context_encoder_path, decoder_path, config = run_full_export( + checkpoint_path=checkpoint_path, + output_dir=output_dir, + device=device, + skip_verification=skip_verification, + ) + except Exception as e: + console.print(f"[bold red]导出失败: {e}[/bold red]") + raise typer.Exit(1) + + console.print("\n[bold green]✓ 导出完成![/bold green]") + console.print(f" context_encoder.onnx -> {context_encoder_path}") + console.print(f" decoder.onnx -> {decoder_path}") + console.print( + f"\n MoE 层使用 'all' 模式(全量计算 {config.get('num_experts', '?')} 个专家)," + f"稀疏化优化可后续迭代" + ) if __name__ == "__main__":