100 lines
5.2 KiB
Markdown
100 lines
5.2 KiB
Markdown
# SUimeModelTraner
|
||
|
||
> 深度学习输入法引擎技术方案
|
||
|
||
## 1. 任务目标
|
||
设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。
|
||
|
||
## 2. 输入表示
|
||
- **四段文本**:
|
||
- 光标前文本
|
||
- 拼音(如“bangdao”)
|
||
- 光标后文本
|
||
- 符合条件的历史记录(如“绑到|邦道|…”)
|
||
- **编码方式**:使用BERT类Tokenizer,统一序列长度 **L=128**(或88)。
|
||
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
|
||
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
|
||
|
||
## 3. 模型架构概览
|
||
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
|
||
|
||
![模型架构示意图]
|
||
|
||
## 4. 核心模块设计
|
||
|
||
### 4.1 Embedding层与Transformer编码器
|
||
- **Embedding维度**:512
|
||
- **Transformer层数**:4层
|
||
- **多头注意力头数**:4或8
|
||
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`)
|
||
- **预训练**:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。
|
||
|
||
### 4.2 槽位记忆模块
|
||
- **槽位结构**:共8个槽位,每个槽位最多可填充3步,总计最多24步。
|
||
- **槽位嵌入**:每一步选择的文字ID通过共享的embedding层转换为512维向量。
|
||
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
|
||
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
|
||
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
|
||
|
||
### 4.3 交叉注意力与注意力池化
|
||
- **Query**:当前步的槽位序列(经过位置编码后)
|
||
- **Key/Value**:Transformer编码器输出 `H`
|
||
- **输出**:槽位相关的特征序列
|
||
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`(维度512)。
|
||
|
||
### 4.4 门控网络与专家层(MoE)
|
||
- **门控网络**:输入 `f`,输出20个专家的权重,选择**top-3**专家。
|
||
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
|
||
- **专家组合**:对选中的3个专家输出进行加权求和,得到融合特征 `e`。
|
||
|
||
### 4.5 分类头
|
||
- **设计**:采用同维FFN,结构为
|
||
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
|
||
输出10019维(ID范围0~10018,其中0表示终止符)。
|
||
- **优点**:保留特征维度,引入非线性,参数量适中。
|
||
|
||
## 5. 解码策略:束搜索
|
||
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(默认为3)。
|
||
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
|
||
- **终止条件**:
|
||
1. 所有槽位已填满(8×3=24步);
|
||
2. 当前步所有候选分支的最高概率词均为 **0(终止符)**,则强制退出。
|
||
- **输出**:概率最高的完整槽位序列。
|
||
|
||
## 6. 训练设置
|
||
- **优化器**:AdamW
|
||
- **损失函数**:每一步的CrossEntropyLoss(仅计算非填充位置)
|
||
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
|
||
- **标签处理**:每个槽位步的真实文字ID作为监督信号
|
||
|
||
## 7. 关键设计考量与潜在风险
|
||
|
||
| 设计点 | 考量 | 潜在风险/优化 |
|
||
|--------|------|----------------|
|
||
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
|
||
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 |
|
||
| 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
|
||
| 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
|
||
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
|
||
|
||
## 8. 后续优化方向
|
||
- **拼音增强**:引入拼音专用嵌入(音节级编码)
|
||
- **槽位建模增强**:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用)
|
||
- **预训练**:尝试加载开源中文BERT权重作为编码器初始化
|
||
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
|
||
|
||
---
|
||
|
||
## 附录:关键超参数(待定)
|
||
- 序列长度:128
|
||
- 隐藏层维度:512
|
||
- Transformer层数:4
|
||
- 注意力头数:4
|
||
- 专家数量:20
|
||
- 束宽:5
|
||
- 学习率:待调(建议 1e-4 ~ 5e-4,带warmup)
|
||
|
||
---
|
||
|
||
此方案结构完整,模块间接口清晰,可立即进入原型实现阶段。建议先在小规模数据上验证前向与训练流程,再逐步扩展至全量数据调优。
|