SUimeModelTraner/README.md

100 lines
5.2 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SUimeModelTraner
> 深度学习输入法引擎技术方案
## 1. 任务目标
设计一个基于上下文的输入法引擎模型输入包括光标前文本、拼音、光标后文本及历史记录输出候选词序列文字ID序列支持束搜索解码实现精准的拼音到文字转换。
## 2. 输入表示
- **四段文本**
- 光标前文本
- 拼音如“bangdao”
- 光标后文本
- 符合条件的历史记录(如“绑到|邦道|…”)
- **编码方式**使用BERT类Tokenizer统一序列长度 **L=128**或88
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
## 3. 模型架构概览
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
![模型架构示意图]
## 4. 核心模块设计
### 4.1 Embedding层与Transformer编码器
- **Embedding维度**512
- **Transformer层数**4层
- **多头注意力头数**4或8
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`
- **预训练**可选加载structBERT轻量级骨干链接已失效当前从头训练
### 4.2 槽位记忆模块
- **槽位结构**共8个槽位每个槽位最多可填充3步总计最多24步。
- **槽位嵌入**每一步选择的文字ID通过共享的embedding层转换为512维向量。
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
### 4.3 交叉注意力与注意力池化
- **Query**:当前步的槽位序列(经过位置编码后)
- **Key/Value**Transformer编码器输出 `H`
- **输出**:槽位相关的特征序列
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`维度512
### 4.4 门控网络与专家层MoE
- **门控网络**:输入 `f`输出20个专家的权重选择**top-3**专家。
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
- **专家组合**对选中的3个专家输出进行加权求和得到融合特征 `e`
### 4.5 分类头
- **设计**采用同维FFN结构为
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
输出10019维ID范围010018其中0表示终止符
- **优点**:保留特征维度,引入非线性,参数量适中。
## 5. 解码策略:束搜索
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**默认为3
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
- **终止条件**
1. 所有槽位已填满8×3=24步
2. 当前步所有候选分支的最高概率词均为 **0终止符**,则强制退出。
- **输出**:概率最高的完整槽位序列。
## 6. 训练设置
- **优化器**AdamW
- **损失函数**每一步的CrossEntropyLoss仅计算非填充位置
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
- **标签处理**每个槽位步的真实文字ID作为监督信号
## 7. 关键设计考量与潜在风险
| 设计点 | 考量 | 潜在风险/优化 |
|--------|------|----------------|
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长最多24交叉注意力计算量可控 |
| 专家层 | 20专家top-3组合增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
| 分类头 | 同维FFN兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
## 8. 后续优化方向
- **拼音增强**:引入拼音专用嵌入(音节级编码)
- **槽位建模增强**若拼接方式难以学习长距离依赖可替换为轻量级GRU但当前暂不采用
- **预训练**尝试加载开源中文BERT权重作为编码器初始化
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
---
## 附录:关键超参数(待定)
- 序列长度128
- 隐藏层维度512
- Transformer层数4
- 注意力头数4
- 专家数量20
- 束宽5
- 学习率:待调(建议 1e-4 5e-4带warmup
---
此方案结构完整,模块间接口清晰,可立即进入原型实现阶段。建议先在小规模数据上验证前向与训练流程,再逐步扩展至全量数据调优。