feat: 更新输入法模型架构设计文档并重构核心组件代码
This commit is contained in:
parent
fd49058764
commit
1af85a36bc
|
|
@ -174,4 +174,6 @@ cython_debug/
|
|||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
uv.lock
|
||||
uv.lock
|
||||
|
||||
data/*
|
||||
|
|
|
|||
225
README.md
225
README.md
|
|
@ -1,95 +1,166 @@
|
|||
# SUimeModelTraner
|
||||
# 输入法预测模型架构设计 (Input Method Prediction Model)
|
||||
|
||||
> 深度学习输入法引擎技术方案
|
||||
## 1. 概述
|
||||
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
|
||||
|
||||
## 1. 任务目标
|
||||
设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。
|
||||
## 2. 核心架构流程
|
||||
数据流遵循以下路径:
|
||||
`输入编码` → `Transformer 上下文编码` → `槽位记忆嵌入` → `交叉注意力融合` → `门控+专家混合 (MoE)` → `分类预测` → `束搜索解码`
|
||||
|
||||
## 2. 输入表示
|
||||
- **四段文本**:
|
||||
- 光标前文本
|
||||
- 拼音(如“bangdao”)
|
||||
- 光标后文本
|
||||
- 符合条件的历史记录(如“绑到|邦道|…”)
|
||||
- **编码方式**:使用BERT类Tokenizer,统一序列长度 **L=128**(或88)。
|
||||
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
|
||||
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
|
||||
### 2.1 输入层设计
|
||||
模型接收三类输入,分别处理以保持语义清晰:
|
||||
1. **当前文本上下文**:包含光标前文本(Prefix)和光标后文本(Suffix)。
|
||||
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
|
||||
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
|
||||
|
||||
## 3. 模型架构概览
|
||||
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
|
||||
### 2.2 模块详解
|
||||
|
||||
![模型架构示意图]
|
||||
#### A. Transformer 编码器 (Context Encoder)
|
||||
负责提取当前语境的深层语义表示。
|
||||
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
|
||||
* **骨干网络**:使用标准的 Transformer Encoder。
|
||||
* **隐藏层维度**:512 [1]
|
||||
* **Transformer 层数**:4 层(轻量级设计,从头训练) [1]
|
||||
* **注意力头数**:4 头 [1]
|
||||
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
|
||||
|
||||
## 4. 核心模块设计
|
||||
#### B. 槽位记忆模块 (Slot Memory)
|
||||
负责将非结构化的历史输入转化为结构化的记忆向量。
|
||||
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
|
||||
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
|
||||
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。
|
||||
|
||||
### 4.1 Embedding层与Transformer编码器
|
||||
- **Embedding维度**:512
|
||||
- **Transformer层数**:4层
|
||||
- **多头注意力头数**:4或8
|
||||
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`)
|
||||
- **预训练**:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。
|
||||
#### C. 交叉注意力融合 (Cross-Attention Fusion)
|
||||
这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。
|
||||
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
|
||||
* **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。
|
||||
* **机制**:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。
|
||||
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。
|
||||
|
||||
### 4.2 槽位记忆模块
|
||||
- **槽位结构**:共8个槽位,每个槽位最多可填充3步,总计最多24步。
|
||||
- **槽位嵌入**:每一步选择的文字ID通过共享的embedding层转换为512维向量。
|
||||
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
|
||||
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
|
||||
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
|
||||
#### D. 门控与专家混合 (Gating + MoE)
|
||||
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
|
||||
* **专家数量**:20 个专家 [1]。
|
||||
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
|
||||
* **输出**:经过专家网络增强后的特征向量。
|
||||
|
||||
### 4.3 交叉注意力与注意力池化
|
||||
- **Query**:当前步的槽位序列(经过位置编码后)
|
||||
- **Key/Value**:Transformer编码器输出 `H`
|
||||
- **输出**:槽位相关的特征序列
|
||||
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`(维度512)。
|
||||
#### E. 分类头与解码
|
||||
* **分类预测**:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
|
||||
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
|
||||
|
||||
### 4.4 门控网络与专家层(MoE)
|
||||
- **门控网络**:输入 `f`,输出20个专家的权重,选择**top-3**专家。
|
||||
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
|
||||
- **专家组合**:对选中的3个专家输出进行加权求和,得到融合特征 `e`。
|
||||
## 3. 关键超参数配置
|
||||
|
||||
### 4.5 分类头
|
||||
- **设计**:采用同维FFN,结构为
|
||||
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
|
||||
输出10019维(ID范围0~10018,其中0表示终止符)。
|
||||
- **优点**:保留特征维度,引入非线性,参数量适中。
|
||||
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
|
||||
|
||||
## 5. 解码策略:束搜索
|
||||
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(默认为3)。
|
||||
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
|
||||
- **终止条件**:
|
||||
1. 所有槽位已填满(8×3=24步);
|
||||
2. 当前步所有候选分支的最高概率词均为 **0(终止符)**,则强制退出。
|
||||
- **输出**:概率最高的完整槽位序列。
|
||||
| 参数项 | 推荐值 | 说明 |
|
||||
| :--- | :--- | :--- |
|
||||
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
|
||||
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
|
||||
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
|
||||
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
|
||||
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
|
||||
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
|
||||
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
|
||||
|
||||
## 6. 训练设置
|
||||
- **优化器**:AdamW
|
||||
- **损失函数**:每一步的CrossEntropyLoss(仅计算非填充位置)
|
||||
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
|
||||
- **标签处理**:每个槽位步的真实文字ID作为监督信号
|
||||
## 4. 训练策略
|
||||
|
||||
## 7. 关键设计考量与潜在风险
|
||||
本模型采用标准的**序列到序列(Seq2Seq)监督学习**范式,直接对目标槽位序列进行逐步预测。
|
||||
|
||||
| 设计点 | 考量 | 潜在风险/优化 |
|
||||
|--------|------|----------------|
|
||||
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
|
||||
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 |
|
||||
| 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
|
||||
| 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
|
||||
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
|
||||
### 4.1 数据构造与标签
|
||||
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
|
||||
* **上下文**:光标前后的文本片段。
|
||||
* **拼音**:当前待输入字的拼音序列。
|
||||
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
||||
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
||||
|
||||
## 8. 后续优化方向
|
||||
- **拼音增强**:引入拼音专用嵌入(音节级编码)
|
||||
- **槽位建模增强**:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用)
|
||||
- **预训练**:尝试加载开源中文BERT权重作为编码器初始化
|
||||
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
|
||||
### 4.2 损失函数与优化
|
||||
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
||||
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
||||
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
|
||||
|
||||
---
|
||||
### 4.3 训练流程细节
|
||||
1. **前向传播**:
|
||||
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
|
||||
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
|
||||
* 分类头输出当前步所有候选字的概率分布。
|
||||
2. **Teacher Forcing**:
|
||||
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。
|
||||
3. **反向传播**:
|
||||
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
|
||||
|
||||
## 附录:关键超参数(待定)
|
||||
- 序列长度:128
|
||||
- 隐藏层维度:512
|
||||
- Transformer层数:4
|
||||
- 注意力头数:4
|
||||
- 专家数量:20
|
||||
- 束宽:5
|
||||
- 学习率:待调(建议 1e-4 ~ 5e-4,带warmup)
|
||||
### 4.4 推理与训练的差异
|
||||
* **训练时**:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
|
||||
* **推理时**:由于无法获取真实标签,模型采用**束搜索(Beam Search)** [1]。
|
||||
* **束宽**:默认为 5 [1]。
|
||||
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
|
||||
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
|
||||
|
||||
|
||||
## 5. 代码实现示意 (PyTorch)
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Expert(nn.Module):
|
||||
def __init__(self, dim=512):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, dim)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class InputMethodModel(nn.Module):
|
||||
def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20):
|
||||
super().__init__()
|
||||
# 1. Context Encoder
|
||||
self.text_emb = nn.Embedding(vocab_size, dim)
|
||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||
self.pos_emb = nn.Embedding(128, dim)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||
|
||||
# 2. Slot Memory
|
||||
self.slot_emb = nn.Embedding(slot_vocab_size, dim)
|
||||
self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位
|
||||
|
||||
# 3. Cross-Attention
|
||||
self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
|
||||
|
||||
# 4. MoE Layer
|
||||
self.num_experts = num_experts
|
||||
self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
|
||||
self.gate = nn.Linear(dim, num_experts)
|
||||
|
||||
# 5. Classification Head
|
||||
self.classifier = nn.Linear(dim, vocab_size)
|
||||
|
||||
def forward(self, text_ids, pinyin_ids, history_slot_ids):
|
||||
# Encode Context
|
||||
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
|
||||
x += self.pos_emb(torch.arange(x.size(1)).to(x.device))
|
||||
H = self.transformer(x) # [B, L, 512]
|
||||
|
||||
# Encode Slots
|
||||
S = self.slot_emb(history_slot_ids)
|
||||
S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device))
|
||||
|
||||
# Cross-Attention: Q=Slots, K/V=Context
|
||||
fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512]
|
||||
|
||||
# MoE Processing
|
||||
# 简化版 MoE: 对所有专家输出进行加权平均
|
||||
gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts]
|
||||
expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim]
|
||||
moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim]
|
||||
|
||||
# Pooling & Predict
|
||||
pooled = moe_out.mean(dim=1) # [B, 512]
|
||||
logits = self.classifier(pooled)
|
||||
return logits
|
||||
```
|
||||
|
||||
## 6. 总结
|
||||
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。
|
||||
|
|
|
|||
|
|
@ -1,30 +0,0 @@
|
|||
{
|
||||
"add_cross_attention": false,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"bos_token_id": null,
|
||||
"classifier_dropout": null,
|
||||
"directionality": "bidi",
|
||||
"eos_token_id": null,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 512,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 2048,
|
||||
"is_decoder": false,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 6,
|
||||
"pad_token_id": 0,
|
||||
"pooler_fc_size": 768,
|
||||
"pooler_num_attention_heads": 12,
|
||||
"pooler_num_fc_layers": 3,
|
||||
"pooler_size_per_head": 128,
|
||||
"pooler_type": "first_token_transform",
|
||||
"tie_word_embeddings": true,
|
||||
"transformers_version": "5.1.0",
|
||||
"type_vocab_size": 4,
|
||||
"use_cache": true,
|
||||
"vocab_size": 21128
|
||||
}
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
import torch
|
||||
import torch.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# ---------------------------- 注意力池化模块----------------------------
|
||||
|
|
@ -80,3 +78,270 @@ class Expert(nn.Module):
|
|||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
return self.output(x)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. 上下文编码器 (Context Encoder)
|
||||
# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1]
|
||||
# ------------------------------------------------------------------
|
||||
class ContextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# Embeddings
|
||||
self.text_emb = nn.Embedding(vocab_size, dim)
|
||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||
self.pos_emb = nn.Embedding(max_len, dim)
|
||||
|
||||
# Transformer Encoder (4 layers, 4 heads) [1]
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=dim,
|
||||
nhead=n_heads,
|
||||
dim_feedforward=dim * 4,
|
||||
dropout=0.1,
|
||||
batch_first=True, # 方便处理 [B, L, D]
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||
|
||||
# LayerNorm for stability
|
||||
self.ln = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, text_ids, pinyin_ids, mask=None):
|
||||
"""
|
||||
Args:
|
||||
text_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, seq_len] (假设已对齐,若不对齐需预处理)
|
||||
mask: [batch, seq_len] optional padding mask
|
||||
Returns:
|
||||
H: [batch, seq_len, 512] Context representation [1]
|
||||
"""
|
||||
# 1. Embedding Fusion: Text + Pinyin + Position
|
||||
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
|
||||
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
|
||||
|
||||
seq_len = x.size(1)
|
||||
pos_ids = (
|
||||
torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids)
|
||||
)
|
||||
x += self.pos_emb(pos_ids)
|
||||
|
||||
# 2. Transformer Encoding
|
||||
# src_key_padding_mask expects True for padding positions
|
||||
if mask is not None:
|
||||
# Convert 0/1 mask to bool mask where True is padding
|
||||
src_mask = mask == 0
|
||||
else:
|
||||
src_mask = None
|
||||
|
||||
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||||
return self.ln(H)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. 槽位记忆模块 (Slot Memory)
|
||||
# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1]
|
||||
# ------------------------------------------------------------------
|
||||
class SlotMemory(nn.Module):
|
||||
def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512):
|
||||
super().__init__()
|
||||
self.max_slots = max_slots
|
||||
self.steps_per_slot = steps_per_slot
|
||||
self.total_steps = max_slots * steps_per_slot # 24 steps [1]
|
||||
|
||||
# Shared embedding layer for history tokens [1]
|
||||
self.emb = nn.Embedding(vocab_size, dim)
|
||||
|
||||
# Learnable positional embeddings for the flattened sequence [1]
|
||||
self.pos_emb = nn.Embedding(self.total_steps, dim)
|
||||
|
||||
# Start token embedding for empty slots [1]
|
||||
self.start_emb = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
def forward(self, history_ids):
|
||||
"""
|
||||
Args:
|
||||
history_ids: [batch, total_steps]
|
||||
Flattened sequence of history tokens.
|
||||
Empty positions should be filled with a special PAD or handled via mask.
|
||||
Returns:
|
||||
S: [batch, total_steps, 512] Slot sequence representation [1]
|
||||
"""
|
||||
# Embed history tokens
|
||||
S = self.emb(history_ids) # [B, 24, 512]
|
||||
|
||||
# Add positional embeddings
|
||||
pos_ids = (
|
||||
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
||||
)
|
||||
S += self.pos_emb(pos_ids)
|
||||
|
||||
return S
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. 交叉注意力融合 (Cross-Attention Fusion)
|
||||
# 对应 README: Query=Slots, Key/Value=Context H [1]
|
||||
# ------------------------------------------------------------------
|
||||
class CrossAttentionFusion(nn.Module):
|
||||
def __init__(self, dim=512, n_heads=4):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads"
|
||||
|
||||
# Linear projections for Q, K, V
|
||||
self.q_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.out_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.ln = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, slots_S, context_H, slot_mask=None, context_mask=None):
|
||||
"""
|
||||
Args:
|
||||
slots_S: [batch, num_slots_steps, dim] Query
|
||||
context_H: [batch, ctx_len, dim] Key/Value
|
||||
slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention)
|
||||
context_mask: [batch, ctx_len] Optional padding mask
|
||||
Returns:
|
||||
Fused: [batch, num_slots_steps, dim]
|
||||
"""
|
||||
batch_size, num_slots, _ = slots_S.shape
|
||||
_, ctx_len, _ = context_H.shape
|
||||
|
||||
# Project queries, keys, values
|
||||
Q = self.q_proj(slots_S) # [batch, num_slots, dim]
|
||||
K = self.k_proj(context_H) # [batch, ctx_len, dim]
|
||||
V = self.v_proj(context_H) # [batch, ctx_len, dim]
|
||||
|
||||
# Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
|
||||
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Prepare attention mask if context_mask is provided
|
||||
attn_mask = None
|
||||
if context_mask is not None:
|
||||
# context_mask: [batch, ctx_len] where 0 means padding
|
||||
# Convert to bool mask and reshape for broadcasting
|
||||
bool_mask = context_mask == 0 # [batch, ctx_len]
|
||||
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
|
||||
# Convert to float mask where True (padding) becomes -inf
|
||||
attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf"))
|
||||
|
||||
# Scaled dot-product attention
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0, # no dropout
|
||||
)
|
||||
|
||||
# Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, num_slots, self.dim)
|
||||
)
|
||||
|
||||
# Project back
|
||||
fused = self.out_proj(attn_output)
|
||||
|
||||
# Residual connection and layer norm
|
||||
fused = self.ln(fused + slots_S)
|
||||
|
||||
return fused
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. 专家混合层 (MoE Layer)
|
||||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||
# ------------------------------------------------------------------
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.dim = dim
|
||||
|
||||
# Import Expert from your existing components
|
||||
# Assuming Expert class is defined as in components.py [2]
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
input_dim=dim,
|
||||
d_model=dim,
|
||||
num_resblocks=export_resblocks,
|
||||
output_multiplier=1,
|
||||
)
|
||||
for _ in range(num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# Gating Network [2]
|
||||
self.gate = nn.Linear(dim, num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [batch, seq_len, dim]
|
||||
Returns:
|
||||
out: [batch, seq_len, dim]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
|
||||
# 1. Compute Gating Scores
|
||||
gates = self.gate(x) # [B, L, num_experts]
|
||||
|
||||
# 2. Select Top-K Experts
|
||||
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
|
||||
|
||||
# Normalize weights for selected experts
|
||||
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
|
||||
|
||||
# 3. Dispatch and Compute
|
||||
# Initialize output
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
# Reshape for easier processing: flatten batch and sequence dimensions
|
||||
x_flat = x.view(-1, D) # [B*L, D]
|
||||
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
|
||||
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
|
||||
|
||||
# For each of the top-k positions
|
||||
for k in range(self.top_k):
|
||||
# Get expert indices and weights for this position
|
||||
expert_indices = topk_indices_flat[:, k] # [B*L]
|
||||
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
|
||||
|
||||
# Process each expert separately
|
||||
for e_idx in range(self.num_experts):
|
||||
# Mask for tokens assigned to this expert at position k
|
||||
mask = expert_indices == e_idx # [B*L]
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
# Extract tokens for this expert
|
||||
x_selected = x_flat[mask] # [N_selected, D]
|
||||
if x_selected.numel() == 0:
|
||||
continue
|
||||
|
||||
# Pass through expert
|
||||
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
|
||||
|
||||
# Apply expert weights and add to output
|
||||
weighted_out = expert_out * expert_weights[mask]
|
||||
|
||||
# Scatter back to flat output
|
||||
out_flat = out.view(-1, D)
|
||||
out_flat[mask] += weighted_out
|
||||
|
||||
# Reshape back to original shape
|
||||
out = out.view(B, L, D)
|
||||
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from modelscope import AutoModel, AutoTokenizer
|
||||
from modelscope import AutoTokenizer
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
|
@ -20,7 +20,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
self,
|
||||
data_path: str,
|
||||
max_workes: int = -1,
|
||||
max_length=128,
|
||||
max_iter_length=1e6,
|
||||
max_seq_length=128,
|
||||
text_field: str = "text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 5000,
|
||||
|
|
@ -37,7 +38,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
Path(files(__package__) / "assets" / "tokenizer")
|
||||
)
|
||||
self.data_path = data_path
|
||||
self.max_length = max_length
|
||||
|
||||
self.max_iter_length = max_iter_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.text_field = text_field
|
||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||
self.max_workers = max_workes
|
||||
|
|
@ -48,8 +51,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
self.query_engine = QueryEngine()
|
||||
self.query_engine.load()
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
self.buffer = []
|
||||
|
||||
# 提取每个样本的目标字符及其频率
|
||||
self.sample_freqs = self.query_engine.get_all_weights()
|
||||
|
|
@ -95,76 +96,14 @@ class PinyinInputDataset(IterableDataset):
|
|||
else:
|
||||
return 1
|
||||
|
||||
def tokenize_with_four_seg(self, parts: List[str]) -> Dict[str, Any]:
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
|
||||
# 添加 [CLS] (Type 0)
|
||||
cls_id = self.tokenizer.cls_token_id
|
||||
input_ids.append(cls_id)
|
||||
token_type_ids.append(0)
|
||||
|
||||
for seg_idx, part in enumerate(parts):
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# Tokenize 单个部分,不加特殊符号
|
||||
# 注意:这里先不截断,最后统一截断,保证优先级高的段落(如part1)完整
|
||||
encoded_part = self.tokenizer(
|
||||
part, add_special_tokens=False, truncation=False
|
||||
)
|
||||
|
||||
part_ids = encoded_part["input_ids"]
|
||||
|
||||
# 如果加上当前部分会超过 MAX_LEN - 1 (留一个位置给最后的SEP或截断),则截断当前部分
|
||||
remaining_space = (
|
||||
self.max_length - len(input_ids) - 1
|
||||
) # -1 for final SEP or safety
|
||||
if len(part_ids) > remaining_space:
|
||||
part_ids = part_ids[:remaining_space]
|
||||
|
||||
if not part_ids:
|
||||
continue
|
||||
|
||||
input_ids.extend(part_ids)
|
||||
# 当前段落的 type_id 即为 seg_idx (0, 1, 2, 3)
|
||||
token_type_ids.extend([seg_idx] * len(part_ids))
|
||||
|
||||
# 添加 [SEP] (Type 跟随当前段落)
|
||||
sep_id = self.tokenizer.sep_token_id
|
||||
input_ids.append(sep_id)
|
||||
token_type_ids.append(seg_idx)
|
||||
|
||||
# 如果已经达到最大长度,提前退出
|
||||
if len(input_ids) >= self.max_length:
|
||||
break
|
||||
|
||||
# 4. 处理 Padding 或 最终截断
|
||||
if len(input_ids) > self.max_length:
|
||||
input_ids = input_ids[: self.max_length]
|
||||
token_type_ids = token_type_ids[: self.max_length]
|
||||
else:
|
||||
pad_len = self.max_length - len(input_ids)
|
||||
input_ids += [self.tokenizer.pad_token_id] * pad_len
|
||||
token_type_ids += [0] * pad_len # Padding mask type 通常为 0
|
||||
|
||||
# 5. 生成 Attention Mask
|
||||
attention_mask = [
|
||||
1 if i != self.tokenizer.pad_token_id else 0 for i in input_ids
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||
}
|
||||
|
||||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[List[str]]:
|
||||
def generate_pinyin(self, text: str) -> List[str]:
|
||||
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
|
||||
# 生成需要预测汉字对应的拼音,并进行加强
|
||||
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> (int, List[str]):
|
||||
def get_mask_pinyin(
|
||||
self, text: str, pinyin_list: List[str]
|
||||
) -> Tuple[int, List[str]]:
|
||||
mask_pinyin = []
|
||||
for i in range(len(text)):
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
|
|
@ -193,14 +132,41 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||
|
||||
# 计算每个worker的配额
|
||||
# 将 max_iter_length 转换为整数以确保整数除法
|
||||
total_quota = int(self.max_iter_length)
|
||||
base_quota = total_quota // num_workers
|
||||
remainder = total_quota % num_workers
|
||||
|
||||
# 最后一个worker处理剩余的样本(如果有余数)
|
||||
if worker_id == num_workers - 1:
|
||||
worker_quota = base_quota + remainder
|
||||
else:
|
||||
worker_quota = base_quota
|
||||
else:
|
||||
# 单worker情况,使用全部配额
|
||||
worker_quota = int(self.max_iter_length)
|
||||
num_workers = 1
|
||||
|
||||
# 每个worker有自己的迭代计数器
|
||||
current_iter_index = 0
|
||||
|
||||
batch_samples = []
|
||||
for sample in self.dataset:
|
||||
# 检查是否达到最大迭代次数
|
||||
if current_iter_index >= worker_quota:
|
||||
break
|
||||
|
||||
text = sample.get(self.text_field, "")
|
||||
if text:
|
||||
pinyin_list = self.generate_pinyin(text)
|
||||
for i in range(len(text)):
|
||||
labels = []
|
||||
# 如果text[i]不再字符库中,则跳过
|
||||
# 在开始处理每个字符前检查配额
|
||||
if current_iter_index >= worker_quota:
|
||||
break
|
||||
|
||||
labels = [] # 添加起始符
|
||||
# 如果text[i]不在字符库中,则跳过
|
||||
# 当i小于48时候,则将part1取text[0:i]
|
||||
# 当i大于48时候,则将part1取text[i-48:i]
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
|
|
@ -256,34 +222,66 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
labels = [
|
||||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||
for c, p in zip(text[i:i + pinyin_len], pinyin_list[i:i + pinyin_len])
|
||||
for c, p in zip(
|
||||
text[i : i + pinyin_len],
|
||||
pinyin_list[i : i + pinyin_len],
|
||||
)
|
||||
]
|
||||
labels.append(0)
|
||||
|
||||
encoded = self.tokenize_with_four_seg(
|
||||
[
|
||||
part1,
|
||||
part2,
|
||||
part3,
|
||||
part4,
|
||||
]
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=self.max_seq_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
repeats = self.adjust_frequency(
|
||||
min([self.sample_freqs[i] for i in labels])
|
||||
)
|
||||
sample = {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
"part1": part1,
|
||||
"part2": part2,
|
||||
"part3": part3,
|
||||
"part4": part4,
|
||||
}
|
||||
batch_samples.extend([sample] * repeats)
|
||||
samples = []
|
||||
for i, label in enumerate(labels):
|
||||
repeats = self.adjust_frequency(label)
|
||||
l = labels[:i]
|
||||
len_l = len(l)
|
||||
l.extend([0] * (8 - len_l))
|
||||
|
||||
samples.extend(
|
||||
[
|
||||
{
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(
|
||||
l, dtype=torch.long
|
||||
),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": part2,
|
||||
}
|
||||
]
|
||||
* repeats
|
||||
)
|
||||
|
||||
# 添加到缓冲区
|
||||
batch_samples.extend(samples)
|
||||
|
||||
# 处理shuffle buffer
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
self.buffer.extend([batch_samples[i] for i in indices])
|
||||
for idx in indices:
|
||||
if current_iter_index >= worker_quota:
|
||||
# 清空batch_samples并返回
|
||||
batch_samples = []
|
||||
return # 使用return而不是break,因为我们在生成器函数中
|
||||
yield batch_samples[idx]
|
||||
current_iter_index += 1
|
||||
batch_samples = []
|
||||
yield from self.buffer
|
||||
|
||||
# 处理剩余的样本
|
||||
if batch_samples:
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
for idx in indices:
|
||||
if current_iter_index >= worker_quota:
|
||||
return
|
||||
yield batch_samples[idx]
|
||||
current_iter_index += 1
|
||||
|
|
|
|||
|
|
@ -1,448 +1,134 @@
|
|||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.amp as amp
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
from tqdm.notebook import tqdm
|
||||
|
||||
|
||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||
# 导入 components.py 中的组件
|
||||
from .components import (
|
||||
AttentionPooling, # 可选,暂不使用
|
||||
ContextEncoder,
|
||||
CrossAttentionFusion,
|
||||
MoELayer,
|
||||
SlotMemory,
|
||||
)
|
||||
|
||||
|
||||
class InputMethodEngine(nn.Module):
|
||||
"""
|
||||
输入法预测引擎模型。
|
||||
基于 README 设计的轻量级输入法预测模型,整合了上下文编码、槽位记忆、
|
||||
交叉注意力融合、混合专家网络 (MoE) 以及分类预测。
|
||||
|
||||
输入参数:
|
||||
input_ids (torch.Tensor): [batch_size, seq_len] 文本 token ids。
|
||||
token_type_ids (torch.Tensor): [batch_size, seq_len] 标识前缀/后缀等类型,用于偏置。
|
||||
attention_mask (torch.Tensor): [batch_size, seq_len] 注意力掩码,1 表示有效位置。
|
||||
history_slot_ids (torch.Tensor): [batch_size, num_slots] 或 [num_slots] 历史槽位 ID。
|
||||
如果输入为 [num_slots],内部会自动扩展 batch 维度。
|
||||
|
||||
输出:
|
||||
logits (torch.Tensor): [batch_size, vocab_size] 下一个字符的概率分布(未经过 softmax)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_encoder, # 已加载并扩展好的预训练编码器
|
||||
output_vocab_size: int,
|
||||
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
|
||||
max_slot_steps: int = 24,
|
||||
num_experts: int = 20,
|
||||
top_k: int = 3,
|
||||
expert_res_blocks: int = 4,
|
||||
dropout: float = 0.3,
|
||||
use_attention_pooling: bool = False,
|
||||
**kwargs,
|
||||
vocab_size: int = 10019,
|
||||
pinyin_vocab_size: int = 28,
|
||||
dim: int = 512,
|
||||
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
||||
n_layers: int = 4, # Transformer 层数
|
||||
n_heads: int = 4, # 注意力头数
|
||||
num_experts: int = 20, # MoE 专家数量
|
||||
max_seq_len: int = 128, # 最大上下文长度
|
||||
use_pinyin: bool = False, # 是否使用拼音特征(若为 False,拼音嵌入恒为零)
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.output_vocab_size = output_vocab_size
|
||||
self.max_slot_steps = max_slot_steps
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.use_attention_pooling = use_attention_pooling
|
||||
self.dim = dim
|
||||
self.num_slots = num_slots
|
||||
self.use_pinyin = use_pinyin
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# 预训练编码器
|
||||
self.encoder = pretrained_encoder
|
||||
|
||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||
|
||||
self.cross_attention = nn.MultiheadAttention(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=4, # 可配置
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
# 1. 上下文编码器 (ContextEncoder)
|
||||
# 若 use_pinyin=False,则传入 pinyin_vocab_size=1 并固定嵌入为零
|
||||
self.context_encoder = ContextEncoder(
|
||||
vocab_size=vocab_size,
|
||||
pinyin_vocab_size=pinyin_vocab_size if use_pinyin else 1,
|
||||
dim=dim,
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads,
|
||||
max_len=max_seq_len,
|
||||
)
|
||||
|
||||
if use_attention_pooling:
|
||||
self.attention_pooling = AttentionPooling(hidden_size)
|
||||
|
||||
self.gate = nn.Linear(hidden_size, num_experts)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
input_dim=hidden_size,
|
||||
d_model=hidden_size,
|
||||
num_resblocks=expert_res_blocks,
|
||||
output_multiplier=1,
|
||||
dropout_prob=dropout,
|
||||
)
|
||||
for _ in range(num_experts)
|
||||
]
|
||||
# 2. 槽位记忆模块 (SlotMemory)
|
||||
# 适配历史槽位数量为 num_slots(每个槽位对应一个词,而非多步)
|
||||
self.slot_memory = SlotMemory(
|
||||
vocab_size=vocab_size,
|
||||
max_slots=num_slots,
|
||||
steps_per_slot=1, # 每个槽位只占一步
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||
nn.LayerNorm(hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_size, output_vocab_size),
|
||||
)
|
||||
# 3. 交叉注意力融合 (CrossAttentionFusion)
|
||||
# 使用 F.scaled_dot_product_attention 实现的版本
|
||||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||
|
||||
self._init_weights()
|
||||
# 4. 混合专家层 (MoE)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2)
|
||||
|
||||
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
||||
"""
|
||||
使用预训练编码器编码文本
|
||||
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
|
||||
"""
|
||||
outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
# 取 last_hidden_state [batch, seq_len, hidden]
|
||||
return outputs.last_hidden_state
|
||||
# 5. 分类头
|
||||
self.classifier = nn.Linear(dim, vocab_size)
|
||||
|
||||
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
||||
"""
|
||||
单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布
|
||||
context: [batch, seq_len, hidden] 文本编码结果
|
||||
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接)
|
||||
slot_seq_mask: [batch, current_len] 有效位置mask(1有效)
|
||||
返回: [batch, output_vocab_size] 概率分布
|
||||
"""
|
||||
batch_size = slot_seq_emb.size(0)
|
||||
# 交叉注意力:Query是槽位序列(通常只取最后一个步的嵌入作为Query,但这里我们使用整个序列)
|
||||
# 为了简单,我们使用整个序列作为Query,然后取最后一个位置的输出(因为自回归)
|
||||
# 方法1:Query = 最后一个位置的嵌入(单个向量)
|
||||
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
|
||||
# 交叉注意力
|
||||
attn_out, _ = self.cross_attention(
|
||||
query=last_query,
|
||||
key=context,
|
||||
value=context,
|
||||
key_padding_mask=(
|
||||
context.sum(-1) == 0
|
||||
), # 忽略填充位置,实际应传入attention_mask
|
||||
) # [batch, 1, hidden]
|
||||
attn_out = attn_out.squeeze(1) # [batch, hidden]
|
||||
|
||||
# 门控网络:选择top-k专家
|
||||
gate_logits = self.gate(attn_out) # [batch, num_experts]
|
||||
topk_weights, topk_indices = torch.topk(
|
||||
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
|
||||
)
|
||||
# 归一化权重
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 计算专家输出加权和
|
||||
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
|
||||
for i in range(self.top_k):
|
||||
expert_idx = topk_indices[:, i] # [batch]
|
||||
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
|
||||
# 批量获取专家输出(需对每个样本分别处理,或用循环)
|
||||
# 这里简单循环,实际可优化为 gather
|
||||
for b in range(batch_size):
|
||||
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
|
||||
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
|
||||
|
||||
# 分类头
|
||||
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
|
||||
return logits
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
slot_target_ids,
|
||||
slot_target_mask=None,
|
||||
):
|
||||
"""
|
||||
训练模式:一次性并行计算所有槽位步的预测
|
||||
slot_target_ids: [batch, max_slot_steps] 真实标签(teacher forcing)
|
||||
slot_target_mask: [batch, max_slot_steps] 有效步mask(1表示该步存在)
|
||||
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
|
||||
"""
|
||||
batch_size, max_steps = slot_target_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# 1. 编码文本
|
||||
context = self.encode_text(
|
||||
input_ids, token_type_ids, attention_mask
|
||||
) # [b, L, h]
|
||||
|
||||
# 2. 构建槽位序列嵌入(teacher forcing:使用真实标签构建)
|
||||
# 将每个样本的槽位ID转为embedding,并添加位置编码
|
||||
# slot_target_ids 形状 [b, T]
|
||||
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
|
||||
# 位置编码(位置从0开始)
|
||||
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
|
||||
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
|
||||
slot_emb = slot_emb + pos_emb
|
||||
# 可选:加入mask(无效位置置0)
|
||||
if slot_target_mask is not None:
|
||||
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
|
||||
|
||||
# 3. 交叉注意力:Query为整个槽位序列(每个位置独立预测)
|
||||
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
|
||||
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
|
||||
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
|
||||
attn_out, _ = self.cross_attention(
|
||||
query=slot_emb, # [b, T, h]
|
||||
key=context, # [b, L, h]
|
||||
value=context,
|
||||
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
|
||||
) # [b, T, h]
|
||||
|
||||
# 4. 对每个槽位步分别通过门控+专家+分类头
|
||||
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
|
||||
b, T, h = attn_out.shape
|
||||
attn_flat = attn_out.view(b * T, h) # [b*T, h]
|
||||
|
||||
# 门控网络
|
||||
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
|
||||
gate_probs = F.softmax(gate_logits, dim=-1)
|
||||
topk_probs, topk_indices = torch.topk(
|
||||
gate_probs, self.top_k, dim=-1
|
||||
) # [b*T, top_k]
|
||||
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
|
||||
|
||||
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
|
||||
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
|
||||
for i in range(self.top_k):
|
||||
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
|
||||
idx = topk_indices[:, i] # [b*T]
|
||||
# 批量获取专家输出
|
||||
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
|
||||
# 实际生产环境应优化为组卷积或使用 einsum,此处保持清晰
|
||||
for expert_id in range(self.num_experts):
|
||||
mask = idx == expert_id
|
||||
if mask.any():
|
||||
# 取出属于该专家的样本
|
||||
sub_input = attn_flat[mask] # [k, h]
|
||||
sub_output = self.experts[expert_id](sub_input) # [k, h]
|
||||
expert_out_flat[mask] += weight[mask] * sub_output
|
||||
|
||||
# 分类头
|
||||
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
|
||||
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
|
||||
|
||||
return logits
|
||||
# 可选:如果不需要拼音,将拼音嵌入矩阵固定为零
|
||||
if not use_pinyin and hasattr(self.context_encoder, "pinyin_emb"):
|
||||
# 将拼音嵌入权重置零,确保对输出无影响
|
||||
nn.init.zeros_(self.context_encoder.pinyin_emb.weight)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
slot_target_ids=None,
|
||||
slot_target_mask=None,
|
||||
mode="train",
|
||||
):
|
||||
input_ids: torch.Tensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
history_slot_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
统一接口
|
||||
mode='train': 使用 teacher forcing 并行计算所有步 logits,返回 [b, T, vocab]
|
||||
mode='infer': 需结合外部循环,使用 forward_single_step
|
||||
前向传播。
|
||||
"""
|
||||
if mode == "train":
|
||||
return self.forward_train(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
slot_target_ids,
|
||||
slot_target_mask,
|
||||
)
|
||||
batch_size = input_ids.size(0)
|
||||
|
||||
# 处理 history_slot_ids:若为 [num_slots] 则扩展 batch 维度
|
||||
if history_slot_ids.dim() == 1:
|
||||
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# 1. 构造拼音输入(如果 use_pinyin=False,则使用全零占位符)
|
||||
if self.use_pinyin:
|
||||
# 注意:这里需要真实的拼音 ids,但当前输入未提供,故用零占位(实际应用中应从外部获取)
|
||||
# 为简化演示,此处使用全零张量,并假设拼音词汇表中 0 为 padding。
|
||||
pinyin_ids = torch.zeros_like(input_ids)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Inference mode should call forward_single_step directly"
|
||||
)
|
||||
pinyin_ids = torch.zeros_like(input_ids)
|
||||
|
||||
def to(self, device):
|
||||
"""重写 to 方法,记录设备"""
|
||||
self.device = device
|
||||
return super().to(device)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
train_dataloader,
|
||||
eval_dataloader=None,
|
||||
monitor=None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
num_epochs=1,
|
||||
stop_batch=2e5,
|
||||
eval_frequency=500,
|
||||
grad_accum_steps=1,
|
||||
clip_grad_norm=1.0,
|
||||
loss_weight=None,
|
||||
mixed_precision=True,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
label_smoothing=0.15,
|
||||
lr=1e-4,
|
||||
lr_schedule=None,
|
||||
save_dir=None,
|
||||
save_frequency=1000,
|
||||
):
|
||||
"""
|
||||
训练模型
|
||||
"""
|
||||
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
|
||||
if _processed_batches < _warmup_steps:
|
||||
current_lr = _lr * (_processed_batches / _warmup_steps)
|
||||
else:
|
||||
progress = (_processed_batches - _warmup_steps) / (
|
||||
_stop_batch - _warmup_steps
|
||||
)
|
||||
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
return current_lr
|
||||
# 2. 上下文编码 -> H [batch, seq_len, dim]
|
||||
# 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
|
||||
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||||
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
if lr_schedule is None:
|
||||
lr_schedule = default_lr_schedule
|
||||
# 3. 槽位记忆编码 -> S [batch, num_slots, dim]
|
||||
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
|
||||
|
||||
self.train()
|
||||
# 4. 交叉注意力融合 (使用 CrossAttentionFusion)
|
||||
fused = self.cross_attn(S, H, context_mask=attention_mask)
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
# 5. MoE 处理 -> [batch, num_slots, dim]
|
||||
moe_out = self.moe(fused)
|
||||
|
||||
# 损失函数:需要设置 ignore_index=-1,因为标签中无效位置用 -1 表示
|
||||
if criterion is None:
|
||||
if loss_weight is not None:
|
||||
criterion = nn.CrossEntropyLoss(
|
||||
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
|
||||
)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss(
|
||||
label_smoothing=label_smoothing, ignore_index=-1
|
||||
)
|
||||
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
||||
total_steps = stop_batch
|
||||
warmup_steps = int(total_steps * warmup_ratio)
|
||||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||||
processed_batches = 0
|
||||
batch_loss_sum = 0.0
|
||||
optimizer.zero_grad()
|
||||
|
||||
try:
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(train_dataloader, total=int(stop_batch))
|
||||
):
|
||||
# 学习率调度
|
||||
current_lr = lr_schedule(
|
||||
lr, stop_batch, processed_batches, warmup_steps
|
||||
)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = current_lr
|
||||
|
||||
# 从 batch 中获取数据
|
||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
|
||||
|
||||
# 构建 slot_target_mask:有效位置为 1,无效位置为 0(假设无效标签为 -1)
|
||||
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
|
||||
|
||||
with torch.amp.autocast(
|
||||
device_type=self.device.type, enabled=mixed_precision
|
||||
):
|
||||
# 调用模型(训练模式)
|
||||
logits = self(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
slot_target_ids=labels,
|
||||
slot_target_mask=slot_target_mask,
|
||||
mode="train",
|
||||
) # logits: [batch, max_slot_steps, output_vocab_size]
|
||||
|
||||
# 计算损失(忽略填充位置,ignore_index=-1 已在 criterion 中设置)
|
||||
loss = criterion(
|
||||
logits.view(-1, self.output_vocab_size), labels.view(-1)
|
||||
)
|
||||
loss = loss / grad_accum_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 梯度累积更新
|
||||
if (processed_batches + 1) % grad_accum_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
batch_loss_sum += loss.item() * grad_accum_steps
|
||||
|
||||
# 定期评估
|
||||
if processed_batches % eval_frequency == 0:
|
||||
if eval_dataloader:
|
||||
self.eval()
|
||||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||||
self.train()
|
||||
if monitor:
|
||||
monitor.add_step(
|
||||
processed_batches,
|
||||
{
|
||||
"train_loss": batch_loss_sum
|
||||
/ (eval_frequency if processed_batches > 0 else 1),
|
||||
"acc": acc,
|
||||
"loss": eval_loss,
|
||||
"lr": current_lr,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
|
||||
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||||
f"current_lr: {current_lr}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||||
f"current_lr: {current_lr}"
|
||||
)
|
||||
batch_loss_sum = 0.0
|
||||
|
||||
processed_batches += 1
|
||||
if processed_batches >= stop_batch:
|
||||
break
|
||||
|
||||
else:
|
||||
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
|
||||
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
|
||||
# 但为了兼容原有逻辑,这里不做额外处理
|
||||
pass
|
||||
|
||||
# 训练结束通知
|
||||
if monitor:
|
||||
monitor.finish()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
|
||||
|
||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||
state_dict = torch.load(
|
||||
state_dict_path, weights_only=True, map_location=self.device
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def load_from_pretrained_base_model(
|
||||
self,
|
||||
BaseModel,
|
||||
snapshot_path: Union[str, Path],
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
base_model = BaseModel(*args, **kwargs)
|
||||
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||||
self_static_dict = self.state_dict()
|
||||
pretrained_dict = base_model.state_dict()
|
||||
|
||||
freeze_layers = []
|
||||
|
||||
for key in self_static_dict.keys():
|
||||
if key in pretrained_dict.keys():
|
||||
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||||
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||||
freeze_layers.append(key)
|
||||
self.load_state_dict(self_static_dict)
|
||||
for name, param in self.named_parameters():
|
||||
if name in freeze_layers:
|
||||
param.requires_grad = False
|
||||
# 6. 池化与分类:对槽位维度求平均(或使用 mask 池化)
|
||||
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
|
||||
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
|
||||
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
|
||||
pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8)
|
||||
# 如果所有槽位均为 padding,则降级为全局平均
|
||||
if torch.isnan(pooled).any():
|
||||
pooled = moe_out.mean(dim=1)
|
||||
|
||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||
return logits
|
||||
|
|
|
|||
18
test.py
18
test.py
|
|
@ -1,7 +1,17 @@
|
|||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
|
||||
dataset = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
||||
if sys.platform == "win32":
|
||||
dataset_path = "data"
|
||||
else:
|
||||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
||||
|
||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=20, max_workes=3)
|
||||
for i, line in enumerate(dataset):
|
||||
print(line['labels'])
|
||||
if i > 10:
|
||||
break
|
||||
for k, v in line.items():
|
||||
if isinstance(v, str):
|
||||
continue
|
||||
print(k, v.shape)
|
||||
|
|
|
|||
Loading…
Reference in New Issue