167 lines
8.9 KiB
Markdown
167 lines
8.9 KiB
Markdown
# 输入法预测模型架构设计 (Input Method Prediction Model)
|
||
|
||
## 1. 概述
|
||
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
|
||
|
||
## 2. 核心架构流程
|
||
数据流遵循以下路径:
|
||
`输入编码` → `Transformer 上下文编码` → `槽位记忆嵌入` → `交叉注意力融合` → `门控+专家混合 (MoE)` → `分类预测` → `束搜索解码`
|
||
|
||
### 2.1 输入层设计
|
||
模型接收三类输入,分别处理以保持语义清晰:
|
||
1. **当前文本上下文**:包含光标前文本(Prefix)和光标后文本(Suffix)。
|
||
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
|
||
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
|
||
|
||
### 2.2 模块详解
|
||
|
||
#### A. Transformer 编码器 (Context Encoder)
|
||
负责提取当前语境的深层语义表示。
|
||
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
|
||
* **骨干网络**:使用标准的 Transformer Encoder。
|
||
* **隐藏层维度**:512 [1]
|
||
* **Transformer 层数**:4 层(轻量级设计,从头训练) [1]
|
||
* **注意力头数**:4 头 [1]
|
||
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
|
||
|
||
#### B. 槽位记忆模块 (Slot Memory)
|
||
负责将非结构化的历史输入转化为结构化的记忆向量。
|
||
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
|
||
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
|
||
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。
|
||
|
||
#### C. 交叉注意力融合 (Cross-Attention Fusion)
|
||
这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。
|
||
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
|
||
* **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。
|
||
* **机制**:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。
|
||
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。
|
||
|
||
#### D. 门控与专家混合 (Gating + MoE)
|
||
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
|
||
* **专家数量**:20 个专家 [1]。
|
||
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
|
||
* **输出**:经过专家网络增强后的特征向量。
|
||
|
||
#### E. 分类头与解码
|
||
* **分类预测**:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
|
||
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
|
||
|
||
## 3. 关键超参数配置
|
||
|
||
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
|
||
|
||
| 参数项 | 推荐值 | 说明 |
|
||
| :--- | :--- | :--- |
|
||
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
|
||
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
|
||
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
|
||
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
|
||
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
|
||
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
|
||
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
|
||
|
||
## 4. 训练策略
|
||
|
||
本模型采用标准的**序列到序列(Seq2Seq)监督学习**范式,直接对目标槽位序列进行逐步预测。
|
||
|
||
### 4.1 数据构造与标签
|
||
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
|
||
* **上下文**:光标前后的文本片段。
|
||
* **拼音**:当前待输入字的拼音序列。
|
||
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
||
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
||
|
||
### 4.2 损失函数与优化
|
||
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
||
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
||
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
|
||
|
||
### 4.3 训练流程细节
|
||
1. **前向传播**:
|
||
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
|
||
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
|
||
* 分类头输出当前步所有候选字的概率分布。
|
||
2. **Teacher Forcing**:
|
||
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。
|
||
3. **反向传播**:
|
||
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
|
||
|
||
### 4.4 推理与训练的差异
|
||
* **训练时**:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
|
||
* **推理时**:由于无法获取真实标签,模型采用**束搜索(Beam Search)** [1]。
|
||
* **束宽**:默认为 5 [1]。
|
||
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
|
||
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
|
||
|
||
|
||
## 5. 代码实现示意 (PyTorch)
|
||
|
||
```python
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
class Expert(nn.Module):
|
||
def __init__(self, dim=512):
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Linear(dim, dim * 4),
|
||
nn.GELU(),
|
||
nn.Linear(dim * 4, dim)
|
||
)
|
||
def forward(self, x):
|
||
return self.net(x)
|
||
|
||
class InputMethodModel(nn.Module):
|
||
def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20):
|
||
super().__init__()
|
||
# 1. Context Encoder
|
||
self.text_emb = nn.Embedding(vocab_size, dim)
|
||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||
self.pos_emb = nn.Embedding(128, dim)
|
||
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads)
|
||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||
|
||
# 2. Slot Memory
|
||
self.slot_emb = nn.Embedding(slot_vocab_size, dim)
|
||
self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位
|
||
|
||
# 3. Cross-Attention
|
||
self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
|
||
|
||
# 4. MoE Layer
|
||
self.num_experts = num_experts
|
||
self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
|
||
self.gate = nn.Linear(dim, num_experts)
|
||
|
||
# 5. Classification Head
|
||
self.classifier = nn.Linear(dim, vocab_size)
|
||
|
||
def forward(self, text_ids, pinyin_ids, history_slot_ids):
|
||
# Encode Context
|
||
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
|
||
x += self.pos_emb(torch.arange(x.size(1)).to(x.device))
|
||
H = self.transformer(x) # [B, L, 512]
|
||
|
||
# Encode Slots
|
||
S = self.slot_emb(history_slot_ids)
|
||
S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device))
|
||
|
||
# Cross-Attention: Q=Slots, K/V=Context
|
||
fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512]
|
||
|
||
# MoE Processing
|
||
# 简化版 MoE: 对所有专家输出进行加权平均
|
||
gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts]
|
||
expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim]
|
||
moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim]
|
||
|
||
# Pooling & Predict
|
||
pooled = moe_out.mean(dim=1) # [B, 512]
|
||
logits = self.classifier(pooled)
|
||
return logits
|
||
```
|
||
|
||
## 6. 总结
|
||
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。
|