|
|
||
|---|---|---|
| src/model | ||
| .gitignore | ||
| .python-version | ||
| LICENSE | ||
| README.md | ||
| pyproject.toml | ||
| resign_stat.py | ||
| test.py | ||
| test_trainer.py | ||
| uv.lock | ||
README.md
输入法预测模型架构设计 (Input Method Prediction Model)
1. 概述
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过结构化槽位记忆与交叉注意力机制,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了混合专家网络 (MoE) 模块。
2. 核心架构流程
数据流遵循以下路径:
输入编码 → Transformer 上下文编码 → 槽位记忆嵌入 → 交叉注意力融合 → 门控+专家混合 (MoE) → 分类预测 → 束搜索解码
2.1 输入层设计
模型接收三类输入,分别处理以保持语义清晰:
- 当前文本上下文:包含光标前文本(Prefix)和光标后文本(Suffix)。
- 拼音序列:与当前文本对应的拼音信息,作为增强特征融入文本编码。
- 历史槽位序列:最近 N 个历史输入词汇,作为结构化记忆输入。
2.2 模块详解
A. Transformer 编码器 (Context Encoder)
负责提取当前语境的深层语义表示。
- 输入处理:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用特征叠加或独立 Token 方式融入,避免双流架构的复杂性。
- 骨干网络:使用标准的 Transformer Encoder。
- 隐藏层维度:512 [1]
- Transformer 层数:4 层(轻量级设计,从头训练) [1]
- 注意力头数:4 头 [1]
- 输出:上下文表示
H,形状为[batch, L, 512][1]。
B. 槽位记忆模块 (Slot Memory)
负责将非结构化的历史输入转化为结构化的记忆向量。
- 嵌入方式:历史词汇通过独立的
Slot Embedding查找表映射。 - 位置编码:添加可学习的
Positional Embedding以保留历史输入的时间顺序信息。 - 输出:槽位序列
S,形状为[batch, Num_Slots, 512]。
C. 交叉注意力融合 (Cross-Attention Fusion)
这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。
- Query (Q):当前步的槽位序列
S(经过位置编码后)。 - Key/Value (K/V):Transformer 编码器输出的上下文表示
H[1]。 - 机制:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。
- 输出:融合后的特征序列,形状为
[batch, Num_Slots, 512]。
D. 门控与专家混合 (Gating + MoE)
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
- 专家数量:20 个专家 [1]。
- 门控机制:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
- 输出:经过专家网络增强后的特征向量。
E. 分类头与解码
- 分类预测:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
- 解码策略:推理阶段使用束搜索 (Beam Search),束宽设为 5 [1]。
3. 关键超参数配置
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
| 参数项 | 推荐值 | 说明 |
|---|---|---|
| 序列长度 (L) | 128 | 上下文窗口大小 [1] |
| 隐藏层维度 | 512 | Embedding 及 Transformer 内部维度 [1] |
| Transformer 层数 | 4 | 轻量级骨干,降低延迟 [1] |
| 注意力头数 | 4 | 适配 512 维度的高效配置 [1] |
| 专家数量 | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
| 束宽 (Beam Width) | 5 | 推理时平衡速度与准确率 [1] |
| 学习率 | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
4. 训练策略
本模型采用标准的序列到序列(Seq2Seq)监督学习范式,直接对目标槽位序列进行逐步预测。
4.1 数据构造与标签
- 输入三元组:训练数据由
(上下文, 拼音, 目标槽位序列)构成 [1]。- 上下文:光标前后的文本片段。
- 拼音:当前待输入字的拼音序列。
- 目标槽位序列:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
- 标签处理:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
4.2 损失函数与优化
- 损失函数:使用 CrossEntropyLoss 计算每一步预测结果与真实标签之间的差异 [1]。
- 掩码机制:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
- 优化器:采用 AdamW 进行参数更新 [1]。
4.3 训练流程细节
- 前向传播:
- 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
- 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
- 分类头输出当前步所有候选字的概率分布。
- Teacher Forcing:
- 在训练过程中,强制使用真实的上一槽位输出作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。
- 反向传播:
- 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
4.4 推理与训练的差异
- 训练时:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
- 推理时:由于无法获取真实标签,模型采用束搜索(Beam Search) [1]。
- 束宽:默认为 5 [1]。
- 候选维护:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
- 终止条件:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
5. Jupyter Lab 训练示例
以下是在 Jupyter Lab 环境中使用 trainer.Trainer 类训练输入法模型的完整示例:
# %% [markdown]
# # 输入法模型训练示例
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
# %% [code]
# 1. 导入必要的库
import sys
import os
from pathlib import Path
from datetime import datetime
import torch
from torch.utils.data import DataLoader
# 添加项目路径(适应不同的Jupyter Lab运行位置)
project_root = Path.cwd()
# 检查当前目录是否包含src目录,如果不包含则使用父目录
if not (project_root / "src").exists():
project_root = project_root.parent
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
# 导入项目模块
from src.model.model import InputMethodEngine
from src.model.dataset import PinyinInputDataset
from src.model.trainer import Trainer, worker_init_fn, collate_fn
# %% [code]
# 2. 配置训练参数
config = {
# 数据参数
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
"output_dir": "./training_output",
# 模型参数
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 20,
"max_seq_len": 128,
# 训练参数
"batch_size": 64, # 根据GPU内存调整
"num_epochs": 10,
"learning_rate": 3e-4,
"min_learning_rate": 1e-9,
"weight_decay": 0.1,
"warmup_ratio": 0.1,
"label_smoothing": 0.15,
"grad_accum_steps": 2, # 梯度累积,模拟更大batch size
"clip_grad_norm": 1.0,
"eval_frequency": 500, # 每500步评估一次
"save_frequency": 2000, # 每2000步保存检查点
# 高级选项
"mixed_precision": True,
"use_tensorboard": True,
"seed": 42,
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
}
# %% [code]
# 3. 设置随机种子和设备
torch.manual_seed(config["seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config["seed"])
device = torch.device("cuda")
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
# %% [code]
# 4. 创建数据集和数据加载器
print("📊 创建数据集和数据加载器...")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=config["train_data_path"],
max_workers=-1, # 自动选择worker数量
max_iter_length=config["max_iter_length"],
max_seq_length=config["max_seq_len"],
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=32,
persistent_workers=True,
)
# 评估数据集
eval_dataset = PinyinInputDataset(
data_path=config["eval_data_path"],
max_workers=-1,
max_iter_length=1024, # 评估集较小
max_seq_length=config["max_seq_len"],
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=1000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=config["batch_size"],
num_workers=1,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=32,
persistent_workers=True,
)
print(f"✅ 数据加载器创建完成")
print(f" 训练批次大小: {config['batch_size']}")
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
# %% [code]
# 5. 创建模型
print("🧠 创建输入法模型...")
model = InputMethodEngine(
vocab_size=config["vocab_size"],
pinyin_vocab_size=config["pinyin_vocab_size"],
dim=config["dim"],
num_slots=config["num_slots"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
num_experts=config["num_experts"],
max_seq_len=config["max_seq_len"],
)
# 将模型移动到设备
model.to(device)
# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✅ 模型创建完成")
print(f" 总参数量: {total_params:,}")
print(f" 可训练参数量: {trainable_params:,}")
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
# %% [code]
# 6. 创建训练器
print("⚙️ 创建训练器...")
# 计算总训练步数
total_steps = int(config["max_iter_length"] / config["batch_size"])
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=config["output_dir"],
num_epochs=config["num_epochs"],
learning_rate=config["learning_rate"],
min_learning_rate=config["min_learning_rate"],
weight_decay=config["weight_decay"],
warmup_ratio=config["warmup_ratio"],
label_smoothing=config["label_smoothing"],
grad_accum_steps=config["grad_accum_steps"],
clip_grad_norm=config["clip_grad_norm"],
eval_frequency=config["eval_frequency"],
save_frequency=config["save_frequency"],
mixed_precision=config["mixed_precision"],
use_tensorboard=config["use_tensorboard"],
)
print(f"✅ 训练器创建完成")
print(f" 总训练步数: {total_steps:,}")
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
print(f" 输出目录: {config['output_dir']}")
# %% [code]
# 7. 开始训练
print("🚀 开始训练...")
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
# 开始训练(可以从检查点恢复训练)
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
print("✅ 训练完成!")
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"模型和日志保存在: {config['output_dir']}")
except KeyboardInterrupt:
print("⏹️ 训练被用户中断")
print("💾 保存当前检查点...")
trainer.save_checkpoint("interrupted")
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
except Exception as e:
print(f"❌ 训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
# %% [code]
# 8. 监控训练进度(如果使用TensorBoard)
if config["use_tensorboard"]:
print("📈 TensorBoard日志已记录在:")
print(f" {config['output_dir']}/tensorboard")
print("\n启动TensorBoard查看训练进度:")
print(" tensorboard --logdir ./training_output/tensorboard")
print("然后在浏览器中打开: http://localhost:6006")
# %% [code]
# 9. 加载训练好的模型进行推理(示例)
def load_trained_model(checkpoint_path):
"""加载训练好的模型进行检查点"""
print(f"📥 加载检查点: {checkpoint_path}")
# 创建与训练时相同配置的模型
loaded_model = InputMethodEngine(
vocab_size=config["vocab_size"],
pinyin_vocab_size=config["pinyin_vocab_size"],
dim=config["dim"],
num_slots=config["num_slots"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
num_experts=config["num_experts"],
max_seq_len=config["max_seq_len"],
)
# 加载检查点
checkpoint = torch.load(checkpoint_path, map_location=device)
loaded_model.load_state_dict(checkpoint["model_state_dict"])
loaded_model.to(device)
loaded_model.eval()
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
return loaded_model
# 使用示例(取消注释以使用)
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
关键说明
-
环境要求:
- Python 3.12+
- PyTorch 2.10+
- 建议使用GPU进行训练
- 安装项目依赖:
pip install -e .
-
数据集格式:
- 使用Hugging Face
datasets格式 - 必须包含
text字段 - 支持流式读取(streaming=True)
- 使用Hugging Face
-
训练监控:
- 控制台输出训练进度和指标
- TensorBoard记录损失、准确率、学习率等
- 定期保存模型检查点
-
可调整参数:
batch_size: 根据GPU内存调整learning_rate: 建议在1e-4到5e-4之间grad_accum_steps: 模拟更大batch sizenum_epochs: 根据数据集大小调整
-
故障排除:
- GPU内存不足:减小
batch_size或增加grad_accum_steps - 训练不稳定:降低
learning_rate或增加warmup_ratio - 过拟合:增加
label_smoothing或使用更大数据集
- GPU内存不足:减小
6. 使用指南
本项目的训练功能通过命令行工具 train-model 提供,支持训练、评估和导出模型。
6.1 安装与准备
使用 uv(推荐)
本项目使用 uv 作为Python包管理器,它比传统的 pip 更快且更可靠。
-
安装 uv(如果尚未安装):
# Linux/macOS curl -LsSf https://astral.sh/uv/install.sh | sh # 或使用 pipx pipx install uv # Windows (PowerShell) powershell -c "irm https://astral.sh/uv/install.ps1 | iex" -
安装项目依赖:
uv pip install -e .
使用传统 pip
如果不使用 uv,也可以用标准的 pip 安装:
# 创建并激活虚拟环境(推荐)
python -m venv .venv
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# 安装依赖
pip install -e .
验证安装
安装完成后,可通过以下命令验证:
train-model --help
6.2 数据格式
训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 text 字段,并支持流式读取(streaming=True)。
本地数据集示例
# dataset.py
from datasets import Dataset
data = {
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
}
dataset = Dataset.from_dict(data)
dataset.save_to_disk("./local_dataset")
远程数据集示例
支持Hugging Face Hub或ModelScope上的数据集:
huggingface.co/datasets/username/dataset_namemodelscope.cn/datasets/username/dataset_name
数据格式要求
- 必需字段:
text(字符串类型,包含中文文本) - 流式读取: 数据集必须支持
streaming=True参数 - 数据量: 建议至少数百万条文本以获得良好效果
数据预处理
数据集会自动进行以下处理:
- 文本分词和编码
- 拼音转换和编码
- 上下文窗口滑动生成训练样本
- 频率调整(削峰填谷)以平衡高频/低频字词
6.3 基本训练命令
使用 train-model train 命令开始训练:
train-model train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./output" \
--batch-size 128 \
--num-epochs 10 \
--learning-rate 1e-5
学习率建议
根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围:
- 标准范围: 1e-4 ~ 5e-4
- 配合Warmup策略:在训练初期逐步提高学习率
- 余弦退火:使用最小学习率 1e-9 进行细调
6.4 参数详解
数据参数
--train-data-path,-t: 训练数据集路径(必需)--eval-data-path,-e: 评估数据集路径(必需)--output-dir,-o: 输出目录(默认:./output)--max_iter_length: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728)
模型参数
--vocab-size: 词汇表大小(默认:10019)--pinyin-vocab-size: 拼音词汇表大小(默认:30)--dim: 模型维度(默认:512)--num-slots: 历史槽位数量(默认:8)--n-layers: Transformer层数(默认:4)--n-heads: 注意力头数(默认:4)--num-experts: MoE专家数量(默认:20)--max-seq-len: 最大序列长度(默认:128)--use-pinyin: 是否使用拼音特征(默认:False)
训练参数
--batch-size,-b: 批次大小(默认:128)--num-epochs: 训练轮数(默认:10)--learning-rate,-lr: 学习率(默认:1e-5)--min-learning-rate: 最小学习率(默认:1e-9)--weight-decay: 权重衰减(默认:0.1)--warmup-ratio: 热身步数比例(默认:0.1)--label-smoothing: 标签平滑参数(默认:0.15)--grad-accum-steps: 梯度累积步数(默认:1)--clip-grad-norm: 梯度裁剪范数(默认:1.0)--eval-frequency: 评估频率(默认:500步)--save-frequency: 保存频率(默认:10000步)
高级选项
--mixed-precision/--no-mixed-precision: 是否使用混合精度训练(默认:启用)--tensorboard/--no-tensorboard: 是否使用TensorBoard(默认:启用)--resume-from: 从检查点恢复训练(可选)--seed: 随机种子(默认:42)
6.5 监控训练进度
训练过程中会显示:
- 当前训练步数/总步数
- 损失值和准确率
- 学习率变化
- 内存使用情况
启用TensorBoard后,可使用以下命令查看可视化结果:
tensorboard --logdir ./output/tensorboard
6.6 评估模型(开发中)
当前评估功能尚在开发中:
train-model evaluate \
--checkpoint "./output/checkpoint_final.pt" \
--data-path "path/to/eval/dataset" \
--batch-size 32
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
- 加载训练好的模型检查点
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
6.7 导出模型(开发中)
当前导出功能尚在开发中:
train-model export \
--checkpoint "./output/checkpoint_final.pt" \
--output "./exported_model.onnx"
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
- 将PyTorch模型转换为ONNX格式
- 支持在不同推理引擎上部署
- 提供优化后的推理模型
7. 总结
本方案通过单流 Transformer 编码结合结构化槽位交叉注意力,并引入20个专家的 MoE 模块 [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。