SUimeModelTraner/README.md

167 lines
8.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 输入法预测模型架构设计 (Input Method Prediction Model)
## 1. 概述
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
## 2. 核心架构流程
数据流遵循以下路径:
`输入编码``Transformer 上下文编码``槽位记忆嵌入``交叉注意力融合``门控+专家混合 (MoE)``分类预测``束搜索解码`
### 2.1 输入层设计
模型接收三类输入,分别处理以保持语义清晰:
1. **当前文本上下文**包含光标前文本Prefix和光标后文本Suffix
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
### 2.2 模块详解
#### A. Transformer 编码器 (Context Encoder)
负责提取当前语境的深层语义表示。
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
* **骨干网络**:使用标准的 Transformer Encoder。
* **隐藏层维度**512 [1]
* **Transformer 层数**4 层(轻量级设计,从头训练) [1]
* **注意力头数**4 头 [1]
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
#### B. 槽位记忆模块 (Slot Memory)
负责将非结构化的历史输入转化为结构化的记忆向量。
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`
#### C. 交叉注意力融合 (Cross-Attention Fusion)
这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
* **Key/Value (K/V)**Transformer 编码器输出的上下文表示 $H$ [1]。
* **机制**:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`
#### D. 门控与专家混合 (Gating + MoE)
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
* **专家数量**20 个专家 [1]。
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
* **输出**:经过专家网络增强后的特征向量。
#### E. 分类头与解码
* **分类预测**MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
## 3. 关键超参数配置
为确保模型性能与效率的平衡,建议采用以下超参数 [1]
| 参数项 | 推荐值 | 说明 |
| :--- | :--- | :--- |
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
## 4. 训练策略
本模型采用标准的**序列到序列Seq2Seq监督学习**范式,直接对目标槽位序列进行逐步预测。
### 4.1 数据构造与标签
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
* **上下文**:光标前后的文本片段。
* **拼音**:当前待输入字的拼音序列。
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
* **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。
### 4.2 损失函数与优化
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
* **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
### 4.3 训练流程细节
1. **前向传播**
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
* 分类头输出当前步所有候选字的概率分布。
2. **Teacher Forcing**
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。
3. **反向传播**
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
### 4.4 推理与训练的差异
* **训练时**:使用 Ground Truth真实标签作为槽位输入确保模型学习到最优的条件概率分布。
* **推理时**:由于无法获取真实标签,模型采用**束搜索Beam Search** [1]。
* **束宽**:默认为 5 [1]。
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
## 5. 代码实现示意 (PyTorch)
```python
import torch
import torch.nn as nn
class Expert(nn.Module):
def __init__(self, dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
return self.net(x)
class InputMethodModel(nn.Module):
def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20):
super().__init__()
# 1. Context Encoder
self.text_emb = nn.Embedding(vocab_size, dim)
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(128, dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# 2. Slot Memory
self.slot_emb = nn.Embedding(slot_vocab_size, dim)
self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位
# 3. Cross-Attention
self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
# 4. MoE Layer
self.num_experts = num_experts
self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
self.gate = nn.Linear(dim, num_experts)
# 5. Classification Head
self.classifier = nn.Linear(dim, vocab_size)
def forward(self, text_ids, pinyin_ids, history_slot_ids):
# Encode Context
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
x += self.pos_emb(torch.arange(x.size(1)).to(x.device))
H = self.transformer(x) # [B, L, 512]
# Encode Slots
S = self.slot_emb(history_slot_ids)
S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device))
# Cross-Attention: Q=Slots, K/V=Context
fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512]
# MoE Processing
# 简化版 MoE: 对所有专家输出进行加权平均
gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts]
expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim]
moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim]
# Pooling & Predict
pooled = moe_out.mean(dim=1) # [B, 512]
logits = self.classifier(pooled)
return logits
```
## 6. 总结
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1]在保证模型轻量4层 Transformer的同时有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构该设计在工程实现上更优雅在推理效率上更高效是轻量级输入法模型的局部最优解。