|
|
||
|---|---|---|
| src/model | ||
| .gitignore | ||
| .python-version | ||
| LICENSE | ||
| README.md | ||
| pyproject.toml | ||
| resign_stat.py | ||
| test.py | ||
| uv.lock | ||
README.md
SUimeModelTraner
深度学习输入法引擎技术方案
1. 任务目标
设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。
2. 输入表示
- 四段文本:
- 光标前文本
- 拼音(如“bangdao”)
- 光标后文本
- 符合条件的历史记录(如“绑到|邦道|…”)
- 编码方式:使用BERT类Tokenizer,统一序列长度 L=128(或88)。
- 段落区分:通过
token_type_ids标记段落,取值为 0,1,2,3(分别对应四类输入)。 - 拼音处理:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
3. 模型架构概览
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
![模型架构示意图]
4. 核心模块设计
4.1 Embedding层与Transformer编码器
- Embedding维度:512
- Transformer层数:4层
- 多头注意力头数:4或8
- 输出:上下文表示
H(形状:[batch, L, 512]) - 预训练:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。
4.2 槽位记忆模块
- 槽位结构:共8个槽位,每个槽位最多可填充3步,总计最多24步。
- 槽位嵌入:每一步选择的文字ID通过共享的embedding层转换为512维向量。
- 历史槽位表示:将当前步之前的所有槽位嵌入按顺序拼接,形成动态增长的序列。
- 位置编码:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
- 初始状态:第一步时历史为空,用特殊
[START]嵌入向量作为占位。
4.3 交叉注意力与注意力池化
- Query:当前步的槽位序列(经过位置编码后)
- Key/Value:Transformer编码器输出
H - 输出:槽位相关的特征序列
- 注意力池化:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量
f(维度512)。
4.4 门控网络与专家层(MoE)
- 门控网络:输入
f,输出20个专家的权重,选择top-3专家。 - 专家结构:每个专家为残差网络,输出维度 512(与隐藏层一致)。
- 专家组合:对选中的3个专家输出进行加权求和,得到融合特征
e。
4.5 分类头
- 设计:采用同维FFN,结构为
Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)
输出10019维(ID范围0~10018,其中0表示终止符)。 - 优点:保留特征维度,引入非线性,参数量适中。
5. 解码策略:束搜索
- 搜索范围:在每个槽位步内部执行束搜索,束宽设为 k(默认为3)。
- 候选维护:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
- 终止条件:
- 所有槽位已填满(8×3=24步);
- 当前步所有候选分支的最高概率词均为 0(终止符),则强制退出。
- 输出:概率最高的完整槽位序列。
6. 训练设置
- 优化器:AdamW
- 损失函数:每一步的CrossEntropyLoss(仅计算非填充位置)
- 训练数据:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
- 标签处理:每个槽位步的真实文字ID作为监督信号
7. 关键设计考量与潜在风险
| 设计点 | 考量 | 潜在风险/优化 |
|---|---|---|
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 |
| 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
| 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
8. 后续优化方向
- 拼音增强:引入拼音专用嵌入(音节级编码)
- 槽位建模增强:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用)
- 预训练:尝试加载开源中文BERT权重作为编码器初始化
- 知识蒸馏:若模型过大,可蒸馏为更小版本用于端侧部署
附录:关键超参数(待定)
- 序列长度:128
- 隐藏层维度:512
- Transformer层数:4
- 注意力头数:4
- 专家数量:20
- 束宽:5
- 学习率:待调(建议 1e-4 ~ 5e-4,带warmup)