feat: 添加输入法引擎模型设计与实现,包含架构、核心模块及训练策略

This commit is contained in:
songsenand 2026-03-23 16:30:38 +08:00
parent a8ccb3e4fe
commit b6a677f15d
8 changed files with 1685 additions and 0 deletions

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12

View File

@ -1,2 +1,99 @@
# SUimeModelTraner
> 深度学习输入法引擎技术方案
## 1. 任务目标
设计一个基于上下文的输入法引擎模型输入包括光标前文本、拼音、光标后文本及历史记录输出候选词序列文字ID序列支持束搜索解码实现精准的拼音到文字转换。
## 2. 输入表示
- **四段文本**
- 光标前文本
- 拼音如“bangdao”
- 光标后文本
- 符合条件的历史记录(如“绑到|邦道|…”)
- **编码方式**使用BERT类Tokenizer统一序列长度 **L=128**或88
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
## 3. 模型架构概览
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
![模型架构示意图]
## 4. 核心模块设计
### 4.1 Embedding层与Transformer编码器
- **Embedding维度**512
- **Transformer层数**4层
- **多头注意力头数**4或8
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`
- **预训练**可选加载structBERT轻量级骨干链接已失效当前从头训练
### 4.2 槽位记忆模块
- **槽位结构**共8个槽位每个槽位最多可填充3步总计最多24步。
- **槽位嵌入**每一步选择的文字ID通过共享的embedding层转换为512维向量。
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
### 4.3 交叉注意力与注意力池化
- **Query**:当前步的槽位序列(经过位置编码后)
- **Key/Value**Transformer编码器输出 `H`
- **输出**:槽位相关的特征序列
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`维度512
### 4.4 门控网络与专家层MoE
- **门控网络**:输入 `f`输出20个专家的权重选择**top-3**专家。
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
- **专家组合**对选中的3个专家输出进行加权求和得到融合特征 `e`
### 4.5 分类头
- **设计**采用同维FFN结构为
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
输出10019维ID范围010018其中0表示终止符
- **优点**:保留特征维度,引入非线性,参数量适中。
## 5. 解码策略:束搜索
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**如5
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
- **终止条件**
1. 所有槽位已填满8×3=24步
2. 当前步所有候选分支的最高概率词均为 **0终止符**,则强制退出。
- **输出**:概率最高的完整槽位序列。
## 6. 训练设置
- **优化器**AdamW
- **损失函数**每一步的CrossEntropyLoss仅计算非填充位置
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
- **标签处理**每个槽位步的真实文字ID作为监督信号
## 7. 关键设计考量与潜在风险
| 设计点 | 考量 | 潜在风险/优化 |
|--------|------|----------------|
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长最多24交叉注意力计算量可控 |
| 专家层 | 20专家top-3组合增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
| 分类头 | 同维FFN兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
## 8. 后续优化方向
- **拼音增强**:引入拼音专用嵌入(音节级编码)
- **槽位建模增强**若拼接方式难以学习长距离依赖可替换为轻量级GRU但当前暂不采用
- **预训练**尝试加载开源中文BERT权重作为编码器初始化
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
---
## 附录:关键超参数(待定)
- 序列长度128
- 隐藏层维度512
- Transformer层数4
- 注意力头数4
- 专家数量20
- 束宽5
- 学习率:待调(建议 1e-4 5e-4带warmup
---
此方案结构完整,模块间接口清晰,可立即进入原型实现阶段。建议先在小规模数据上验证前向与训练流程,再逐步扩展至全量数据调优。

6
hello.py Normal file
View File

@ -0,0 +1,6 @@
def main():
print("Hello from suimemodeltraner!")
if __name__ == "__main__":
main()

37
pyproject.toml Normal file
View File

@ -0,0 +1,37 @@
[project]
name = "suimemodeltraner"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"modelscope>=1.35.1",
"onnx>=1.20.1",
"torch>=2.10.0",
]
[tool.uv]
# 设置当前项目的默认索引源
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
[dependency-groups]
dev = [
"autocommit",
"pytest>=9.0.2",
]
[tool.uv.sources]
autocommit = { git = "https://gitea.winkinshly.site/songsenand/autocommit.git" }
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
# 👇 这是关键:指定包在 src/ 下
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.package-data]
model = ["assets/*"]

