feat: 更新输入法模型架构设计文档并重构核心组件代码

This commit is contained in:
songsenand 2026-04-03 17:04:35 +08:00
parent fd49058764
commit 1af85a36bc
8 changed files with 2445 additions and 2425 deletions

2
.gitignore vendored
View File

@ -175,3 +175,5 @@ cython_debug/
*.vsix
uv.lock
data/*

225
README.md
View File

@ -1,95 +1,166 @@
# SUimeModelTraner
# 输入法预测模型架构设计 (Input Method Prediction Model)
> 深度学习输入法引擎技术方案
## 1. 概述
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
## 1. 任务目标
设计一个基于上下文的输入法引擎模型输入包括光标前文本、拼音、光标后文本及历史记录输出候选词序列文字ID序列支持束搜索解码实现精准的拼音到文字转换。
## 2. 核心架构流程
数据流遵循以下路径:
`输入编码``Transformer 上下文编码``槽位记忆嵌入``交叉注意力融合``门控+专家混合 (MoE)``分类预测``束搜索解码`
## 2. 输入表示
- **四段文本**
- 光标前文本
- 拼音如“bangdao”
- 光标后文本
- 符合条件的历史记录(如“绑到|邦道|…”)
- **编码方式**使用BERT类Tokenizer统一序列长度 **L=128**或88
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
### 2.1 输入层设计
模型接收三类输入,分别处理以保持语义清晰:
1. **当前文本上下文**包含光标前文本Prefix和光标后文本Suffix
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
## 3. 模型架构概览
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
### 2.2 模块详解
![模型架构示意图]
#### A. Transformer 编码器 (Context Encoder)
负责提取当前语境的深层语义表示。
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
* **骨干网络**:使用标准的 Transformer Encoder。
* **隐藏层维度**512 [1]
* **Transformer 层数**4 层(轻量级设计,从头训练) [1]
* **注意力头数**4 头 [1]
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
## 4. 核心模块设计
#### B. 槽位记忆模块 (Slot Memory)
负责将非结构化的历史输入转化为结构化的记忆向量。
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`
### 4.1 Embedding层与Transformer编码器
- **Embedding维度**512
- **Transformer层数**4层
- **多头注意力头数**4或8
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`
- **预训练**可选加载structBERT轻量级骨干链接已失效当前从头训练
#### C. 交叉注意力融合 (Cross-Attention Fusion)
这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
* **Key/Value (K/V)**Transformer 编码器输出的上下文表示 $H$ [1]。
* **机制**:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`
### 4.2 槽位记忆模块
- **槽位结构**共8个槽位每个槽位最多可填充3步总计最多24步。
- **槽位嵌入**每一步选择的文字ID通过共享的embedding层转换为512维向量。
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
#### D. 门控与专家混合 (Gating + MoE)
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
* **专家数量**20 个专家 [1]。
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
* **输出**:经过专家网络增强后的特征向量。
### 4.3 交叉注意力与注意力池化
- **Query**:当前步的槽位序列(经过位置编码后)
- **Key/Value**Transformer编码器输出 `H`
- **输出**:槽位相关的特征序列
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`维度512
#### E. 分类头与解码
* **分类预测**MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
### 4.4 门控网络与专家层MoE
- **门控网络**:输入 `f`输出20个专家的权重选择**top-3**专家。
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
- **专家组合**对选中的3个专家输出进行加权求和得到融合特征 `e`
## 3. 关键超参数配置
### 4.5 分类头
- **设计**采用同维FFN结构为
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
输出10019维ID范围010018其中0表示终止符
- **优点**:保留特征维度,引入非线性,参数量适中。
为确保模型性能与效率的平衡,建议采用以下超参数 [1]
## 5. 解码策略:束搜索
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**默认为3
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
- **终止条件**
1. 所有槽位已填满8×3=24步
2. 当前步所有候选分支的最高概率词均为 **0终止符**,则强制退出。
- **输出**:概率最高的完整槽位序列。
| 参数项 | 推荐值 | 说明 |
| :--- | :--- | :--- |
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
## 6. 训练设置
- **优化器**AdamW
- **损失函数**每一步的CrossEntropyLoss仅计算非填充位置
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
- **标签处理**每个槽位步的真实文字ID作为监督信号
## 4. 训练策略
## 7. 关键设计考量与潜在风险
本模型采用标准的**序列到序列Seq2Seq监督学习**范式,直接对目标槽位序列进行逐步预测。
| 设计点 | 考量 | 潜在风险/优化 |
|--------|------|----------------|
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长最多24交叉注意力计算量可控 |
| 专家层 | 20专家top-3组合增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
| 分类头 | 同维FFN兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
### 4.1 数据构造与标签
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
* **上下文**:光标前后的文本片段。
* **拼音**:当前待输入字的拼音序列。
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
* **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。
## 8. 后续优化方向
- **拼音增强**:引入拼音专用嵌入(音节级编码)
- **槽位建模增强**若拼接方式难以学习长距离依赖可替换为轻量级GRU但当前暂不采用
- **预训练**尝试加载开源中文BERT权重作为编码器初始化
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
### 4.2 损失函数与优化
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
* **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
---
### 4.3 训练流程细节
1. **前向传播**
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
* 分类头输出当前步所有候选字的概率分布。
2. **Teacher Forcing**
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。
3. **反向传播**
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
## 附录:关键超参数(待定)
- 序列长度128
- 隐藏层维度512
- Transformer层数4
- 注意力头数4
- 专家数量20
- 束宽5
- 学习率:待调(建议 1e-4 5e-4带warmup
### 4.4 推理与训练的差异
* **训练时**:使用 Ground Truth真实标签作为槽位输入确保模型学习到最优的条件概率分布。
* **推理时**:由于无法获取真实标签,模型采用**束搜索Beam Search** [1]。
* **束宽**:默认为 5 [1]。
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
## 5. 代码实现示意 (PyTorch)
```python
import torch
import torch.nn as nn
class Expert(nn.Module):
def __init__(self, dim=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
return self.net(x)
class InputMethodModel(nn.Module):
def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20):
super().__init__()
# 1. Context Encoder
self.text_emb = nn.Embedding(vocab_size, dim)
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(128, dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# 2. Slot Memory
self.slot_emb = nn.Embedding(slot_vocab_size, dim)
self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位
# 3. Cross-Attention
self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
# 4. MoE Layer
self.num_experts = num_experts
self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
self.gate = nn.Linear(dim, num_experts)
# 5. Classification Head
self.classifier = nn.Linear(dim, vocab_size)
def forward(self, text_ids, pinyin_ids, history_slot_ids):
# Encode Context
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
x += self.pos_emb(torch.arange(x.size(1)).to(x.device))
H = self.transformer(x) # [B, L, 512]
# Encode Slots
S = self.slot_emb(history_slot_ids)
S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device))
# Cross-Attention: Q=Slots, K/V=Context
fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512]
# MoE Processing
# 简化版 MoE: 对所有专家输出进行加权平均
gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts]
expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim]
moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim]
# Pooling & Predict
pooled = moe_out.mean(dim=1) # [B, 512]
logits = self.classifier(pooled)
return logits
```
## 6. 总结
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1]在保证模型轻量4层 Transformer的同时有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构该设计在工程实现上更优雅在推理效率上更高效是轻量级输入法模型的局部最优解。

View File

@ -1,30 +0,0 @@
{
"add_cross_attention": false,
"attention_probs_dropout_prob": 0.1,
"bos_token_id": null,
"classifier_dropout": null,
"directionality": "bidi",
"eos_token_id": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 512,
"initializer_range": 0.02,
"intermediate_size": 2048,
"is_decoder": false,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 8,
"num_hidden_layers": 6,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"tie_word_embeddings": true,
"transformers_version": "5.1.0",
"type_vocab_size": 4,
"use_cache": true,
"vocab_size": 21128
}

View File

@ -1,9 +1,7 @@
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
# ---------------------------- 注意力池化模块----------------------------
@ -80,3 +78,270 @@ class Expert(nn.Module):
for block in self.res_blocks:
x = block(x)
return self.output(x)
# ------------------------------------------------------------------
# 1. 上下文编码器 (Context Encoder)
# 对应 README 4.1: 4层 Transformer, 512维, 输出 H [1]
# ------------------------------------------------------------------
class ContextEncoder(nn.Module):
def __init__(
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
):
super().__init__()
self.dim = dim
# Embeddings
self.text_emb = nn.Embedding(vocab_size, dim)
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(max_len, dim)
# Transformer Encoder (4 layers, 4 heads) [1]
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=n_heads,
dim_feedforward=dim * 4,
dropout=0.1,
batch_first=True, # 方便处理 [B, L, D]
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# LayerNorm for stability
self.ln = nn.LayerNorm(dim)
def forward(self, text_ids, pinyin_ids, mask=None):
"""
Args:
text_ids: [batch, seq_len]
pinyin_ids: [batch, seq_len] (假设已对齐若不对齐需预处理)
mask: [batch, seq_len] optional padding mask
Returns:
H: [batch, seq_len, 512] Context representation [1]
"""
# 1. Embedding Fusion: Text + Pinyin + Position
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
seq_len = x.size(1)
pos_ids = (
torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids)
)
x += self.pos_emb(pos_ids)
# 2. Transformer Encoding
# src_key_padding_mask expects True for padding positions
if mask is not None:
# Convert 0/1 mask to bool mask where True is padding
src_mask = mask == 0
else:
src_mask = None
H = self.transformer(x, src_key_padding_mask=src_mask)
return self.ln(H)
# ------------------------------------------------------------------
# 2. 槽位记忆模块 (Slot Memory)
# 对应 README 4.2: 8个槽位, 每槽3步, 拼接+位置编码 [1]
# ------------------------------------------------------------------
class SlotMemory(nn.Module):
def __init__(self, vocab_size, max_slots=8, steps_per_slot=3, dim=512):
super().__init__()
self.max_slots = max_slots
self.steps_per_slot = steps_per_slot
self.total_steps = max_slots * steps_per_slot # 24 steps [1]
# Shared embedding layer for history tokens [1]
self.emb = nn.Embedding(vocab_size, dim)
# Learnable positional embeddings for the flattened sequence [1]
self.pos_emb = nn.Embedding(self.total_steps, dim)
# Start token embedding for empty slots [1]
self.start_emb = nn.Parameter(torch.randn(1, 1, dim))
def forward(self, history_ids):
"""
Args:
history_ids: [batch, total_steps]
Flattened sequence of history tokens.
Empty positions should be filled with a special PAD or handled via mask.
Returns:
S: [batch, total_steps, 512] Slot sequence representation [1]
"""
# Embed history tokens
S = self.emb(history_ids) # [B, 24, 512]
# Add positional embeddings
pos_ids = (
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
)
S += self.pos_emb(pos_ids)
return S
# ------------------------------------------------------------------
# 3. 交叉注意力融合 (Cross-Attention Fusion)
# 对应 README: Query=Slots, Key/Value=Context H [1]
# ------------------------------------------------------------------
class CrossAttentionFusion(nn.Module):
def __init__(self, dim=512, n_heads=4):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads"
# Linear projections for Q, K, V
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.ln = nn.LayerNorm(dim)
def forward(self, slots_S, context_H, slot_mask=None, context_mask=None):
"""
Args:
slots_S: [batch, num_slots_steps, dim] Query
context_H: [batch, ctx_len, dim] Key/Value
slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention)
context_mask: [batch, ctx_len] Optional padding mask
Returns:
Fused: [batch, num_slots_steps, dim]
"""
batch_size, num_slots, _ = slots_S.shape
_, ctx_len, _ = context_H.shape
# Project queries, keys, values
Q = self.q_proj(slots_S) # [batch, num_slots, dim]
K = self.k_proj(context_H) # [batch, ctx_len, dim]
V = self.v_proj(context_H) # [batch, ctx_len, dim]
# Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
# Prepare attention mask if context_mask is provided
attn_mask = None
if context_mask is not None:
# context_mask: [batch, ctx_len] where 0 means padding
# Convert to bool mask and reshape for broadcasting
bool_mask = context_mask == 0 # [batch, ctx_len]
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
# Convert to float mask where True (padding) becomes -inf
attn_mask = bool_mask.float().masked_fill(bool_mask, float("-inf"))
# Scaled dot-product attention
attn_output = F.scaled_dot_product_attention(
Q,
K,
V,
attn_mask=attn_mask,
dropout_p=0.0, # no dropout
)
# Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim]
attn_output = (
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, num_slots, self.dim)
)
# Project back
fused = self.out_proj(attn_output)
# Residual connection and layer norm
fused = self.ln(fused + slots_S)
return fused
# ------------------------------------------------------------------
# 4. 专家混合层 (MoE Layer)
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
# ------------------------------------------------------------------
class MoELayer(nn.Module):
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.dim = dim
# Import Expert from your existing components
# Assuming Expert class is defined as in components.py [2]
self.experts = nn.ModuleList(
[
Expert(
input_dim=dim,
d_model=dim,
num_resblocks=export_resblocks,
output_multiplier=1,
)
for _ in range(num_experts)
]
)
# Gating Network [2]
self.gate = nn.Linear(dim, num_experts)
def forward(self, x):
"""
Args:
x: [batch, seq_len, dim]
Returns:
out: [batch, seq_len, dim]
"""
B, L, D = x.shape
# 1. Compute Gating Scores
gates = self.gate(x) # [B, L, num_experts]
# 2. Select Top-K Experts
topk_vals, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B, L, K]
# Normalize weights for selected experts
weights = F.softmax(topk_vals, dim=-1) # [B, L, K]
# 3. Dispatch and Compute
# Initialize output
out = torch.zeros_like(x)
# Reshape for easier processing: flatten batch and sequence dimensions
x_flat = x.view(-1, D) # [B*L, D]
weights_flat = weights.view(-1, self.top_k) # [B*L, K]
topk_indices_flat = topk_indices.view(-1, self.top_k) # [B*L, K]
# For each of the top-k positions
for k in range(self.top_k):
# Get expert indices and weights for this position
expert_indices = topk_indices_flat[:, k] # [B*L]
expert_weights = weights_flat[:, k].unsqueeze(-1) # [B*L, 1]
# Process each expert separately
for e_idx in range(self.num_experts):
# Mask for tokens assigned to this expert at position k
mask = expert_indices == e_idx # [B*L]
if not mask.any():
continue
# Extract tokens for this expert
x_selected = x_flat[mask] # [N_selected, D]
if x_selected.numel() == 0:
continue
# Pass through expert
expert_out = self.experts[e_idx](x_selected) # [N_selected, D]
# Apply expert weights and add to output
weighted_out = expert_out * expert_weights[mask]
# Scatter back to flat output
out_flat = out.view(-1, D)
out_flat[mask] += weighted_out
# Reshape back to original shape
out = out.view(B, L, D)
return out

View File

@ -1,13 +1,13 @@
import random
from typing import Any, Dict, List, Optional, Tuple
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoModel, AutoTokenizer
from modelscope import AutoTokenizer
from pypinyin import Style, lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import DataLoader, IterableDataset
@ -20,7 +20,8 @@ class PinyinInputDataset(IterableDataset):
self,
data_path: str,
max_workes: int = -1,
max_length=128,
max_iter_length=1e6,
max_seq_length=128,
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 5000,
@ -37,7 +38,9 @@ class PinyinInputDataset(IterableDataset):
Path(files(__package__) / "assets" / "tokenizer")
)
self.data_path = data_path
self.max_length = max_length
self.max_iter_length = max_iter_length
self.max_seq_length = max_seq_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.max_workers = max_workes
@ -48,8 +51,6 @@ class PinyinInputDataset(IterableDataset):
self.query_engine = QueryEngine()
self.query_engine.load()
self.shuffle_buffer_size = shuffle_buffer_size
self.buffer = []
# 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights()
@ -95,76 +96,14 @@ class PinyinInputDataset(IterableDataset):
else:
return 1
def tokenize_with_four_seg(self, parts: List[str]) -> Dict[str, Any]:
input_ids = []
token_type_ids = []
# 添加 [CLS] (Type 0)
cls_id = self.tokenizer.cls_token_id
input_ids.append(cls_id)
token_type_ids.append(0)
for seg_idx, part in enumerate(parts):
if not part:
continue
# Tokenize 单个部分,不加特殊符号
# 注意这里先不截断最后统一截断保证优先级高的段落如part1完整
encoded_part = self.tokenizer(
part, add_special_tokens=False, truncation=False
)
part_ids = encoded_part["input_ids"]
# 如果加上当前部分会超过 MAX_LEN - 1 (留一个位置给最后的SEP或截断),则截断当前部分
remaining_space = (
self.max_length - len(input_ids) - 1
) # -1 for final SEP or safety
if len(part_ids) > remaining_space:
part_ids = part_ids[:remaining_space]
if not part_ids:
continue
input_ids.extend(part_ids)
# 当前段落的 type_id 即为 seg_idx (0, 1, 2, 3)
token_type_ids.extend([seg_idx] * len(part_ids))
# 添加 [SEP] (Type 跟随当前段落)
sep_id = self.tokenizer.sep_token_id
input_ids.append(sep_id)
token_type_ids.append(seg_idx)
# 如果已经达到最大长度,提前退出
if len(input_ids) >= self.max_length:
break
# 4. 处理 Padding 或 最终截断
if len(input_ids) > self.max_length:
input_ids = input_ids[: self.max_length]
token_type_ids = token_type_ids[: self.max_length]
else:
pad_len = self.max_length - len(input_ids)
input_ids += [self.tokenizer.pad_token_id] * pad_len
token_type_ids += [0] * pad_len # Padding mask type 通常为 0
# 5. 生成 Attention Mask
attention_mask = [
1 if i != self.tokenizer.pad_token_id else 0 for i in input_ids
]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
}
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[List[str]]:
def generate_pinyin(self, text: str) -> List[str]:
return lazy_pinyin(text, errors=lambda x: [c for c in x])
# 生成需要预测汉字对应的拼音,并进行加强
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> (int, List[str]):
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
@ -193,14 +132,41 @@ class PinyinInputDataset(IterableDataset):
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
# 计算每个worker的配额
# 将 max_iter_length 转换为整数以确保整数除法
total_quota = int(self.max_iter_length)
base_quota = total_quota // num_workers
remainder = total_quota % num_workers
# 最后一个worker处理剩余的样本如果有余数
if worker_id == num_workers - 1:
worker_quota = base_quota + remainder
else:
worker_quota = base_quota
else:
# 单worker情况使用全部配额
worker_quota = int(self.max_iter_length)
num_workers = 1
# 每个worker有自己的迭代计数器
current_iter_index = 0
batch_samples = []
for sample in self.dataset:
# 检查是否达到最大迭代次数
if current_iter_index >= worker_quota:
break
text = sample.get(self.text_field, "")
if text:
pinyin_list = self.generate_pinyin(text)
for i in range(len(text)):
labels = []
# 如果text[i]不再字符库中,则跳过
# 在开始处理每个字符前检查配额
if current_iter_index >= worker_quota:
break
labels = [] # 添加起始符
# 如果text[i]不在字符库中,则跳过
# 当i小于48时候则将part1取text[0:i]
# 当i大于48时候则将part1取text[i-48:i]
if not self.query_engine.is_chinese_char(text[i]):
@ -256,34 +222,66 @@ class PinyinInputDataset(IterableDataset):
labels = [
self.query_engine.get_char_info_by_char_pinyin(c, p).id
for c, p in zip(text[i:i + pinyin_len], pinyin_list[i:i + pinyin_len])
for c, p in zip(
text[i : i + pinyin_len],
pinyin_list[i : i + pinyin_len],
)
]
labels.append(0)
encoded = self.tokenize_with_four_seg(
[
part1,
part2,
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
part4,
]
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
repeats = self.adjust_frequency(
min([self.sample_freqs[i] for i in labels])
)
sample = {
samples = []
for i, label in enumerate(labels):
repeats = self.adjust_frequency(label)
l = labels[:i]
len_l = len(l)
l.extend([0] * (8 - len_l))
samples.extend(
[
{
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"labels": torch.tensor(labels, dtype=torch.long),
"part1": part1,
"part2": part2,
"part3": part3,
"part4": part4,
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(
l, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
}
batch_samples.extend([sample] * repeats)
]
* repeats
)
# 添加到缓冲区
batch_samples.extend(samples)
# 处理shuffle buffer
if len(batch_samples) >= self.shuffle_buffer_size:
indices = np.random.permutation(len(batch_samples))
self.buffer.extend([batch_samples[i] for i in indices])
for idx in indices:
if current_iter_index >= worker_quota:
# 清空batch_samples并返回
batch_samples = []
yield from self.buffer
return # 使用return而不是break因为我们在生成器函数中
yield batch_samples[idx]
current_iter_index += 1
batch_samples = []
# 处理剩余的样本
if batch_samples:
indices = np.random.permutation(len(batch_samples))
for idx in indices:
if current_iter_index >= worker_quota:
return
yield batch_samples[idx]
current_iter_index += 1

View File

@ -1,448 +1,134 @@
import math
from pathlib import Path
from typing import Optional, Union
from typing import Optional
import torch
import torch.nn as nn
import torch.amp as amp
import torch.optim as optim
import torch.nn.functional as F
from loguru import logger
from modelscope import AutoTokenizer
from tqdm.notebook import tqdm
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
# 导入 components.py 中的组件
from .components import (
AttentionPooling, # 可选,暂不使用
ContextEncoder,
CrossAttentionFusion,
MoELayer,
SlotMemory,
)
class InputMethodEngine(nn.Module):
"""
输入法预测引擎模型
基于 README 设计的轻量级输入法预测模型整合了上下文编码槽位记忆
交叉注意力融合混合专家网络 (MoE) 以及分类预测
输入参数:
input_ids (torch.Tensor): [batch_size, seq_len] 文本 token ids
token_type_ids (torch.Tensor): [batch_size, seq_len] 标识前缀/后缀等类型用于偏置
attention_mask (torch.Tensor): [batch_size, seq_len] 注意力掩码1 表示有效位置
history_slot_ids (torch.Tensor): [batch_size, num_slots] [num_slots] 历史槽位 ID
如果输入为 [num_slots]内部会自动扩展 batch 维度
输出:
logits (torch.Tensor): [batch_size, vocab_size] 下一个字符的概率分布未经过 softmax
"""
def __init__(
self,
pretrained_encoder, # 已加载并扩展好的预训练编码器
output_vocab_size: int,
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
max_slot_steps: int = 24,
num_experts: int = 20,
top_k: int = 3,
expert_res_blocks: int = 4,
dropout: float = 0.3,
use_attention_pooling: bool = False,
**kwargs,
vocab_size: int = 10019,
pinyin_vocab_size: int = 28,
dim: int = 512,
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
n_layers: int = 4, # Transformer 层数
n_heads: int = 4, # 注意力头数
num_experts: int = 20, # MoE 专家数量
max_seq_len: int = 128, # 最大上下文长度
use_pinyin: bool = False, # 是否使用拼音特征(若为 False拼音嵌入恒为零
):
super().__init__()
self.hidden_size = hidden_size
self.output_vocab_size = output_vocab_size
self.max_slot_steps = max_slot_steps
self.num_experts = num_experts
self.top_k = top_k
self.use_attention_pooling = use_attention_pooling
self.dim = dim
self.num_slots = num_slots
self.use_pinyin = use_pinyin
self.vocab_size = vocab_size
# 预训练编码器
self.encoder = pretrained_encoder
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
self.cross_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=4, # 可配置
dropout=dropout,
batch_first=True,
# 1. 上下文编码器 (ContextEncoder)
# 若 use_pinyin=False则传入 pinyin_vocab_size=1 并固定嵌入为零
self.context_encoder = ContextEncoder(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size if use_pinyin else 1,
dim=dim,
n_layers=n_layers,
n_heads=n_heads,
max_len=max_seq_len,
)
if use_attention_pooling:
self.attention_pooling = AttentionPooling(hidden_size)
self.gate = nn.Linear(hidden_size, num_experts)
self.experts = nn.ModuleList(
[
Expert(
input_dim=hidden_size,
d_model=hidden_size,
num_resblocks=expert_res_blocks,
output_multiplier=1,
dropout_prob=dropout,
)
for _ in range(num_experts)
]
# 2. 槽位记忆模块 (SlotMemory)
# 适配历史槽位数量为 num_slots每个槽位对应一个词而非多步
self.slot_memory = SlotMemory(
vocab_size=vocab_size,
max_slots=num_slots,
steps_per_slot=1, # 每个槽位只占一步
dim=dim,
)
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=False),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Linear(hidden_size, output_vocab_size),
)
# 3. 交叉注意力融合 (CrossAttentionFusion)
# 使用 F.scaled_dot_product_attention 实现的版本
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
self._init_weights()
# 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2)
def encode_text(self, input_ids, token_type_ids, attention_mask):
"""
使用预训练编码器编码文本
注意预训练模型输出可能包含 last_hidden_state pooler_output
"""
outputs = self.encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
return_dict=True,
)
# 取 last_hidden_state [batch, seq_len, hidden]
return outputs.last_hidden_state
# 5. 分类头
self.classifier = nn.Linear(dim, vocab_size)
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
"""
单步预测根据当前槽位序列已拼接的嵌入预测下一个文字的概率分布
context: [batch, seq_len, hidden] 文本编码结果
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入已拼接
slot_seq_mask: [batch, current_len] 有效位置mask1有效
返回: [batch, output_vocab_size] 概率分布
"""
batch_size = slot_seq_emb.size(0)
# 交叉注意力Query是槽位序列通常只取最后一个步的嵌入作为Query但这里我们使用整个序列
# 为了简单我们使用整个序列作为Query然后取最后一个位置的输出因为自回归
# 方法1Query = 最后一个位置的嵌入(单个向量)
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
# 交叉注意力
attn_out, _ = self.cross_attention(
query=last_query,
key=context,
value=context,
key_padding_mask=(
context.sum(-1) == 0
), # 忽略填充位置实际应传入attention_mask
) # [batch, 1, hidden]
attn_out = attn_out.squeeze(1) # [batch, hidden]
# 门控网络选择top-k专家
gate_logits = self.gate(attn_out) # [batch, num_experts]
topk_weights, topk_indices = torch.topk(
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
)
# 归一化权重
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# 计算专家输出加权和
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # [batch]
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
# 批量获取专家输出(需对每个样本分别处理,或用循环)
# 这里简单循环,实际可优化为 gather
for b in range(batch_size):
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
# 分类头
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
return logits
def forward_train(
self,
input_ids,
token_type_ids,
attention_mask,
slot_target_ids,
slot_target_mask=None,
):
"""
训练模式一次性并行计算所有槽位步的预测
slot_target_ids: [batch, max_slot_steps] 真实标签teacher forcing
slot_target_mask: [batch, max_slot_steps] 有效步mask1表示该步存在
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
"""
batch_size, max_steps = slot_target_ids.shape
device = input_ids.device
# 1. 编码文本
context = self.encode_text(
input_ids, token_type_ids, attention_mask
) # [b, L, h]
# 2. 构建槽位序列嵌入teacher forcing使用真实标签构建
# 将每个样本的槽位ID转为embedding并添加位置编码
# slot_target_ids 形状 [b, T]
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
# 位置编码位置从0开始
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
slot_emb = slot_emb + pos_emb
# 可选加入mask无效位置置0
if slot_target_mask is not None:
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
# 3. 交叉注意力Query为整个槽位序列每个位置独立预测
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
attn_out, _ = self.cross_attention(
query=slot_emb, # [b, T, h]
key=context, # [b, L, h]
value=context,
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
) # [b, T, h]
# 4. 对每个槽位步分别通过门控+专家+分类头
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
b, T, h = attn_out.shape
attn_flat = attn_out.view(b * T, h) # [b*T, h]
# 门控网络
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
gate_probs = F.softmax(gate_logits, dim=-1)
topk_probs, topk_indices = torch.topk(
gate_probs, self.top_k, dim=-1
) # [b*T, top_k]
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
for i in range(self.top_k):
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
idx = topk_indices[:, i] # [b*T]
# 批量获取专家输出
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
# 实际生产环境应优化为组卷积或使用 einsum此处保持清晰
for expert_id in range(self.num_experts):
mask = idx == expert_id
if mask.any():
# 取出属于该专家的样本
sub_input = attn_flat[mask] # [k, h]
sub_output = self.experts[expert_id](sub_input) # [k, h]
expert_out_flat[mask] += weight[mask] * sub_output
# 分类头
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
return logits
# 可选:如果不需要拼音,将拼音嵌入矩阵固定为零
if not use_pinyin and hasattr(self.context_encoder, "pinyin_emb"):
# 将拼音嵌入权重置零,确保对输出无影响
nn.init.zeros_(self.context_encoder.pinyin_emb.weight)
def forward(
self,
input_ids,
token_type_ids,
attention_mask,
slot_target_ids=None,
slot_target_mask=None,
mode="train",
):
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
attention_mask: torch.Tensor,
history_slot_ids: torch.Tensor,
) -> torch.Tensor:
"""
统一接口
mode='train': 使用 teacher forcing 并行计算所有步 logits返回 [b, T, vocab]
mode='infer': 需结合外部循环使用 forward_single_step
前向传播
"""
if mode == "train":
return self.forward_train(
input_ids,
token_type_ids,
attention_mask,
slot_target_ids,
slot_target_mask,
)
batch_size = input_ids.size(0)
# 处理 history_slot_ids若为 [num_slots] 则扩展 batch 维度
if history_slot_ids.dim() == 1:
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
# 1. 构造拼音输入(如果 use_pinyin=False则使用全零占位符
if self.use_pinyin:
# 注意:这里需要真实的拼音 ids但当前输入未提供故用零占位实际应用中应从外部获取
# 为简化演示,此处使用全零张量,并假设拼音词汇表中 0 为 padding。
pinyin_ids = torch.zeros_like(input_ids)
else:
raise NotImplementedError(
"Inference mode should call forward_single_step directly"
)
pinyin_ids = torch.zeros_like(input_ids)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
# 2. 上下文编码 -> H [batch, seq_len, dim]
# 注意ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor=None,
criterion=None,
optimizer=None,
num_epochs=1,
stop_batch=2e5,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
lr_schedule=None,
save_dir=None,
save_frequency=1000,
):
"""
训练模型
"""
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
if _processed_batches < _warmup_steps:
current_lr = _lr * (_processed_batches / _warmup_steps)
else:
progress = (_processed_batches - _warmup_steps) / (
_stop_batch - _warmup_steps
)
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
return current_lr
# 3. 槽位记忆编码 -> S [batch, num_slots, dim]
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
if lr_schedule is None:
lr_schedule = default_lr_schedule
# 4. 交叉注意力融合 (使用 CrossAttentionFusion)
fused = self.cross_attn(S, H, context_mask=attention_mask)
self.train()
# 5. MoE 处理 -> [batch, num_slots, dim]
moe_out = self.moe(fused)
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数:需要设置 ignore_index=-1因为标签中无效位置用 -1 表示
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
)
else:
criterion = nn.CrossEntropyLoss(
label_smoothing=label_smoothing, ignore_index=-1
)
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0
batch_loss_sum = 0.0
optimizer.zero_grad()
try:
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
# 学习率调度
current_lr = lr_schedule(
lr, stop_batch, processed_batches, warmup_steps
)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# 从 batch 中获取数据
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
# 构建 slot_target_mask有效位置为 1无效位置为 0假设无效标签为 -1
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
# 调用模型(训练模式)
logits = self(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
slot_target_ids=labels,
slot_target_mask=slot_target_mask,
mode="train",
) # logits: [batch, max_slot_steps, output_vocab_size]
# 计算损失忽略填充位置ignore_index=-1 已在 criterion 中设置)
loss = criterion(
logits.view(-1, self.output_vocab_size), labels.view(-1)
)
loss = loss / grad_accum_steps
scaler.scale(loss).backward()
# 梯度累积更新
if (processed_batches + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
# 定期评估
if processed_batches % eval_frequency == 0:
if eval_dataloader:
self.eval()
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
self.train()
if monitor:
monitor.add_step(
processed_batches,
{
"train_loss": batch_loss_sum
/ (eval_frequency if processed_batches > 0 else 1),
"acc": acc,
"loss": eval_loss,
"lr": current_lr,
},
)
logger.info(
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
f"current_lr: {current_lr}"
)
else:
logger.info(
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
f"current_lr: {current_lr}"
)
batch_loss_sum = 0.0
processed_batches += 1
if processed_batches >= stop_batch:
break
else:
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
# 但为了兼容原有逻辑,这里不做额外处理
pass
# 训练结束通知
if monitor:
monitor.finish()
except KeyboardInterrupt:
logger.info("Training interrupted by user")
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False
# 6. 池化与分类:对槽位维度求平均(或使用 mask 池化)
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
pooled = (moe_out * slot_mask).sum(dim=1) / (slot_mask.sum(dim=1) + 1e-8)
# 如果所有槽位均为 padding则降级为全局平均
if torch.isnan(pooled).any():
pooled = moe_out.mean(dim=1)
logits = self.classifier(pooled) # [batch, vocab_size]
return logits

18
test.py
View File

@ -1,7 +1,17 @@
import sys
from tqdm import tqdm
from model.dataset import PinyinInputDataset
dataset = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
if sys.platform == "win32":
dataset_path = "data"
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
dataset = PinyinInputDataset(dataset_path, max_iter_length=20, max_workes=3)
for i, line in enumerate(dataset):
print(line['labels'])
if i > 10:
break
for k, v in line.items():
if isinstance(v, str):
continue
print(k, v.shape)

3610
uv.lock

File diff suppressed because it is too large Load Diff