SUimeModelTraner/README.md

568 lines
21 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 输入法预测模型架构设计 (Input Method Prediction Model)
## 1. 概述
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
## 2. 核心架构流程
数据流遵循以下路径:
`输入编码``Transformer 上下文编码``槽位记忆嵌入``交叉注意力融合``门控+专家混合 (MoE)``分类预测``束搜索解码`
### 2.1 输入层设计
模型接收三类输入,分别处理以保持语义清晰:
1. **当前文本上下文**包含光标前文本Prefix和光标后文本Suffix
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
### 2.2 模块详解
#### A. Transformer 编码器 (Context Encoder)
负责提取当前语境的深层语义表示。
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
* **骨干网络**:使用标准的 Transformer Encoder。
* **隐藏层维度**512 [1]
* **Transformer 层数**4 层(轻量级设计,从头训练) [1]
* **注意力头数**4 头 [1]
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
#### B. 槽位记忆模块 (Slot Memory)
负责将非结构化的历史输入转化为结构化的记忆向量。
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`
#### C. 交叉注意力融合 (Cross-Attention Fusion)
这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
* **Key/Value (K/V)**Transformer 编码器输出的上下文表示 $H$ [1]。
* **机制**:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`
#### D. 门控与专家混合 (Gating + MoE)
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
* **专家数量**20 个专家 [1]。
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
* **输出**:经过专家网络增强后的特征向量。
#### E. 分类头与解码
* **分类预测**MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
## 3. 关键超参数配置
为确保模型性能与效率的平衡,建议采用以下超参数 [1]
| 参数项 | 推荐值 | 说明 |
| :--- | :--- | :--- |
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
## 4. 训练策略
本模型采用标准的**序列到序列Seq2Seq监督学习**范式,直接对目标槽位序列进行逐步预测。
### 4.1 数据构造与标签
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
* **上下文**:光标前后的文本片段。
* **拼音**:当前待输入字的拼音序列。
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
* **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。
### 4.2 损失函数与优化
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
* **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
### 4.3 训练流程细节
1. **前向传播**
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
* 分类头输出当前步所有候选字的概率分布。
2. **Teacher Forcing**
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。
3. **反向传播**
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
### 4.4 推理与训练的差异
* **训练时**:使用 Ground Truth真实标签作为槽位输入确保模型学习到最优的条件概率分布。
* **推理时**:由于无法获取真实标签,模型采用**束搜索Beam Search** [1]。
* **束宽**:默认为 5 [1]。
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
## 5. Jupyter Lab 训练示例
以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例:
```python
# %% [markdown]
# # 输入法模型训练示例
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
# %% [code]
# 1. 导入必要的库
import sys
import os
from pathlib import Path
from datetime import datetime
import torch
from torch.utils.data import DataLoader
# 添加项目路径适应不同的Jupyter Lab运行位置
project_root = Path.cwd()
# 检查当前目录是否包含src目录如果不包含则使用父目录
if not (project_root / "src").exists():
project_root = project_root.parent
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
# 导入项目模块
from src.model.model import InputMethodEngine
from src.model.dataset import PinyinInputDataset
from src.model.trainer import Trainer, worker_init_fn, collate_fn
# %% [code]
# 2. 配置训练参数
config = {
# 数据参数
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
"output_dir": "./training_output",
# 模型参数
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 20,
"max_seq_len": 128,
# 训练参数
"batch_size": 64, # 根据GPU内存调整
"num_epochs": 10,
"learning_rate": 3e-4,
"min_learning_rate": 1e-9,
"weight_decay": 0.1,
"warmup_ratio": 0.1,
"label_smoothing": 0.15,
"grad_accum_steps": 2, # 梯度累积模拟更大batch size
"clip_grad_norm": 1.0,
"eval_frequency": 500, # 每500步评估一次
"save_frequency": 2000, # 每2000步保存检查点
# 高级选项
"mixed_precision": True,
"use_tensorboard": True,
"seed": 42,
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
}
# %% [code]
# 3. 设置随机种子和设备
torch.manual_seed(config["seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config["seed"])
device = torch.device("cuda")
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
# %% [code]
# 4. 创建数据集和数据加载器
print("📊 创建数据集和数据加载器...")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=config["train_data_path"],
max_workers=-1, # 自动选择worker数量
max_iter_length=config["max_iter_length"],
max_seq_length=config["max_seq_len"],
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=32,
persistent_workers=True,
)
# 评估数据集
eval_dataset = PinyinInputDataset(
data_path=config["eval_data_path"],
max_workers=-1,
max_iter_length=1024, # 评估集较小
max_seq_length=config["max_seq_len"],
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=1000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=config["batch_size"],
num_workers=1,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=32,
persistent_workers=True,
)
print(f"✅ 数据加载器创建完成")
print(f" 训练批次大小: {config['batch_size']}")
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
# %% [code]
# 5. 创建模型
print("🧠 创建输入法模型...")
model = InputMethodEngine(
vocab_size=config["vocab_size"],
pinyin_vocab_size=config["pinyin_vocab_size"],
dim=config["dim"],
num_slots=config["num_slots"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
num_experts=config["num_experts"],
max_seq_len=config["max_seq_len"],
)
# 将模型移动到设备
model.to(device)
# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✅ 模型创建完成")
print(f" 总参数量: {total_params:,}")
print(f" 可训练参数量: {trainable_params:,}")
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
# %% [code]
# 6. 创建训练器
print("⚙️ 创建训练器...")
# 计算总训练步数
total_steps = int(config["max_iter_length"] / config["batch_size"])
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=config["output_dir"],
num_epochs=config["num_epochs"],
learning_rate=config["learning_rate"],
min_learning_rate=config["min_learning_rate"],
weight_decay=config["weight_decay"],
warmup_ratio=config["warmup_ratio"],
label_smoothing=config["label_smoothing"],
grad_accum_steps=config["grad_accum_steps"],
clip_grad_norm=config["clip_grad_norm"],
eval_frequency=config["eval_frequency"],
save_frequency=config["save_frequency"],
mixed_precision=config["mixed_precision"],
use_tensorboard=config["use_tensorboard"],
)
print(f"✅ 训练器创建完成")
print(f" 总训练步数: {total_steps:,}")
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
print(f" 输出目录: {config['output_dir']}")
# %% [code]
# 7. 开始训练
print("🚀 开始训练...")
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
# 开始训练(可以从检查点恢复训练)
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
print("✅ 训练完成!")
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"模型和日志保存在: {config['output_dir']}")
except KeyboardInterrupt:
print("⏹️ 训练被用户中断")
print("💾 保存当前检查点...")
trainer.save_checkpoint("interrupted")
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
except Exception as e:
print(f"❌ 训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
# %% [code]
# 8. 监控训练进度如果使用TensorBoard
if config["use_tensorboard"]:
print("📈 TensorBoard日志已记录在:")
print(f" {config['output_dir']}/tensorboard")
print("\n启动TensorBoard查看训练进度:")
print(" tensorboard --logdir ./training_output/tensorboard")
print("然后在浏览器中打开: http://localhost:6006")
# %% [code]
# 9. 加载训练好的模型进行推理(示例)
def load_trained_model(checkpoint_path):
"""加载训练好的模型进行检查点"""
print(f"📥 加载检查点: {checkpoint_path}")
# 创建与训练时相同配置的模型
loaded_model = InputMethodEngine(
vocab_size=config["vocab_size"],
pinyin_vocab_size=config["pinyin_vocab_size"],
dim=config["dim"],
num_slots=config["num_slots"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
num_experts=config["num_experts"],
max_seq_len=config["max_seq_len"],
)
# 加载检查点
checkpoint = torch.load(checkpoint_path, map_location=device)
loaded_model.load_state_dict(checkpoint["model_state_dict"])
loaded_model.to(device)
loaded_model.eval()
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
return loaded_model
# 使用示例(取消注释以使用)
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
```
### 关键说明
1. **环境要求**
- Python 3.12+
- PyTorch 2.10+
- 建议使用GPU进行训练
- 安装项目依赖:`pip install -e .`
2. **数据集格式**
- 使用Hugging Face `datasets`格式
- 必须包含`text`字段
- 支持流式读取streaming=True
3. **训练监控**
- 控制台输出训练进度和指标
- TensorBoard记录损失、准确率、学习率等
- 定期保存模型检查点
4. **可调整参数**
- `batch_size`: 根据GPU内存调整
- `learning_rate`: 建议在1e-4到5e-4之间
- `grad_accum_steps`: 模拟更大batch size
- `num_epochs`: 根据数据集大小调整
5. **故障排除**
- GPU内存不足减小`batch_size`或增加`grad_accum_steps`
- 训练不稳定:降低`learning_rate`或增加`warmup_ratio`
- 过拟合:增加`label_smoothing`或使用更大数据集
## 6. 使用指南
本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。
### 6.1 安装与准备
#### 使用 uv推荐
本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器它比传统的 pip 更快且更可靠。
1. **安装 uv**(如果尚未安装):
```bash
# Linux/macOS
curl -LsSf https://astral.sh/uv/install.sh | sh
# 或使用 pipx
pipx install uv
# Windows (PowerShell)
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
```
2. **安装项目依赖**
```bash
uv pip install -e .
```
#### 使用传统 pip
如果不使用 uv也可以用标准的 pip 安装:
```bash
# 创建并激活虚拟环境(推荐)
python -m venv .venv
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# 安装依赖
pip install -e .
```
#### 验证安装
安装完成后,可通过以下命令验证:
```bash
train-model --help
```
### 6.2 数据格式
训练数据应为Hugging Face数据集格式支持本地文件或远程数据集仓库。数据集需包含 `text` 字段并支持流式读取streaming=True
#### 本地数据集示例
```python
# dataset.py
from datasets import Dataset
data = {
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
}
dataset = Dataset.from_dict(data)
dataset.save_to_disk("./local_dataset")
```
#### 远程数据集示例
支持Hugging Face Hub或ModelScope上的数据集
- `huggingface.co/datasets/username/dataset_name`
- `modelscope.cn/datasets/username/dataset_name`
#### 数据格式要求
- **必需字段**: `text`(字符串类型,包含中文文本)
- **流式读取**: 数据集必须支持 `streaming=True` 参数
- **数据量**: 建议至少数百万条文本以获得良好效果
#### 数据预处理
数据集会自动进行以下处理:
1. 文本分词和编码
2. 拼音转换和编码
3. 上下文窗口滑动生成训练样本
4. 频率调整(削峰填谷)以平衡高频/低频字词
### 6.3 基本训练命令
使用 `train-model train` 命令开始训练:
```bash
train-model train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./output" \
--batch-size 128 \
--num-epochs 10 \
--learning-rate 1e-5
```
#### 学习率建议
根据模型架构和超参数配置4层Transformer512维度推荐使用以下学习率范围
- **标准范围**: 1e-4 ~ 5e-4
- **配合Warmup策略**:在训练初期逐步提高学习率
- **余弦退火**:使用最小学习率 1e-9 进行细调
### 6.4 参数详解
#### 数据参数
- `--train-data-path`, `-t`: 训练数据集路径(必需)
- `--eval-data-path`, `-e`: 评估数据集路径(必需)
- `--output-dir`, `-o`: 输出目录(默认:`./output`
- `--max_iter_length`: 最大迭代长度控制每次训练迭代处理的数据量默认134217728
#### 模型参数
- `--vocab-size`: 词汇表大小默认10019
- `--pinyin-vocab-size`: 拼音词汇表大小默认30
- `--dim`: 模型维度默认512
- `--num-slots`: 历史槽位数量默认8
- `--n-layers`: Transformer层数默认4
- `--n-heads`: 注意力头数默认4
- `--num-experts`: MoE专家数量默认20
- `--max-seq-len`: 最大序列长度默认128
- `--use-pinyin`: 是否使用拼音特征默认False
#### 训练参数
- `--batch-size`, `-b`: 批次大小默认128
- `--num-epochs`: 训练轮数默认10
- `--learning-rate`, `-lr`: 学习率默认1e-5
- `--min-learning-rate`: 最小学习率默认1e-9
- `--weight-decay`: 权重衰减默认0.1
- `--warmup-ratio`: 热身步数比例默认0.1
- `--label-smoothing`: 标签平滑参数默认0.15
- `--grad-accum-steps`: 梯度累积步数默认1
- `--clip-grad-norm`: 梯度裁剪范数默认1.0
- `--eval-frequency`: 评估频率默认500步
- `--save-frequency`: 保存频率默认10000步
#### 高级选项
- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用)
- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard默认启用
- `--resume-from`: 从检查点恢复训练(可选)
- `--seed`: 随机种子默认42
### 6.5 监控训练进度
训练过程中会显示:
- 当前训练步数/总步数
- 损失值和准确率
- 学习率变化
- 内存使用情况
启用TensorBoard后可使用以下命令查看可视化结果
```bash
tensorboard --logdir ./output/tensorboard
```
### 6.6 评估模型(开发中)
当前评估功能尚在开发中:
```bash
train-model evaluate \
--checkpoint "./output/checkpoint_final.pt" \
--data-path "path/to/eval/dataset" \
--batch-size 32
```
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
- 加载训练好的模型检查点
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
### 6.7 导出模型(开发中)
当前导出功能尚在开发中:
```bash
train-model export \
--checkpoint "./output/checkpoint_final.pt" \
--output "./exported_model.onnx"
```
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
- 将PyTorch模型转换为ONNX格式
- 支持在不同推理引擎上部署
- 提供优化后的推理模型
## 7. 总结
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1]在保证模型轻量4层 Transformer的同时有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构该设计在工程实现上更优雅在推理效率上更高效是轻量级输入法模型的局部最优解。