# SUimeModelTraner > 深度学习输入法引擎技术方案 ## 1. 任务目标 设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。 ## 2. 输入表示 - **四段文本**: - 光标前文本 - 拼音(如“bangdao”) - 光标后文本 - 符合条件的历史记录(如“绑到|邦道|…”) - **编码方式**:使用BERT类Tokenizer,统一序列长度 **L=128**(或88)。 - **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。 - **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。 ## 3. 模型架构概览 整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。 ![模型架构示意图] ## 4. 核心模块设计 ### 4.1 Embedding层与Transformer编码器 - **Embedding维度**:512 - **Transformer层数**:4层 - **多头注意力头数**:4或8 - **输出**:上下文表示 `H`(形状:`[batch, L, 512]`) - **预训练**:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。 ### 4.2 槽位记忆模块 - **槽位结构**:共8个槽位,每个槽位最多可填充3步,总计最多24步。 - **槽位嵌入**:每一步选择的文字ID通过共享的embedding层转换为512维向量。 - **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。 - **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。 - **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。 ### 4.3 交叉注意力与注意力池化 - **Query**:当前步的槽位序列(经过位置编码后) - **Key/Value**:Transformer编码器输出 `H` - **输出**:槽位相关的特征序列 - **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`(维度512)。 ### 4.4 门控网络与专家层(MoE) - **门控网络**:输入 `f`,输出20个专家的权重,选择**top-3**专家。 - **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。 - **专家组合**:对选中的3个专家输出进行加权求和,得到融合特征 `e`。 ### 4.5 分类头 - **设计**:采用同维FFN,结构为 `Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)` 输出10019维(ID范围0~10018,其中0表示终止符)。 - **优点**:保留特征维度,引入非线性,参数量适中。 ## 5. 解码策略:束搜索 - **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(如5)。 - **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。 - **终止条件**: 1. 所有槽位已填满(8×3=24步); 2. 当前步所有候选分支的最高概率词均为 **0(终止符)**,则强制退出。 - **输出**:概率最高的完整槽位序列。 ## 6. 训练设置 - **优化器**:AdamW - **损失函数**:每一步的CrossEntropyLoss(仅计算非填充位置) - **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组 - **标签处理**:每个槽位步的真实文字ID作为监督信号 ## 7. 关键设计考量与潜在风险 | 设计点 | 考量 | 潜在风险/优化 | |--------|------|----------------| | 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 | | 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 | | 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 | | 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout | | 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 | ## 8. 后续优化方向 - **拼音增强**:引入拼音专用嵌入(音节级编码) - **槽位建模增强**:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用) - **预训练**:尝试加载开源中文BERT权重作为编码器初始化 - **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署 --- ## 附录:关键超参数(待定) - 序列长度:128 - 隐藏层维度:512 - Transformer层数:4 - 注意力头数:4 - 专家数量:20 - 束宽:5 - 学习率:待调(建议 1e-4 ~ 5e-4,带warmup) --- 此方案结构完整,模块间接口清晰,可立即进入原型实现阶段。建议先在小规模数据上验证前向与训练流程,再逐步扩展至全量数据调优。