|
|
||
|---|---|---|
| src/model | ||
| .gitignore | ||
| .python-version | ||
| LICENSE | ||
| README.md | ||
| pyproject.toml | ||
| resign_stat.py | ||
| test.py | ||
| uv.lock | ||
README.md
输入法预测模型架构设计 (Input Method Prediction Model)
1. 概述
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过结构化槽位记忆与交叉注意力机制,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了混合专家网络 (MoE) 模块。
2. 核心架构流程
数据流遵循以下路径:
输入编码 → Transformer 上下文编码 → 槽位记忆嵌入 → 交叉注意力融合 → 门控+专家混合 (MoE) → 分类预测 → 束搜索解码
2.1 输入层设计
模型接收三类输入,分别处理以保持语义清晰:
- 当前文本上下文:包含光标前文本(Prefix)和光标后文本(Suffix)。
- 拼音序列:与当前文本对应的拼音信息,作为增强特征融入文本编码。
- 历史槽位序列:最近 N 个历史输入词汇,作为结构化记忆输入。
2.2 模块详解
A. Transformer 编码器 (Context Encoder)
负责提取当前语境的深层语义表示。
- 输入处理:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用特征叠加或独立 Token 方式融入,避免双流架构的复杂性。
- 骨干网络:使用标准的 Transformer Encoder。
- 隐藏层维度:512 [1]
- Transformer 层数:4 层(轻量级设计,从头训练) [1]
- 注意力头数:4 头 [1]
- 输出:上下文表示
H,形状为[batch, L, 512][1]。
B. 槽位记忆模块 (Slot Memory)
负责将非结构化的历史输入转化为结构化的记忆向量。
- 嵌入方式:历史词汇通过独立的
Slot Embedding查找表映射。 - 位置编码:添加可学习的
Positional Embedding以保留历史输入的时间顺序信息。 - 输出:槽位序列
S,形状为[batch, Num_Slots, 512]。
C. 交叉注意力融合 (Cross-Attention Fusion)
这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。
- Query (Q):当前步的槽位序列
S(经过位置编码后)。 - Key/Value (K/V):Transformer 编码器输出的上下文表示
H[1]。 - 机制:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。
- 输出:融合后的特征序列,形状为
[batch, Num_Slots, 512]。
D. 门控与专家混合 (Gating + MoE)
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
- 专家数量:20 个专家 [1]。
- 门控机制:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
- 输出:经过专家网络增强后的特征向量。
E. 分类头与解码
- 分类预测:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
- 解码策略:推理阶段使用束搜索 (Beam Search),束宽设为 5 [1]。
3. 关键超参数配置
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
| 参数项 | 推荐值 | 说明 |
|---|---|---|
| 序列长度 (L) | 128 | 上下文窗口大小 [1] |
| 隐藏层维度 | 512 | Embedding 及 Transformer 内部维度 [1] |
| Transformer 层数 | 4 | 轻量级骨干,降低延迟 [1] |
| 注意力头数 | 4 | 适配 512 维度的高效配置 [1] |
| 专家数量 | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
| 束宽 (Beam Width) | 5 | 推理时平衡速度与准确率 [1] |
| 学习率 | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
4. 训练策略
本模型采用标准的序列到序列(Seq2Seq)监督学习范式,直接对目标槽位序列进行逐步预测。
4.1 数据构造与标签
- 输入三元组:训练数据由
(上下文, 拼音, 目标槽位序列)构成 [1]。- 上下文:光标前后的文本片段。
- 拼音:当前待输入字的拼音序列。
- 目标槽位序列:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
- 标签处理:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
4.2 损失函数与优化
- 损失函数:使用 CrossEntropyLoss 计算每一步预测结果与真实标签之间的差异 [1]。
- 掩码机制:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
- 优化器:采用 AdamW 进行参数更新 [1]。
4.3 训练流程细节
- 前向传播:
- 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
- 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
- 分类头输出当前步所有候选字的概率分布。
- Teacher Forcing:
- 在训练过程中,强制使用真实的上一槽位输出作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。
- 反向传播:
- 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
4.4 推理与训练的差异
- 训练时:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
- 推理时:由于无法获取真实标签,模型采用束搜索(Beam Search) [1]。
- 束宽:默认为 5 [1]。
- 候选维护:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
- 终止条件:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
5. 代码实现示意 (PyTorch)
import torch
import torch.nn as nn
class Expert(nn.Module):
def __init__(self, dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
return self.net(x)
class InputMethodModel(nn.Module):
def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20):
super().__init__()
# 1. Context Encoder
self.text_emb = nn.Embedding(vocab_size, dim)
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(128, dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# 2. Slot Memory
self.slot_emb = nn.Embedding(slot_vocab_size, dim)
self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位
# 3. Cross-Attention
self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
# 4. MoE Layer
self.num_experts = num_experts
self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
self.gate = nn.Linear(dim, num_experts)
# 5. Classification Head
self.classifier = nn.Linear(dim, vocab_size)
def forward(self, text_ids, pinyin_ids, history_slot_ids):
# Encode Context
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
x += self.pos_emb(torch.arange(x.size(1)).to(x.device))
H = self.transformer(x) # [B, L, 512]
# Encode Slots
S = self.slot_emb(history_slot_ids)
S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device))
# Cross-Attention: Q=Slots, K/V=Context
fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512]
# MoE Processing
# 简化版 MoE: 对所有专家输出进行加权平均
gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts]
expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim]
moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim]
# Pooling & Predict
pooled = moe_out.mean(dim=1) # [B, 512]
logits = self.classifier(pooled)
return logits
6. 总结
本方案通过单流 Transformer 编码结合结构化槽位交叉注意力,并引入20个专家的 MoE 模块 [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。