0
src/model/__init__.py Normal file
View File

82
src/model/components.py Normal file
View File

@ -0,0 +1,82 @@
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
# ---------------------------- 注意力池化模块----------------------------
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
# 三个可学习偏置:文本、拼音、个性化
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
def forward(self, x, mask=None, token_type_ids=None):
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
if token_type_ids is not None:
# 根据 token_type_ids 添加对应偏置
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
scores = scores + bias_per_token
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
return pooled
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
# 修复:使用 self.gelu 而不是未定义的 self.relu
x = self.gelu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x)
x = x + residual
return self.gelu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim,
d_model=512,
num_resblocks=4,
output_multiplier=1,
dropout_prob=0.3,
):
super().__init__()
self.output_dim = input_dim * output_multiplier
self.linear_in = nn.Linear(input_dim, d_model)
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout_prob),
nn.Linear(d_model, self.output_dim),
)
def forward(self, x):
x = self.linear_in(x)
for block in self.res_blocks:
x = block(x)
return self.output(x)

281
src/model/model.py Normal file
View File

@ -0,0 +1,281 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
class InputMethodEngine(nn.Module):
"""
输入法引擎模型
输入光标前/后文本拼音历史记录四段的编码序列
输出槽位序列最多24个文字ID的概率分布
"""
def __init__(
self,
input_vocab_size: int, # 输入文本的词汇表大小(含特殊符号)
output_vocab_size: int, # 输出文字的词汇表大小含终止符0
hidden_size: int = 512, # 隐藏层维度
num_layers: int = 4, # Transformer 层数
num_heads: int = 4, # 多头注意力头数
max_slot_steps: int = 24, # 最大槽位步数8槽×3步
num_experts: int = 20, # 专家数量
top_k: int = 3, # 每个token选择的专家数
expert_res_blocks: int = 4, # 每个专家内部的残差块数
dropout: float = 0.3,
use_attention_pooling: bool = False, # 是否在槽位特征后加池化(暂未使用)
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.output_vocab_size = output_vocab_size
self.max_slot_steps = max_slot_steps
self.num_experts = num_experts
self.top_k = top_k
self.use_attention_pooling = use_attention_pooling
# -------------------- 1. 文本编码器 --------------------
self.token_embedding = nn.Embedding(input_vocab_size, hidden_size)
self.position_embedding = nn.Embedding(
512, hidden_size
) # 可学习位置编码512足够长
self.token_type_embedding = nn.Embedding(4, hidden_size) # 4种段落类型
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size * 4,
dropout=dropout,
activation="gelu",
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# -------------------- 2. 槽位相关组件 --------------------
# 槽位文字嵌入(输出词汇表)
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
# 槽位位置编码(可学习,最大步数+1用于起始符
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
# 交叉注意力Query = 槽位序列Key/Value = 文本编码结果)
self.cross_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
# 可选:注意力池化(若 use_attention_pooling=True则用于聚合槽位序列特征
if use_attention_pooling:
self.attention_pooling = AttentionPooling(hidden_size)
# -------------------- 3. 门控网络 + 专家层 --------------------
self.gate = nn.Linear(hidden_size, num_experts) # 输出 logits用于选择专家
# 每个专家是一个独立的残差网络(输出维度=hidden_size
self.experts = nn.ModuleList(
[
Expert(
input_dim=hidden_size,
d_model=hidden_size,
num_resblocks=expert_res_blocks,
output_multiplier=1,
dropout_prob=dropout,
)
for _ in range(num_experts)
]
)
# -------------------- 4. 分类头同维FFN--------------------
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=False),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Linear(hidden_size, output_vocab_size),
)
# 初始化
self._init_weights()
def _init_weights(self):
"""初始化权重(简单示例)"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode_text(self, input_ids, token_type_ids, attention_mask):
"""
编码输入文本光标前拼音光标后历史记录
返回: [batch, seq_len, hidden_size] 编码后的上下文表示
"""
seq_len = input_ids.size(1)
# 位置编码
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
pos_emb = self.position_embedding(positions)
# token类型编码
type_emb = self.token_type_embedding(token_type_ids)
# 词嵌入
token_emb = self.token_embedding(input_ids)
# 相加
x = token_emb + pos_emb + type_emb
# Transformer编码器需要padding mask
# attention_mask: [batch, seq_len] 1表示有效0表示填充
# TransformerEncoder 要求 mask 为 [batch, seq_len] 且 1 表示忽略
src_key_padding_mask = attention_mask == 0 # True表示填充位置
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
return x
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
"""
单步预测根据当前槽位序列已拼接的嵌入预测下一个文字的概率分布
context: [batch, seq_len, hidden] 文本编码结果
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入已拼接
slot_seq_mask: [batch, current_len] 有效位置mask1有效
返回: [batch, output_vocab_size] 概率分布
"""
batch_size = slot_seq_emb.size(0)
# 交叉注意力Query是槽位序列通常只取最后一个步的嵌入作为Query但这里我们使用整个序列
# 为了简单我们使用整个序列作为Query然后取最后一个位置的输出因为自回归
# 方法1Query = 最后一个位置的嵌入(单个向量)
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
# 交叉注意力
attn_out, _ = self.cross_attention(
query=last_query,
key=context,
value=context,
key_padding_mask=(
context.sum(-1) == 0
), # 忽略填充位置实际应传入attention_mask
) # [batch, 1, hidden]
attn_out = attn_out.squeeze(1) # [batch, hidden]
# 门控网络选择top-k专家
gate_logits = self.gate(attn_out) # [batch, num_experts]
topk_weights, topk_indices = torch.topk(
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
)
# 归一化权重
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# 计算专家输出加权和
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # [batch]
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
# 批量获取专家输出(需对每个样本分别处理,或用循环)
# 这里简单循环,实际可优化为 gather
for b in range(batch_size):
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
# 分类头
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
return logits
def forward_train(
self,
input_ids,
token_type_ids,
attention_mask,
slot_target_ids,
slot_target_mask=None,
):
"""
训练模式一次性并行计算所有槽位步的预测
slot_target_ids: [batch, max_slot_steps] 真实标签teacher forcing
slot_target_mask: [batch, max_slot_steps] 有效步mask1表示该步存在
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
"""
batch_size, max_steps = slot_target_ids.shape
device = input_ids.device
# 1. 编码文本
context = self.encode_text(
input_ids, token_type_ids, attention_mask
) # [b, L, h]
# 2. 构建槽位序列嵌入teacher forcing使用真实标签构建
# 将每个样本的槽位ID转为embedding并添加位置编码
# slot_target_ids 形状 [b, T]
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
# 位置编码位置从0开始
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
slot_emb = slot_emb + pos_emb
# 可选加入mask无效位置置0
if slot_target_mask is not None:
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
# 3. 交叉注意力Query为整个槽位序列每个位置独立预测
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
attn_out, _ = self.cross_attention(
query=slot_emb, # [b, T, h]
key=context, # [b, L, h]
value=context,
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
) # [b, T, h]
# 4. 对每个槽位步分别通过门控+专家+分类头
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
b, T, h = attn_out.shape
attn_flat = attn_out.view(b * T, h) # [b*T, h]
# 门控网络
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
gate_probs = F.softmax(gate_logits, dim=-1)
topk_probs, topk_indices = torch.topk(
gate_probs, self.top_k, dim=-1
) # [b*T, top_k]
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
for i in range(self.top_k):
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
idx = topk_indices[:, i] # [b*T]
# 批量获取专家输出
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
# 实际生产环境应优化为组卷积或使用 einsum此处保持清晰
for expert_id in range(self.num_experts):
mask = idx == expert_id
if mask.any():
# 取出属于该专家的样本
sub_input = attn_flat[mask] # [k, h]
sub_output = self.experts[expert_id](sub_input) # [k, h]
expert_out_flat[mask] += weight[mask] * sub_output
# 分类头
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
return logits
def forward(
self,
input_ids,
token_type_ids,
attention_mask,
slot_target_ids=None,
slot_target_mask=None,
mode="train",
):
"""
统一接口
mode='train': 使用 teacher forcing 并行计算所有步 logits返回 [b, T, vocab]
mode='infer': 需结合外部循环使用 forward_single_step
"""
if mode == "train":
return self.forward_train(
input_ids,
token_type_ids,
attention_mask,
slot_target_ids,
slot_target_mask,
)
else:
raise NotImplementedError(
"Inference mode should call forward_single_step directly"
)

1181
uv.lock Normal file

File diff suppressed because it is too large Load Diff