feat: 更新输入法模型架构设计文档并重构核心组件代码
This commit is contained in:
parent
fd49058764
commit
1af85a36bc
|
|
@ -175,3 +175,5 @@ cython_debug/
|
||||||
*.vsix
|
*.vsix
|
||||||
|
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|
||||||
|
data/*
|
||||||
|
|
|
||||||
225
README.md
225
README.md
|
|
@ -1,95 +1,166 @@
|
||||||
# SUimeModelTraner
|
# 输入法预测模型架构设计 (Input Method Prediction Model)
|
||||||
|
|
||||||
> 深度学习输入法引擎技术方案
|
## 1. 概述
|
||||||
|
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
|
||||||
|
|
||||||
## 1. 任务目标
|
## 2. 核心架构流程
|
||||||
设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。
|
数据流遵循以下路径:
|
||||||
|
`输入编码` → `Transformer 上下文编码` → `槽位记忆嵌入` → `交叉注意力融合` → `门控+专家混合 (MoE)` → `分类预测` → `束搜索解码`
|
||||||
|
|
||||||
## 2. 输入表示
|
### 2.1 输入层设计
|
||||||
- **四段文本**:
|
模型接收三类输入,分别处理以保持语义清晰:
|
||||||
- 光标前文本
|
1. **当前文本上下文**:包含光标前文本(Prefix)和光标后文本(Suffix)。
|
||||||
- 拼音(如“bangdao”)
|
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
|
||||||
- 光标后文本
|
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
|
||||||
- 符合条件的历史记录(如“绑到|邦道|…”)
|
|
||||||
- **编码方式**:使用BERT类Tokenizer,统一序列长度 **L=128**(或88)。
|
|
||||||
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
|
|
||||||
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
|
|
||||||
|
|
||||||
## 3. 模型架构概览
|
### 2.2 模块详解
|
||||||
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
|
|
||||||
|
|
||||||
![模型架构示意图]
|
#### A. Transformer 编码器 (Context Encoder)
|
||||||
|
负责提取当前语境的深层语义表示。
|
||||||
|
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
|
||||||
|
* **骨干网络**:使用标准的 Transformer Encoder。
|
||||||
|
* **隐藏层维度**:512 [1]
|
||||||
|
* **Transformer 层数**:4 层(轻量级设计,从头训练) [1]
|
||||||
|
* **注意力头数**:4 头 [1]
|
||||||
|
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
|
||||||
|
|
||||||
## 4. 核心模块设计
|
#### B. 槽位记忆模块 (Slot Memory)
|
||||||
|
负责将非结构化的历史输入转化为结构化的记忆向量。
|
||||||
|
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
|
||||||
|
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
|
||||||
|
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。
|
||||||
|
|
||||||
### 4.1 Embedding层与Transformer编码器
|
#### C. 交叉注意力融合 (Cross-Attention Fusion)
|
||||||
- **Embedding维度**:512
|
这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。
|
||||||
- **Transformer层数**:4层
|
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
|
||||||
- **多头注意力头数**:4或8
|
* **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。
|
||||||
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`)
|
* **机制**:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。
|
||||||
- **预训练**:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。
|
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。
|
||||||
|
|
||||||
### 4.2 槽位记忆模块
|
#### D. 门控与专家混合 (Gating + MoE)
|
||||||
- **槽位结构**:共8个槽位,每个槽位最多可填充3步,总计最多24步。
|
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
|
||||||
- **槽位嵌入**:每一步选择的文字ID通过共享的embedding层转换为512维向量。
|
* **专家数量**:20 个专家 [1]。
|
||||||
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
|
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
|
||||||
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
|
* **输出**:经过专家网络增强后的特征向量。
|
||||||
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
|
|
||||||
|
|
||||||
### 4.3 交叉注意力与注意力池化
|
#### E. 分类头与解码
|
||||||
- **Query**:当前步的槽位序列(经过位置编码后)
|
* **分类预测**:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
|
||||||
- **Key/Value**:Transformer编码器输出 `H`
|
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
|
||||||
- **输出**:槽位相关的特征序列
|
|
||||||
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`(维度512)。
|
|
||||||
|
|
||||||
### 4.4 门控网络与专家层(MoE)
|
## 3. 关键超参数配置
|
||||||
- **门控网络**:输入 `f`,输出20个专家的权重,选择**top-3**专家。
|
|
||||||
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
|
|
||||||
- **专家组合**:对选中的3个专家输出进行加权求和,得到融合特征 `e`。
|
|
||||||
|
|
||||||
### 4.5 分类头
|
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
|
||||||
- **设计**:采用同维FFN,结构为
|
|
||||||
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
|
|
||||||
输出10019维(ID范围0~10018,其中0表示终止符)。
|
|
||||||
- **优点**:保留特征维度,引入非线性,参数量适中。
|
|
||||||
|
|
||||||
## 5. 解码策略:束搜索
|
| 参数项 | 推荐值 | 说明 |
|
||||||
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(默认为3)。
|
| :--- | :--- | :--- |
|
||||||
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
|
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
|
||||||
- **终止条件**:
|
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
|
||||||
1. 所有槽位已填满(8×3=24步);
|
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
|
||||||
2. 当前步所有候选分支的最高概率词均为 **0(终止符)**,则强制退出。
|
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
|
||||||
- **输出**:概率最高的完整槽位序列。
|
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
|
||||||
|
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
|
||||||
|
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
|
||||||
|
|
||||||
## 6. 训练设置
|
## 4. 训练策略
|
||||||
- **优化器**:AdamW
|
|
||||||
- **损失函数**:每一步的CrossEntropyLoss(仅计算非填充位置)
|
|
||||||
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
|
|
||||||
- **标签处理**:每个槽位步的真实文字ID作为监督信号
|
|
||||||
|
|
||||||
## 7. 关键设计考量与潜在风险
|
本模型采用标准的**序列到序列(Seq2Seq)监督学习**范式,直接对目标槽位序列进行逐步预测。
|
||||||
|
|
||||||
| 设计点 | 考量 | 潜在风险/优化 |
|
### 4.1 数据构造与标签
|
||||||
|--------|------|----------------|
|
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
|
||||||
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
|
* **上下文**:光标前后的文本片段。
|
||||||
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 |
|
* **拼音**:当前待输入字的拼音序列。
|
||||||
| 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
|
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
||||||
| 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
|
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
||||||
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
|
|
||||||
|
|
||||||
## 8. 后续优化方向
|
### 4.2 损失函数与优化
|
||||||
- **拼音增强**:引入拼音专用嵌入(音节级编码)
|
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
||||||
- **槽位建模增强**:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用)
|
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
||||||
- **预训练**:尝试加载开源中文BERT权重作为编码器初始化
|
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
|
||||||
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
|
|
||||||
|
|
||||||
---
|
### 4.3 训练流程细节
|
||||||
|
1. **前向传播**:
|
||||||
|
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
|
||||||
|
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
|
||||||
|
* 分类头输出当前步所有候选字的概率分布。
|
||||||
|
2. **Teacher Forcing**:
|
||||||
|
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。
|
||||||
|
3. **反向传播**:
|
||||||
|
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
|
||||||
|
|
||||||
## 附录:关键超参数(待定)
|
### 4.4 推理与训练的差异
|
||||||
- 序列长度:128
|
* **训练时**:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
|
||||||
- 隐藏层维度:512
|
* **推理时**:由于无法获取真实标签,模型采用**束搜索(Beam Search)** [1]。
|
||||||
- Transformer层数:4
|
* **束宽**:默认为 5 [1]。
|
||||||
- 注意力头数:4
|
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
|
||||||
- 专家数量:20
|
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
|
||||||
- 束宽:5
|
|
||||||
- 学习率:待调(建议 1e-4 ~ 5e-4,带warmup)
|
|
||||||
|
## 5. 代码实现示意 (PyTorch)
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class Expert(nn.Module):
|
||||||
|
def __init__(self, dim=512):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(dim * 4, dim)
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
class InputMethodModel(nn.Module):
|
||||||
|
def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20):
|
||||||
|
super().__init__()
|
||||||
|
# 1. Context Encoder
|
||||||
|
self.text_emb = nn.Embedding(vocab_size, dim)
|
||||||
|
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||||
|
self.pos_emb = nn.Embedding(128, dim)
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads)
|
||||||
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||||
|
|
||||||
|
# 2. Slot Memory
|
||||||
|
self.slot_emb = nn.Embedding(slot_vocab_size, dim)
|
||||||
|
self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位
|
||||||
|
|
||||||
|
# 3. Cross-Attention
|
||||||
|
self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
|
||||||
|
|
||||||
|
# 4. MoE Layer
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
|
||||||
|
self.gate = nn.Linear(dim, num_experts)
|
||||||
|
|
||||||
|
# 5. Classification Head
|
||||||
|
self.classifier = nn.Linear(dim, vocab_size)
|
||||||
|
|
||||||
|
def forward(self, text_ids, pinyin_ids, history_slot_ids):
|
||||||
|
# Encode Context
|
||||||
|
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
|
||||||
|
x += self.pos_emb(torch.arange(x.size(1)).to(x.device))
|
||||||
|
H = self.transformer(x) # [B, L, 512]
|
||||||
|
|
||||||
|
# Encode Slots
|
||||||
|
S = self.slot_emb(history_slot_ids)
|
||||||
|
S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device))
|
||||||
|
|
||||||
|
# Cross-Attention: Q=Slots, K/V=Context
|
||||||
|
fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512]
|
||||||
|
|
||||||
|
# MoE Processing
|
||||||
|
# 简化版 MoE: 对所有专家输出进行加权平均
|
||||||
|
gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts]
|
||||||
|
expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim]
|
||||||
|
moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim]
|
||||||
|
|
||||||
|
# Pooling & Predict
|
||||||
|
pooled = moe_out.mean(dim=1) # [B, 512]
|
||||||
|
logits = self.classifier(pooled)
|
||||||
|
return logits
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. 总结
|
||||||
|
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。
|
||||||
|
|
|
||||||
|
|
@ -1,30 +0,0 @@
|
||||||
{
|
|
||||||
"add_cross_attention": false,
|
|
||||||
"attention_probs_dropout_prob": 0.1,
|
|
||||||
"bos_token_id": null,
|
|
||||||
"classifier_dropout": null,
|
|
||||||
"directionality": "bidi",
|
|
||||||
"eos_token_id": null,
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_dropout_prob": 0.1,
|
|
||||||
"hidden_size": 512,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 2048,
|
|
||||||
"is_decoder": false,
|
|
||||||
"layer_norm_eps": 1e-12,
|
|
||||||
"max_position_embeddings": 512,
|
|
||||||
"model_type": "bert",
|
|
||||||
"num_attention_heads": 8,
|
|
||||||
"num_hidden_layers": 6,
|
|
||||||
"pad_token_id": 0,
|
|
||||||
"pooler_fc_size": 768,
|
|
||||||
"pooler_num_attention_heads": 12,
|
|
||||||
"pooler_num_fc_layers": 3,
|
|
||||||
"pooler_size_per_head": 128,
|
|
||||||
"pooler_type": "first_token_transform",
|
|
||||||
"tie_word_embeddings": true,
|
|
||||||
"transformers_version": "5.1.0",
|
|
||||||
"type_vocab_size": 4,
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 21128
|
|
||||||
}
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.amp as amp
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 注意力池化模块----------------------------
|
# ---------------------------- 注意力池化模块----------------------------
|
||||||
|
|
@ -80,3 +78,270 @@ class Expert(nn.Module):
|
||||||
for block in self.res_blocks:
|
for block in self.res_blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
return self.output(x)
|
return self.output(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 1. 上下文编码器 (Context Encoder)
|
||||||
|
# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1]
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class ContextEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# Embeddings
|
||||||
|
self.text_emb = nn.Embedding(vocab_size, dim)
|
||||||
|
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||||
|
self.pos_emb = nn.Embedding(max_len, dim)
|
||||||
|
|
||||||
|
# Transformer Encoder (4 layers, 4 heads) [1]
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=dim,
|
||||||
|
nhead=n_heads,
|
||||||
|
dim_feedforward=dim * 4,
|
||||||
|
dropout=0.1,
|
||||||
|
batch_first=True, # 方便处理 [B, L, D]
|
||||||
|
)
|
||||||
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||||
|
|
||||||
|
# LayerNorm for stability
|
||||||
|
self.ln = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, text_ids, pinyin_ids, mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text_ids: [batch, seq_len]
|
||||||
|
pinyin_ids: [batch, seq_len] (假设已对齐,若不对齐需预处理)
|
||||||
|
mask: [batch, seq_len] optional padding mask
|
||||||
|
Returns:
|
||||||
|
H: [batch, seq_len, 512] Context representation [1]
|
||||||
|
"""
|
||||||
|
# 1. Embedding Fusion: Text + Pinyin + Position
|
||||||
|
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
|
||||||
|
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
|
||||||
|
|
||||||
|
seq_len = x.size(1)
|
||||||
|
pos_ids = (
|
||||||
|
torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids)
|
||||||
|
)
|
||||||
|
x += self.pos_emb(pos_ids)
|
||||||
|
|
||||||
|
# 2. Transformer Encoding
|
||||||
|
# src_key_padding_mask expects True for padding positions
|
||||||
|
if mask is not None:
|
||||||
|
# Convert 0/1 mask to bool mask where True is padding
|
||||||
|
src_mask = mask == 0
|
||||||
|
else:
|
||||||
|
src_mask = None
|
||||||
|
|
||||||
|
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||||||
|
return self.ln(H)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 2. 槽位记忆模块 (Slot Memory)
|
||||||
|
# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1]
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class SlotMemory(nn.Module):
|
||||||
|
def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512):
|
||||||
|
super().__init__()
|
||||||
|
self.max_slots = max_slots
|
||||||
|
self.steps_per_slot = steps_per_slot
|
||||||
|
self.total_steps = max_slots * steps_per_slot # 24 steps [1]
|
||||||
|
|
||||||
|
# Shared embedding layer for history tokens [1]
|
||||||
|
self.emb = nn.Embedding(vocab_size, dim)
|
||||||
|
|
||||||
|
# Learnable positional embeddings for the flattened sequence [1]
|
||||||
|
self.pos_emb = nn.Embedding(self.total_steps, dim)
|
||||||
|
|
||||||
|
# Start token embedding for empty slots [1]
|
||||||
|
self.start_emb = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
|
|
||||||
|
def forward(self, history_ids):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
history_ids: [batch, total_steps]
|
||||||
|
Flattened sequence of history tokens.
|
||||||
|
Empty positions should be filled with a special PAD or handled via mask.
|
||||||
|
Returns:
|
||||||
|
S: [batch, total_steps, 512] Slot sequence representation [1]
|
||||||
|
"""
|
||||||
|
# Embed history tokens
|
||||||
|
S = self.emb(history_ids) # [B, 24, 512]
|
||||||
|
|
||||||
|
# Add positional embeddings
|
||||||
|
pos_ids = (
|
||||||
|
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
||||||
|
)
|
||||||
|
S += self.pos_emb(pos_ids)
|
||||||
|
|
||||||
|
return S
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 3. 交叉注意力融合 (Cross-Attention Fusion)
|
||||||
|
# 对应 README: Query=Slots, Key/Value=Context H [1]
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class CrossAttentionFusion(nn.Module):
|
||||||
|
def __init__(self, dim=512, n_heads=4):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads"
|
||||||
|
|
||||||
|
# Linear projections for Q, K, V
|
||||||
|
self.q_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.ln = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, slots_S, context_H, slot_mask=None, context_mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
slots_S: [batch, num_slots_steps, dim] Query
|
||||||
|
context_H: [batch, ctx_len, dim] Key/Value
|
||||||
|
slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention)
|
||||||
|
context_mask: [batch, ctx_len] Optional padding mask
|
||||||
|
Returns:
|
||||||
|
Fused: [batch, num_slots_steps, dim]
|
||||||
|
"""
|
||||||
|
batch_size, num_slots, _ = slots_S.shape
|
||||||
|
_, ctx_len, _ = context_H.shape
|
||||||
|
|
||||||
|
# Project queries, keys, values
|
||||||
|
Q = self.q_proj(slots_S) # [batch, num_slots, dim]
|
||||||
|
K = self.k_proj(context_H) # [batch, ctx_len, dim]
|
||||||
|
V = self.v_proj(context_H) # [batch, ctx_len, dim]
|
||||||
|
|
||||||
|
# Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
|
||||||
|
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Prepare attention mask if context_mask is provided
|
||||||
|
attn_mask = None
|
||||||
|
if context_mask is not None:
|
||||||
|
# context_mask: [batch, ctx_len] where 0 means padding
|
||||||
|
# Convert to bool mask and reshape for broadcasting
|
||||||
|
bool_mask = context_mask == 0 # [batch, ctx_len]
|
||||||
|
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
|
||||||
|
# Convert to float mask where True (padding) becomes -inf
|
||||||
|
attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf"))
|
||||||
|
|
||||||
|
# Scaled dot-product attention
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=0.0, # no dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim]
|
||||||
|
attn_output = (
|
||||||
|
attn_output.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, num_slots, self.dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Project back
|
||||||
|
fused = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
# Residual connection and layer norm
|
||||||
|
fused = self.ln(fused + slots_S)
|
||||||
|
|
||||||
|
return fused
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 4. 专家混合层 (MoE Layer)
|
||||||
|
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
class MoELayer(nn.Module):
|
||||||
|
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.top_k = top_k
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# Import Expert from your existing components
|
||||||
|
# Assuming Expert class is defined as in components.py [2]
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Expert(
|
||||||
|
input_dim=dim,
|
||||||
|
d_model=dim,
|
||||||
|
num_resblocks=export_resblocks,
|
||||||
|
output_multiplier=1,
|
||||||
|
)
|
||||||
|
for _ in range(num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gating Network [2]
|
||||||
|
self.gate = nn.Linear(dim, num_experts)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [batch, seq_len, dim]
|
||||||
|
Returns:
|
||||||
|
out: [batch, seq_len, dim]
|
||||||
|
"""
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
# 1. Compute Gating Scores
|
||||||
|
gates = self.gate(x) # [B, L, num_experts]
|
||||||
|
|
||||||
|
# 2. Select Top-K Experts
|
||||||
|
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
|
||||||
|
|
||||||
|
# Normalize weights for selected experts
|
||||||
|
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
|
||||||
|
|
||||||
|
# 3. Dispatch and Compute
|
||||||
|
# Initialize output
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
|
||||||
|
# Reshape for easier processing: flatten batch and sequence dimensions
|
||||||
|
x_flat = x.view(-1, D) # [B*L, D]
|
||||||
|
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
|
||||||
|
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
|
||||||
|
|
||||||
|
# For each of the top-k positions
|
||||||
|
for k in range(self.top_k):
|
||||||
|
# Get expert indices and weights for this position
|
||||||
|
expert_indices = topk_indices_flat[:, k] # [B*L]
|
||||||
|
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
|
||||||
|
|
||||||
|
# Process each expert separately
|
||||||
|
for e_idx in range(self.num_experts):
|
||||||
|
# Mask for tokens assigned to this expert at position k
|
||||||
|
mask = expert_indices == e_idx # [B*L]
|
||||||
|
if not mask.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract tokens for this expert
|
||||||
|
x_selected = x_flat[mask] # [N_selected, D]
|
||||||
|
if x_selected.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Pass through expert
|
||||||
|
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
|
||||||
|
|
||||||
|
# Apply expert weights and add to output
|
||||||
|
weighted_out = expert_out * expert_weights[mask]
|
||||||
|
|
||||||
|
# Scatter back to flat output
|
||||||
|
out_flat = out.view(-1, D)
|
||||||
|
out_flat[mask] += weighted_out
|
||||||
|
|
||||||
|
# Reshape back to original shape
|
||||||
|
out = out.view(B, L, D)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from modelscope import AutoModel, AutoTokenizer
|
from modelscope import AutoTokenizer
|
||||||
from pypinyin import Style, lazy_pinyin
|
from pypinyin import Style, lazy_pinyin
|
||||||
from pypinyin.contrib.tone_convert import to_initials
|
from pypinyin.contrib.tone_convert import to_initials
|
||||||
from torch.utils.data import DataLoader, IterableDataset
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
@ -20,7 +20,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self,
|
self,
|
||||||
data_path: str,
|
data_path: str,
|
||||||
max_workes: int = -1,
|
max_workes: int = -1,
|
||||||
max_length=128,
|
max_iter_length=1e6,
|
||||||
|
max_seq_length=128,
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size: int = 5000,
|
shuffle_buffer_size: int = 5000,
|
||||||
|
|
@ -37,7 +38,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
Path(files(__package__) / "assets" / "tokenizer")
|
Path(files(__package__) / "assets" / "tokenizer")
|
||||||
)
|
)
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
self.max_length = max_length
|
|
||||||
|
self.max_iter_length = max_iter_length
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
self.text_field = text_field
|
self.text_field = text_field
|
||||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||||
self.max_workers = max_workes
|
self.max_workers = max_workes
|
||||||
|
|
@ -48,8 +51,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
self.query_engine = QueryEngine()
|
self.query_engine = QueryEngine()
|
||||||
self.query_engine.load()
|
self.query_engine.load()
|
||||||
self.shuffle_buffer_size = shuffle_buffer_size
|
|
||||||
self.buffer = []
|
|
||||||
|
|
||||||
# 提取每个样本的目标字符及其频率
|
# 提取每个样本的目标字符及其频率
|
||||||
self.sample_freqs = self.query_engine.get_all_weights()
|
self.sample_freqs = self.query_engine.get_all_weights()
|
||||||
|
|
@ -95,76 +96,14 @@ class PinyinInputDataset(IterableDataset):
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def tokenize_with_four_seg(self, parts: List[str]) -> Dict[str, Any]:
|
|
||||||
input_ids = []
|
|
||||||
token_type_ids = []
|
|
||||||
|
|
||||||
# 添加 [CLS] (Type 0)
|
|
||||||
cls_id = self.tokenizer.cls_token_id
|
|
||||||
input_ids.append(cls_id)
|
|
||||||
token_type_ids.append(0)
|
|
||||||
|
|
||||||
for seg_idx, part in enumerate(parts):
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Tokenize 单个部分,不加特殊符号
|
|
||||||
# 注意:这里先不截断,最后统一截断,保证优先级高的段落(如part1)完整
|
|
||||||
encoded_part = self.tokenizer(
|
|
||||||
part, add_special_tokens=False, truncation=False
|
|
||||||
)
|
|
||||||
|
|
||||||
part_ids = encoded_part["input_ids"]
|
|
||||||
|
|
||||||
# 如果加上当前部分会超过 MAX_LEN - 1 (留一个位置给最后的SEP或截断),则截断当前部分
|
|
||||||
remaining_space = (
|
|
||||||
self.max_length - len(input_ids) - 1
|
|
||||||
) # -1 for final SEP or safety
|
|
||||||
if len(part_ids) > remaining_space:
|
|
||||||
part_ids = part_ids[:remaining_space]
|
|
||||||
|
|
||||||
if not part_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids.extend(part_ids)
|
|
||||||
# 当前段落的 type_id 即为 seg_idx (0, 1, 2, 3)
|
|
||||||
token_type_ids.extend([seg_idx] * len(part_ids))
|
|
||||||
|
|
||||||
# 添加 [SEP] (Type 跟随当前段落)
|
|
||||||
sep_id = self.tokenizer.sep_token_id
|
|
||||||
input_ids.append(sep_id)
|
|
||||||
token_type_ids.append(seg_idx)
|
|
||||||
|
|
||||||
# 如果已经达到最大长度,提前退出
|
|
||||||
if len(input_ids) >= self.max_length:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 4. 处理 Padding 或 最终截断
|
|
||||||
if len(input_ids) > self.max_length:
|
|
||||||
input_ids = input_ids[: self.max_length]
|
|
||||||
token_type_ids = token_type_ids[: self.max_length]
|
|
||||||
else:
|
|
||||||
pad_len = self.max_length - len(input_ids)
|
|
||||||
input_ids += [self.tokenizer.pad_token_id] * pad_len
|
|
||||||
token_type_ids += [0] * pad_len # Padding mask type 通常为 0
|
|
||||||
|
|
||||||
# 5. 生成 Attention Mask
|
|
||||||
attention_mask = [
|
|
||||||
1 if i != self.tokenizer.pad_token_id else 0 for i in input_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
|
||||||
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
|
||||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 生成对应文本的拼音
|
# 生成对应文本的拼音
|
||||||
def generate_pinyin(self, text: str) -> List[List[str]]:
|
def generate_pinyin(self, text: str) -> List[str]:
|
||||||
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
|
|
||||||
# 生成需要预测汉字对应的拼音,并进行加强
|
# 生成需要预测汉字对应的拼音,并进行加强
|
||||||
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> (int, List[str]):
|
def get_mask_pinyin(
|
||||||
|
self, text: str, pinyin_list: List[str]
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
mask_pinyin = []
|
mask_pinyin = []
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
if not self.query_engine.is_chinese_char(text[i]):
|
if not self.query_engine.is_chinese_char(text[i]):
|
||||||
|
|
@ -193,14 +132,41 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||||
|
|
||||||
|
# 计算每个worker的配额
|
||||||
|
# 将 max_iter_length 转换为整数以确保整数除法
|
||||||
|
total_quota = int(self.max_iter_length)
|
||||||
|
base_quota = total_quota // num_workers
|
||||||
|
remainder = total_quota % num_workers
|
||||||
|
|
||||||
|
# 最后一个worker处理剩余的样本(如果有余数)
|
||||||
|
if worker_id == num_workers - 1:
|
||||||
|
worker_quota = base_quota + remainder
|
||||||
|
else:
|
||||||
|
worker_quota = base_quota
|
||||||
|
else:
|
||||||
|
# 单worker情况,使用全部配额
|
||||||
|
worker_quota = int(self.max_iter_length)
|
||||||
|
num_workers = 1
|
||||||
|
|
||||||
|
# 每个worker有自己的迭代计数器
|
||||||
|
current_iter_index = 0
|
||||||
|
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
for sample in self.dataset:
|
for sample in self.dataset:
|
||||||
|
# 检查是否达到最大迭代次数
|
||||||
|
if current_iter_index >= worker_quota:
|
||||||
|
break
|
||||||
|
|
||||||
text = sample.get(self.text_field, "")
|
text = sample.get(self.text_field, "")
|
||||||
if text:
|
if text:
|
||||||
pinyin_list = self.generate_pinyin(text)
|
pinyin_list = self.generate_pinyin(text)
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
labels = []
|
# 在开始处理每个字符前检查配额
|
||||||
# 如果text[i]不再字符库中,则跳过
|
if current_iter_index >= worker_quota:
|
||||||
|
break
|
||||||
|
|
||||||
|
labels = [] # 添加起始符
|
||||||
|
# 如果text[i]不在字符库中,则跳过
|
||||||
# 当i小于48时候,则将part1取text[0:i]
|
# 当i小于48时候,则将part1取text[0:i]
|
||||||
# 当i大于48时候,则将part1取text[i-48:i]
|
# 当i大于48时候,则将part1取text[i-48:i]
|
||||||
if not self.query_engine.is_chinese_char(text[i]):
|
if not self.query_engine.is_chinese_char(text[i]):
|
||||||
|
|
@ -256,34 +222,66 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
labels = [
|
labels = [
|
||||||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||||
for c, p in zip(text[i:i + pinyin_len], pinyin_list[i:i + pinyin_len])
|
for c, p in zip(
|
||||||
|
text[i : i + pinyin_len],
|
||||||
|
pinyin_list[i : i + pinyin_len],
|
||||||
|
)
|
||||||
]
|
]
|
||||||
labels.append(0)
|
|
||||||
|
|
||||||
encoded = self.tokenize_with_four_seg(
|
encoded = self.tokenizer(
|
||||||
[
|
f"{part4}|{part1}",
|
||||||
part1,
|
|
||||||
part2,
|
|
||||||
part3,
|
part3,
|
||||||
part4,
|
max_length=self.max_seq_length,
|
||||||
]
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_token_type_ids=True,
|
||||||
)
|
)
|
||||||
repeats = self.adjust_frequency(
|
samples = []
|
||||||
min([self.sample_freqs[i] for i in labels])
|
for i, label in enumerate(labels):
|
||||||
)
|
repeats = self.adjust_frequency(label)
|
||||||
sample = {
|
l = labels[:i]
|
||||||
|
len_l = len(l)
|
||||||
|
l.extend([0] * (8 - len_l))
|
||||||
|
|
||||||
|
samples.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
"input_ids": encoded["input_ids"],
|
"input_ids": encoded["input_ids"],
|
||||||
"token_type_ids": encoded["token_type_ids"],
|
"token_type_ids": encoded["token_type_ids"],
|
||||||
"attention_mask": encoded["attention_mask"],
|
"attention_mask": encoded["attention_mask"],
|
||||||
"labels": torch.tensor(labels, dtype=torch.long),
|
"label": torch.tensor([label], dtype=torch.long),
|
||||||
"part1": part1,
|
"history_slot_ids": torch.tensor(
|
||||||
"part2": part2,
|
l, dtype=torch.long
|
||||||
"part3": part3,
|
),
|
||||||
"part4": part4,
|
"prefix": f"{part4}^{part1}",
|
||||||
|
"suffix": part3,
|
||||||
|
"pinyin": part2,
|
||||||
}
|
}
|
||||||
batch_samples.extend([sample] * repeats)
|
]
|
||||||
|
* repeats
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加到缓冲区
|
||||||
|
batch_samples.extend(samples)
|
||||||
|
|
||||||
|
# 处理shuffle buffer
|
||||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||||
indices = np.random.permutation(len(batch_samples))
|
indices = np.random.permutation(len(batch_samples))
|
||||||
self.buffer.extend([batch_samples[i] for i in indices])
|
for idx in indices:
|
||||||
|
if current_iter_index >= worker_quota:
|
||||||
|
# 清空batch_samples并返回
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
yield from self.buffer
|
return # 使用return而不是break,因为我们在生成器函数中
|
||||||
|
yield batch_samples[idx]
|
||||||
|
current_iter_index += 1
|
||||||
|
batch_samples = []
|
||||||
|
|
||||||
|
# 处理剩余的样本
|
||||||
|
if batch_samples:
|
||||||
|
indices = np.random.permutation(len(batch_samples))
|
||||||
|
for idx in indices:
|
||||||
|
if current_iter_index >= worker_quota:
|
||||||
|
return
|
||||||
|
yield batch_samples[idx]
|
||||||
|
current_iter_index += 1
|
||||||
|
|
|
||||||
|
|
@ -1,448 +1,134 @@
|
||||||
import math
|
from typing import Optional
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.amp as amp
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from loguru import logger
|
# 导入 components.py 中的组件
|
||||||
from modelscope import AutoTokenizer
|
from .components import (
|
||||||
from tqdm.notebook import tqdm
|
AttentionPooling, # 可选,暂不使用
|
||||||
|
ContextEncoder,
|
||||||
|
CrossAttentionFusion,
|
||||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
MoELayer,
|
||||||
|
SlotMemory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InputMethodEngine(nn.Module):
|
class InputMethodEngine(nn.Module):
|
||||||
|
"""
|
||||||
|
输入法预测引擎模型。
|
||||||
|
基于 README 设计的轻量级输入法预测模型,整合了上下文编码、槽位记忆、
|
||||||
|
交叉注意力融合、混合专家网络 (MoE) 以及分类预测。
|
||||||
|
|
||||||
|
输入参数:
|
||||||
|
input_ids (torch.Tensor): [batch_size, seq_len] 文本 token ids。
|
||||||
|
token_type_ids (torch.Tensor): [batch_size, seq_len] 标识前缀/后缀等类型,用于偏置。
|
||||||
|
attention_mask (torch.Tensor): [batch_size, seq_len] 注意力掩码,1 表示有效位置。
|
||||||
|
history_slot_ids (torch.Tensor): [batch_size, num_slots] 或 [num_slots] 历史槽位 ID。
|
||||||
|
如果输入为 [num_slots],内部会自动扩展 batch 维度。
|
||||||
|
|
||||||
|
输出:
|
||||||
|
logits (torch.Tensor): [batch_size, vocab_size] 下一个字符的概率分布(未经过 softmax)。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pretrained_encoder, # 已加载并扩展好的预训练编码器
|
vocab_size: int = 10019,
|
||||||
output_vocab_size: int,
|
pinyin_vocab_size: int = 28,
|
||||||
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
|
dim: int = 512,
|
||||||
max_slot_steps: int = 24,
|
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
||||||
num_experts: int = 20,
|
n_layers: int = 4, # Transformer 层数
|
||||||
top_k: int = 3,
|
n_heads: int = 4, # 注意力头数
|
||||||
expert_res_blocks: int = 4,
|
num_experts: int = 20, # MoE 专家数量
|
||||||
dropout: float = 0.3,
|
max_seq_len: int = 128, # 最大上下文长度
|
||||||
use_attention_pooling: bool = False,
|
use_pinyin: bool = False, # 是否使用拼音特征(若为 False,拼音嵌入恒为零)
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.dim = dim
|
||||||
self.output_vocab_size = output_vocab_size
|
self.num_slots = num_slots
|
||||||
self.max_slot_steps = max_slot_steps
|
self.use_pinyin = use_pinyin
|
||||||
self.num_experts = num_experts
|
self.vocab_size = vocab_size
|
||||||
self.top_k = top_k
|
|
||||||
self.use_attention_pooling = use_attention_pooling
|
|
||||||
|
|
||||||
# 预训练编码器
|
# 1. 上下文编码器 (ContextEncoder)
|
||||||
self.encoder = pretrained_encoder
|
# 若 use_pinyin=False,则传入 pinyin_vocab_size=1 并固定嵌入为零
|
||||||
|
self.context_encoder = ContextEncoder(
|
||||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
vocab_size=vocab_size,
|
||||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
pinyin_vocab_size=pinyin_vocab_size if use_pinyin else 1,
|
||||||
|
dim=dim,
|
||||||
self.cross_attention = nn.MultiheadAttention(
|
n_layers=n_layers,
|
||||||
embed_dim=hidden_size,
|
n_heads=n_heads,
|
||||||
num_heads=4, # 可配置
|
max_len=max_seq_len,
|
||||||
dropout=dropout,
|
|
||||||
batch_first=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_attention_pooling:
|
# 2. 槽位记忆模块 (SlotMemory)
|
||||||
self.attention_pooling = AttentionPooling(hidden_size)
|
# 适配历史槽位数量为 num_slots(每个槽位对应一个词,而非多步)
|
||||||
|
self.slot_memory = SlotMemory(
|
||||||
self.gate = nn.Linear(hidden_size, num_experts)
|
vocab_size=vocab_size,
|
||||||
self.experts = nn.ModuleList(
|
max_slots=num_slots,
|
||||||
[
|
steps_per_slot=1, # 每个槽位只占一步
|
||||||
Expert(
|
dim=dim,
|
||||||
input_dim=hidden_size,
|
|
||||||
d_model=hidden_size,
|
|
||||||
num_resblocks=expert_res_blocks,
|
|
||||||
output_multiplier=1,
|
|
||||||
dropout_prob=dropout,
|
|
||||||
)
|
|
||||||
for _ in range(num_experts)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.classifier = nn.Sequential(
|
# 3. 交叉注意力融合 (CrossAttentionFusion)
|
||||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
# 使用 F.scaled_dot_product_attention 实现的版本
|
||||||
nn.LayerNorm(hidden_size),
|
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(hidden_size, output_vocab_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._init_weights()
|
# 4. 混合专家层 (MoE)
|
||||||
|
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2)
|
||||||
|
|
||||||
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
# 5. 分类头
|
||||||
"""
|
self.classifier = nn.Linear(dim, vocab_size)
|
||||||
使用预训练编码器编码文本
|
|
||||||
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
|
|
||||||
"""
|
|
||||||
outputs = self.encoder(
|
|
||||||
input_ids=input_ids,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
# 取 last_hidden_state [batch, seq_len, hidden]
|
|
||||||
return outputs.last_hidden_state
|
|
||||||
|
|
||||||
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
# 可选:如果不需要拼音,将拼音嵌入矩阵固定为零
|
||||||
"""
|
if not use_pinyin and hasattr(self.context_encoder, "pinyin_emb"):
|
||||||
单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布
|
# 将拼音嵌入权重置零,确保对输出无影响
|
||||||
context: [batch, seq_len, hidden] 文本编码结果
|
nn.init.zeros_(self.context_encoder.pinyin_emb.weight)
|
||||||
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接)
|
|
||||||
slot_seq_mask: [batch, current_len] 有效位置mask(1有效)
|
|
||||||
返回: [batch, output_vocab_size] 概率分布
|
|
||||||
"""
|
|
||||||
batch_size = slot_seq_emb.size(0)
|
|
||||||
# 交叉注意力:Query是槽位序列(通常只取最后一个步的嵌入作为Query,但这里我们使用整个序列)
|
|
||||||
# 为了简单,我们使用整个序列作为Query,然后取最后一个位置的输出(因为自回归)
|
|
||||||
# 方法1:Query = 最后一个位置的嵌入(单个向量)
|
|
||||||
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
|
|
||||||
# 交叉注意力
|
|
||||||
attn_out, _ = self.cross_attention(
|
|
||||||
query=last_query,
|
|
||||||
key=context,
|
|
||||||
value=context,
|
|
||||||
key_padding_mask=(
|
|
||||||
context.sum(-1) == 0
|
|
||||||
), # 忽略填充位置,实际应传入attention_mask
|
|
||||||
) # [batch, 1, hidden]
|
|
||||||
attn_out = attn_out.squeeze(1) # [batch, hidden]
|
|
||||||
|
|
||||||
# 门控网络:选择top-k专家
|
|
||||||
gate_logits = self.gate(attn_out) # [batch, num_experts]
|
|
||||||
topk_weights, topk_indices = torch.topk(
|
|
||||||
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
|
|
||||||
)
|
|
||||||
# 归一化权重
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 计算专家输出加权和
|
|
||||||
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
|
|
||||||
for i in range(self.top_k):
|
|
||||||
expert_idx = topk_indices[:, i] # [batch]
|
|
||||||
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
|
|
||||||
# 批量获取专家输出(需对每个样本分别处理,或用循环)
|
|
||||||
# 这里简单循环,实际可优化为 gather
|
|
||||||
for b in range(batch_size):
|
|
||||||
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
|
|
||||||
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
|
|
||||||
|
|
||||||
# 分类头
|
|
||||||
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def forward_train(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
attention_mask,
|
|
||||||
slot_target_ids,
|
|
||||||
slot_target_mask=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
训练模式:一次性并行计算所有槽位步的预测
|
|
||||||
slot_target_ids: [batch, max_slot_steps] 真实标签(teacher forcing)
|
|
||||||
slot_target_mask: [batch, max_slot_steps] 有效步mask(1表示该步存在)
|
|
||||||
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
|
|
||||||
"""
|
|
||||||
batch_size, max_steps = slot_target_ids.shape
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
# 1. 编码文本
|
|
||||||
context = self.encode_text(
|
|
||||||
input_ids, token_type_ids, attention_mask
|
|
||||||
) # [b, L, h]
|
|
||||||
|
|
||||||
# 2. 构建槽位序列嵌入(teacher forcing:使用真实标签构建)
|
|
||||||
# 将每个样本的槽位ID转为embedding,并添加位置编码
|
|
||||||
# slot_target_ids 形状 [b, T]
|
|
||||||
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
|
|
||||||
# 位置编码(位置从0开始)
|
|
||||||
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
|
|
||||||
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
|
|
||||||
slot_emb = slot_emb + pos_emb
|
|
||||||
# 可选:加入mask(无效位置置0)
|
|
||||||
if slot_target_mask is not None:
|
|
||||||
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
|
|
||||||
|
|
||||||
# 3. 交叉注意力:Query为整个槽位序列(每个位置独立预测)
|
|
||||||
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
|
|
||||||
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
|
|
||||||
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
|
|
||||||
attn_out, _ = self.cross_attention(
|
|
||||||
query=slot_emb, # [b, T, h]
|
|
||||||
key=context, # [b, L, h]
|
|
||||||
value=context,
|
|
||||||
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
|
|
||||||
) # [b, T, h]
|
|
||||||
|
|
||||||
# 4. 对每个槽位步分别通过门控+专家+分类头
|
|
||||||
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
|
|
||||||
b, T, h = attn_out.shape
|
|
||||||
attn_flat = attn_out.view(b * T, h) # [b*T, h]
|
|
||||||
|
|
||||||
# 门控网络
|
|
||||||
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
|
|
||||||
gate_probs = F.softmax(gate_logits, dim=-1)
|
|
||||||
topk_probs, topk_indices = torch.topk(
|
|
||||||
gate_probs, self.top_k, dim=-1
|
|
||||||
) # [b*T, top_k]
|
|
||||||
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
|
|
||||||
|
|
||||||
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
|
|
||||||
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
|
|
||||||
for i in range(self.top_k):
|
|
||||||
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
|
|
||||||
idx = topk_indices[:, i] # [b*T]
|
|
||||||
# 批量获取专家输出
|
|
||||||
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
|
|
||||||
# 实际生产环境应优化为组卷积或使用 einsum,此处保持清晰
|
|
||||||
for expert_id in range(self.num_experts):
|
|
||||||
mask = idx == expert_id
|
|
||||||
if mask.any():
|
|
||||||
# 取出属于该专家的样本
|
|
||||||
sub_input = attn_flat[mask] # [k, h]
|
|
||||||
sub_output = self.experts[expert_id](sub_input) # [k, h]
|
|
||||||
expert_out_flat[mask] += weight[mask] * sub_output
|
|
||||||
|
|
||||||
# 分类头
|
|
||||||
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
|
|
||||||
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
token_type_ids,
|
token_type_ids: torch.Tensor,
|
||||||
attention_mask,
|
attention_mask: torch.Tensor,
|
||||||
slot_target_ids=None,
|
history_slot_ids: torch.Tensor,
|
||||||
slot_target_mask=None,
|
) -> torch.Tensor:
|
||||||
mode="train",
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
统一接口
|
前向传播。
|
||||||
mode='train': 使用 teacher forcing 并行计算所有步 logits,返回 [b, T, vocab]
|
|
||||||
mode='infer': 需结合外部循环,使用 forward_single_step
|
|
||||||
"""
|
"""
|
||||||
if mode == "train":
|
batch_size = input_ids.size(0)
|
||||||
return self.forward_train(
|
|
||||||
input_ids,
|
# 处理 history_slot_ids:若为 [num_slots] 则扩展 batch 维度
|
||||||
token_type_ids,
|
if history_slot_ids.dim() == 1:
|
||||||
attention_mask,
|
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
|
||||||
slot_target_ids,
|
|
||||||
slot_target_mask,
|
# 1. 构造拼音输入(如果 use_pinyin=False,则使用全零占位符)
|
||||||
)
|
if self.use_pinyin:
|
||||||
|
# 注意:这里需要真实的拼音 ids,但当前输入未提供,故用零占位(实际应用中应从外部获取)
|
||||||
|
# 为简化演示,此处使用全零张量,并假设拼音词汇表中 0 为 padding。
|
||||||
|
pinyin_ids = torch.zeros_like(input_ids)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
pinyin_ids = torch.zeros_like(input_ids)
|
||||||
"Inference mode should call forward_single_step directly"
|
|
||||||
)
|
|
||||||
|
|
||||||
def to(self, device):
|
# 2. 上下文编码 -> H [batch, seq_len, dim]
|
||||||
"""重写 to 方法,记录设备"""
|
# 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
|
||||||
self.device = device
|
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||||||
return super().to(device)
|
|
||||||
|
|
||||||
def fit(
|
# 3. 槽位记忆编码 -> S [batch, num_slots, dim]
|
||||||
self,
|
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
|
||||||
train_dataloader,
|
|
||||||
eval_dataloader=None,
|
|
||||||
monitor=None,
|
|
||||||
criterion=None,
|
|
||||||
optimizer=None,
|
|
||||||
num_epochs=1,
|
|
||||||
stop_batch=2e5,
|
|
||||||
eval_frequency=500,
|
|
||||||
grad_accum_steps=1,
|
|
||||||
clip_grad_norm=1.0,
|
|
||||||
loss_weight=None,
|
|
||||||
mixed_precision=True,
|
|
||||||
weight_decay=0.1,
|
|
||||||
warmup_ratio=0.1,
|
|
||||||
label_smoothing=0.15,
|
|
||||||
lr=1e-4,
|
|
||||||
lr_schedule=None,
|
|
||||||
save_dir=None,
|
|
||||||
save_frequency=1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
训练模型
|
|
||||||
"""
|
|
||||||
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
|
|
||||||
if _processed_batches < _warmup_steps:
|
|
||||||
current_lr = _lr * (_processed_batches / _warmup_steps)
|
|
||||||
else:
|
|
||||||
progress = (_processed_batches - _warmup_steps) / (
|
|
||||||
_stop_batch - _warmup_steps
|
|
||||||
)
|
|
||||||
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
||||||
return current_lr
|
|
||||||
|
|
||||||
if self.device is None:
|
# 4. 交叉注意力融合 (使用 CrossAttentionFusion)
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
fused = self.cross_attn(S, H, context_mask=attention_mask)
|
||||||
self.to(self.device)
|
|
||||||
if lr_schedule is None:
|
|
||||||
lr_schedule = default_lr_schedule
|
|
||||||
|
|
||||||
self.train()
|
# 5. MoE 处理 -> [batch, num_slots, dim]
|
||||||
|
moe_out = self.moe(fused)
|
||||||
|
|
||||||
if optimizer is None:
|
# 6. 池化与分类:对槽位维度求平均(或使用 mask 池化)
|
||||||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
|
||||||
|
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
|
||||||
# 损失函数:需要设置 ignore_index=-1,因为标签中无效位置用 -1 表示
|
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
|
||||||
if criterion is None:
|
pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8)
|
||||||
if loss_weight is not None:
|
# 如果所有槽位均为 padding,则降级为全局平均
|
||||||
criterion = nn.CrossEntropyLoss(
|
if torch.isnan(pooled).any():
|
||||||
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
|
pooled = moe_out.mean(dim=1)
|
||||||
)
|
|
||||||
else:
|
|
||||||
criterion = nn.CrossEntropyLoss(
|
|
||||||
label_smoothing=label_smoothing, ignore_index=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
|
||||||
|
|
||||||
total_steps = stop_batch
|
|
||||||
warmup_steps = int(total_steps * warmup_ratio)
|
|
||||||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
|
||||||
processed_batches = 0
|
|
||||||
batch_loss_sum = 0.0
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
try:
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
for batch_idx, batch in enumerate(
|
|
||||||
tqdm(train_dataloader, total=int(stop_batch))
|
|
||||||
):
|
|
||||||
# 学习率调度
|
|
||||||
current_lr = lr_schedule(
|
|
||||||
lr, stop_batch, processed_batches, warmup_steps
|
|
||||||
)
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group["lr"] = current_lr
|
|
||||||
|
|
||||||
# 从 batch 中获取数据
|
|
||||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
|
||||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
|
||||||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
|
|
||||||
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
|
|
||||||
|
|
||||||
# 构建 slot_target_mask:有效位置为 1,无效位置为 0(假设无效标签为 -1)
|
|
||||||
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
|
|
||||||
|
|
||||||
with torch.amp.autocast(
|
|
||||||
device_type=self.device.type, enabled=mixed_precision
|
|
||||||
):
|
|
||||||
# 调用模型(训练模式)
|
|
||||||
logits = self(
|
|
||||||
input_ids=input_ids,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
slot_target_ids=labels,
|
|
||||||
slot_target_mask=slot_target_mask,
|
|
||||||
mode="train",
|
|
||||||
) # logits: [batch, max_slot_steps, output_vocab_size]
|
|
||||||
|
|
||||||
# 计算损失(忽略填充位置,ignore_index=-1 已在 criterion 中设置)
|
|
||||||
loss = criterion(
|
|
||||||
logits.view(-1, self.output_vocab_size), labels.view(-1)
|
|
||||||
)
|
|
||||||
loss = loss / grad_accum_steps
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# 梯度累积更新
|
|
||||||
if (processed_batches + 1) % grad_accum_steps == 0:
|
|
||||||
scaler.unscale_(optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
batch_loss_sum += loss.item() * grad_accum_steps
|
|
||||||
|
|
||||||
# 定期评估
|
|
||||||
if processed_batches % eval_frequency == 0:
|
|
||||||
if eval_dataloader:
|
|
||||||
self.eval()
|
|
||||||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
|
||||||
self.train()
|
|
||||||
if monitor:
|
|
||||||
monitor.add_step(
|
|
||||||
processed_batches,
|
|
||||||
{
|
|
||||||
"train_loss": batch_loss_sum
|
|
||||||
/ (eval_frequency if processed_batches > 0 else 1),
|
|
||||||
"acc": acc,
|
|
||||||
"loss": eval_loss,
|
|
||||||
"lr": current_lr,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
|
|
||||||
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
|
||||||
f"current_lr: {current_lr}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
|
||||||
f"current_lr: {current_lr}"
|
|
||||||
)
|
|
||||||
batch_loss_sum = 0.0
|
|
||||||
|
|
||||||
processed_batches += 1
|
|
||||||
if processed_batches >= stop_batch:
|
|
||||||
break
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
|
|
||||||
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
|
|
||||||
# 但为了兼容原有逻辑,这里不做额外处理
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 训练结束通知
|
|
||||||
if monitor:
|
|
||||||
monitor.finish()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Training interrupted by user")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
|
||||||
state_dict = torch.load(
|
|
||||||
state_dict_path, weights_only=True, map_location=self.device
|
|
||||||
)
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
def load_from_pretrained_base_model(
|
|
||||||
self,
|
|
||||||
BaseModel,
|
|
||||||
snapshot_path: Union[str, Path],
|
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
base_model = BaseModel(*args, **kwargs)
|
|
||||||
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
|
||||||
self_static_dict = self.state_dict()
|
|
||||||
pretrained_dict = base_model.state_dict()
|
|
||||||
|
|
||||||
freeze_layers = []
|
|
||||||
|
|
||||||
for key in self_static_dict.keys():
|
|
||||||
if key in pretrained_dict.keys():
|
|
||||||
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
|
||||||
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
|
||||||
freeze_layers.append(key)
|
|
||||||
self.load_state_dict(self_static_dict)
|
|
||||||
for name, param in self.named_parameters():
|
|
||||||
if name in freeze_layers:
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
|
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||||
|
return logits
|
||||||
|
|
|
||||||
18
test.py
18
test.py
|
|
@ -1,7 +1,17 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from model.dataset import PinyinInputDataset
|
from model.dataset import PinyinInputDataset
|
||||||
|
|
||||||
dataset = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
if sys.platform == "win32":
|
||||||
|
dataset_path = "data"
|
||||||
|
else:
|
||||||
|
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
||||||
|
|
||||||
|
dataset = PinyinInputDataset(dataset_path, max_iter_length=20, max_workes=3)
|
||||||
for i, line in enumerate(dataset):
|
for i, line in enumerate(dataset):
|
||||||
print(line['labels'])
|
for k, v in line.items():
|
||||||
if i > 10:
|
if isinstance(v, str):
|
||||||
break
|
continue
|
||||||
|
print(k, v.shape)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue