feat: 添加输入法引擎模型设计与实现,包含架构、核心模块及训练策略
This commit is contained in:
parent
a8ccb3e4fe
commit
b6a677f15d
|
|
@ -0,0 +1 @@
|
|||
3.12
|
||||
97
README.md
97
README.md
|
|
@ -1,2 +1,99 @@
|
|||
# SUimeModelTraner
|
||||
|
||||
> 深度学习输入法引擎技术方案
|
||||
|
||||
## 1. 任务目标
|
||||
设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。
|
||||
|
||||
## 2. 输入表示
|
||||
- **四段文本**:
|
||||
- 光标前文本
|
||||
- 拼音(如“bangdao”)
|
||||
- 光标后文本
|
||||
- 符合条件的历史记录(如“绑到|邦道|…”)
|
||||
- **编码方式**:使用BERT类Tokenizer,统一序列长度 **L=128**(或88)。
|
||||
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
|
||||
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
|
||||
|
||||
## 3. 模型架构概览
|
||||
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
|
||||
|
||||
![模型架构示意图]
|
||||
|
||||
## 4. 核心模块设计
|
||||
|
||||
### 4.1 Embedding层与Transformer编码器
|
||||
- **Embedding维度**:512
|
||||
- **Transformer层数**:4层
|
||||
- **多头注意力头数**:4或8
|
||||
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`)
|
||||
- **预训练**:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。
|
||||
|
||||
### 4.2 槽位记忆模块
|
||||
- **槽位结构**:共8个槽位,每个槽位最多可填充3步,总计最多24步。
|
||||
- **槽位嵌入**:每一步选择的文字ID通过共享的embedding层转换为512维向量。
|
||||
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
|
||||
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
|
||||
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
|
||||
|
||||
### 4.3 交叉注意力与注意力池化
|
||||
- **Query**:当前步的槽位序列(经过位置编码后)
|
||||
- **Key/Value**:Transformer编码器输出 `H`
|
||||
- **输出**:槽位相关的特征序列
|
||||
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`(维度512)。
|
||||
|
||||
### 4.4 门控网络与专家层(MoE)
|
||||
- **门控网络**:输入 `f`,输出20个专家的权重,选择**top-3**专家。
|
||||
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
|
||||
- **专家组合**:对选中的3个专家输出进行加权求和,得到融合特征 `e`。
|
||||
|
||||
### 4.5 分类头
|
||||
- **设计**:采用同维FFN,结构为
|
||||
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
|
||||
输出10019维(ID范围0~10018,其中0表示终止符)。
|
||||
- **优点**:保留特征维度,引入非线性,参数量适中。
|
||||
|
||||
## 5. 解码策略:束搜索
|
||||
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(如5)。
|
||||
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
|
||||
- **终止条件**:
|
||||
1. 所有槽位已填满(8×3=24步);
|
||||
2. 当前步所有候选分支的最高概率词均为 **0(终止符)**,则强制退出。
|
||||
- **输出**:概率最高的完整槽位序列。
|
||||
|
||||
## 6. 训练设置
|
||||
- **优化器**:AdamW
|
||||
- **损失函数**:每一步的CrossEntropyLoss(仅计算非填充位置)
|
||||
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
|
||||
- **标签处理**:每个槽位步的真实文字ID作为监督信号
|
||||
|
||||
## 7. 关键设计考量与潜在风险
|
||||
|
||||
| 设计点 | 考量 | 潜在风险/优化 |
|
||||
|--------|------|----------------|
|
||||
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
|
||||
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 |
|
||||
| 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
|
||||
| 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
|
||||
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
|
||||
|
||||
## 8. 后续优化方向
|
||||
- **拼音增强**:引入拼音专用嵌入(音节级编码)
|
||||
- **槽位建模增强**:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用)
|
||||
- **预训练**:尝试加载开源中文BERT权重作为编码器初始化
|
||||
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
|
||||
|
||||
---
|
||||
|
||||
## 附录:关键超参数(待定)
|
||||
- 序列长度:128
|
||||
- 隐藏层维度:512
|
||||
- Transformer层数:4
|
||||
- 注意力头数:4
|
||||
- 专家数量:20
|
||||
- 束宽:5
|
||||
- 学习率:待调(建议 1e-4 ~ 5e-4,带warmup)
|
||||
|
||||
---
|
||||
|
||||
此方案结构完整,模块间接口清晰,可立即进入原型实现阶段。建议先在小规模数据上验证前向与训练流程,再逐步扩展至全量数据调优。
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
def main():
|
||||
print("Hello from suimemodeltraner!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
[project]
|
||||
name = "suimemodeltraner"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"modelscope>=1.35.1",
|
||||
"onnx>=1.20.1",
|
||||
"torch>=2.10.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
# 设置当前项目的默认索引源
|
||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"autocommit",
|
||||
"pytest>=9.0.2",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
autocommit = { git = "https://gitea.winkinshly.site/songsenand/autocommit.git" }
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
# 👇 这是关键:指定包在 src/ 下
|
||||
package-dir = {"" = "src"}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
model = ["assets/*"]
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import torch.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# ---------------------------- 注意力池化模块----------------------------
|
||||
class AttentionPooling(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.attn = nn.Linear(hidden_size, 1)
|
||||
# 三个可学习偏置:文本、拼音、个性化
|
||||
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
|
||||
|
||||
def forward(self, x, mask=None, token_type_ids=None):
|
||||
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
|
||||
if token_type_ids is not None:
|
||||
# 根据 token_type_ids 添加对应偏置
|
||||
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
|
||||
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
|
||||
scores = scores + bias_per_token
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
weights = torch.softmax(scores, dim=-1)
|
||||
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
||||
return pooled
|
||||
|
||||
# ---------------------------- 残差块 ----------------------------
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim, dropout_prob=0.3):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(dim, dim)
|
||||
self.ln1 = nn.LayerNorm(dim)
|
||||
self.linear2 = nn.Linear(dim, dim)
|
||||
self.ln2 = nn.LayerNorm(dim)
|
||||
self.gelu = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
# 修复:使用 self.gelu 而不是未定义的 self.relu
|
||||
x = self.gelu(self.linear1(x))
|
||||
x = self.ln1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.ln2(x)
|
||||
x = self.dropout(x)
|
||||
x = x + residual
|
||||
return self.gelu(x)
|
||||
|
||||
|
||||
# ---------------------------- 专家网络 ----------------------------
|
||||
class Expert(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
d_model=512,
|
||||
num_resblocks=4,
|
||||
output_multiplier=1,
|
||||
dropout_prob=0.3,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = input_dim * output_multiplier
|
||||
self.linear_in = nn.Linear(input_dim, d_model)
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||||
)
|
||||
self.output = nn.Sequential(
|
||||
nn.Linear(d_model, d_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout_prob),
|
||||
nn.Linear(d_model, self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_in(x)
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
return self.output(x)
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||
|
||||
|
||||
class InputMethodEngine(nn.Module):
|
||||
"""
|
||||
输入法引擎模型
|
||||
输入:光标前/后文本、拼音、历史记录(四段)的编码序列
|
||||
输出:槽位序列(最多24个文字ID)的概率分布
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_vocab_size: int, # 输入文本的词汇表大小(含特殊符号)
|
||||
output_vocab_size: int, # 输出文字的词汇表大小(含终止符0)
|
||||
hidden_size: int = 512, # 隐藏层维度
|
||||
num_layers: int = 4, # Transformer 层数
|
||||
num_heads: int = 4, # 多头注意力头数
|
||||
max_slot_steps: int = 24, # 最大槽位步数(8槽×3步)
|
||||
num_experts: int = 20, # 专家数量
|
||||
top_k: int = 3, # 每个token选择的专家数
|
||||
expert_res_blocks: int = 4, # 每个专家内部的残差块数
|
||||
dropout: float = 0.3,
|
||||
use_attention_pooling: bool = False, # 是否在槽位特征后加池化(暂未使用)
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.output_vocab_size = output_vocab_size
|
||||
self.max_slot_steps = max_slot_steps
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.use_attention_pooling = use_attention_pooling
|
||||
|
||||
# -------------------- 1. 文本编码器 --------------------
|
||||
self.token_embedding = nn.Embedding(input_vocab_size, hidden_size)
|
||||
self.position_embedding = nn.Embedding(
|
||||
512, hidden_size
|
||||
) # 可学习位置编码,512足够长
|
||||
self.token_type_embedding = nn.Embedding(4, hidden_size) # 4种段落类型
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=hidden_size * 4,
|
||||
dropout=dropout,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# -------------------- 2. 槽位相关组件 --------------------
|
||||
# 槽位文字嵌入(输出词汇表)
|
||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||
# 槽位位置编码(可学习,最大步数+1用于起始符)
|
||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||
|
||||
# 交叉注意力(Query = 槽位序列,Key/Value = 文本编码结果)
|
||||
self.cross_attention = nn.MultiheadAttention(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
# 可选:注意力池化(若 use_attention_pooling=True,则用于聚合槽位序列特征)
|
||||
if use_attention_pooling:
|
||||
self.attention_pooling = AttentionPooling(hidden_size)
|
||||
|
||||
# -------------------- 3. 门控网络 + 专家层 --------------------
|
||||
self.gate = nn.Linear(hidden_size, num_experts) # 输出 logits,用于选择专家
|
||||
|
||||
# 每个专家是一个独立的残差网络(输出维度=hidden_size)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
input_dim=hidden_size,
|
||||
d_model=hidden_size,
|
||||
num_resblocks=expert_res_blocks,
|
||||
output_multiplier=1,
|
||||
dropout_prob=dropout,
|
||||
)
|
||||
for _ in range(num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# -------------------- 4. 分类头(同维FFN)--------------------
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||
nn.LayerNorm(hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_size, output_vocab_size),
|
||||
)
|
||||
|
||||
# 初始化
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化权重(简单示例)"""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
||||
"""
|
||||
编码输入文本(光标前、拼音、光标后、历史记录)
|
||||
返回: [batch, seq_len, hidden_size] 编码后的上下文表示
|
||||
"""
|
||||
seq_len = input_ids.size(1)
|
||||
# 位置编码
|
||||
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
|
||||
pos_emb = self.position_embedding(positions)
|
||||
# token类型编码
|
||||
type_emb = self.token_type_embedding(token_type_ids)
|
||||
# 词嵌入
|
||||
token_emb = self.token_embedding(input_ids)
|
||||
# 相加
|
||||
x = token_emb + pos_emb + type_emb
|
||||
# Transformer编码器(需要padding mask)
|
||||
# attention_mask: [batch, seq_len] 1表示有效,0表示填充
|
||||
# TransformerEncoder 要求 mask 为 [batch, seq_len] 且 1 表示忽略
|
||||
src_key_padding_mask = attention_mask == 0 # True表示填充位置
|
||||
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
|
||||
return x
|
||||
|
||||
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
||||
"""
|
||||
单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布
|
||||
context: [batch, seq_len, hidden] 文本编码结果
|
||||
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接)
|
||||
slot_seq_mask: [batch, current_len] 有效位置mask(1有效)
|
||||
返回: [batch, output_vocab_size] 概率分布
|
||||
"""
|
||||
batch_size = slot_seq_emb.size(0)
|
||||
# 交叉注意力:Query是槽位序列(通常只取最后一个步的嵌入作为Query,但这里我们使用整个序列)
|
||||
# 为了简单,我们使用整个序列作为Query,然后取最后一个位置的输出(因为自回归)
|
||||
# 方法1:Query = 最后一个位置的嵌入(单个向量)
|
||||
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
|
||||
# 交叉注意力
|
||||
attn_out, _ = self.cross_attention(
|
||||
query=last_query,
|
||||
key=context,
|
||||
value=context,
|
||||
key_padding_mask=(
|
||||
context.sum(-1) == 0
|
||||
), # 忽略填充位置,实际应传入attention_mask
|
||||
) # [batch, 1, hidden]
|
||||
attn_out = attn_out.squeeze(1) # [batch, hidden]
|
||||
|
||||
# 门控网络:选择top-k专家
|
||||
gate_logits = self.gate(attn_out) # [batch, num_experts]
|
||||
topk_weights, topk_indices = torch.topk(
|
||||
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
|
||||
)
|
||||
# 归一化权重
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 计算专家输出加权和
|
||||
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
|
||||
for i in range(self.top_k):
|
||||
expert_idx = topk_indices[:, i] # [batch]
|
||||
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
|
||||
# 批量获取专家输出(需对每个样本分别处理,或用循环)
|
||||
# 这里简单循环,实际可优化为 gather
|
||||
for b in range(batch_size):
|
||||
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
|
||||
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
|
||||
|
||||
# 分类头
|
||||
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
|
||||
return logits
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
slot_target_ids,
|
||||
slot_target_mask=None,
|
||||
):
|
||||
"""
|
||||
训练模式:一次性并行计算所有槽位步的预测
|
||||
slot_target_ids: [batch, max_slot_steps] 真实标签(teacher forcing)
|
||||
slot_target_mask: [batch, max_slot_steps] 有效步mask(1表示该步存在)
|
||||
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
|
||||
"""
|
||||
batch_size, max_steps = slot_target_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# 1. 编码文本
|
||||
context = self.encode_text(
|
||||
input_ids, token_type_ids, attention_mask
|
||||
) # [b, L, h]
|
||||
|
||||
# 2. 构建槽位序列嵌入(teacher forcing:使用真实标签构建)
|
||||
# 将每个样本的槽位ID转为embedding,并添加位置编码
|
||||
# slot_target_ids 形状 [b, T]
|
||||
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
|
||||
# 位置编码(位置从0开始)
|
||||
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
|
||||
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
|
||||
slot_emb = slot_emb + pos_emb
|
||||
# 可选:加入mask(无效位置置0)
|
||||
if slot_target_mask is not None:
|
||||
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
|
||||
|
||||
# 3. 交叉注意力:Query为整个槽位序列(每个位置独立预测)
|
||||
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
|
||||
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
|
||||
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
|
||||
attn_out, _ = self.cross_attention(
|
||||
query=slot_emb, # [b, T, h]
|
||||
key=context, # [b, L, h]
|
||||
value=context,
|
||||
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
|
||||
) # [b, T, h]
|
||||
|
||||
# 4. 对每个槽位步分别通过门控+专家+分类头
|
||||
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
|
||||
b, T, h = attn_out.shape
|
||||
attn_flat = attn_out.view(b * T, h) # [b*T, h]
|
||||
|
||||
# 门控网络
|
||||
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
|
||||
gate_probs = F.softmax(gate_logits, dim=-1)
|
||||
topk_probs, topk_indices = torch.topk(
|
||||
gate_probs, self.top_k, dim=-1
|
||||
) # [b*T, top_k]
|
||||
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
|
||||
|
||||
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
|
||||
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
|
||||
for i in range(self.top_k):
|
||||
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
|
||||
idx = topk_indices[:, i] # [b*T]
|
||||
# 批量获取专家输出
|
||||
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
|
||||
# 实际生产环境应优化为组卷积或使用 einsum,此处保持清晰
|
||||
for expert_id in range(self.num_experts):
|
||||
mask = idx == expert_id
|
||||
if mask.any():
|
||||
# 取出属于该专家的样本
|
||||
sub_input = attn_flat[mask] # [k, h]
|
||||
sub_output = self.experts[expert_id](sub_input) # [k, h]
|
||||
expert_out_flat[mask] += weight[mask] * sub_output
|
||||
|
||||
# 分类头
|
||||
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
|
||||
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
slot_target_ids=None,
|
||||
slot_target_mask=None,
|
||||
mode="train",
|
||||
):
|
||||
"""
|
||||
统一接口
|
||||
mode='train': 使用 teacher forcing 并行计算所有步 logits,返回 [b, T, vocab]
|
||||
mode='infer': 需结合外部循环,使用 forward_single_step
|
||||
"""
|
||||
if mode == "train":
|
||||
return self.forward_train(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
slot_target_ids,
|
||||
slot_target_mask,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Inference mode should call forward_single_step directly"
|
||||
)
|
||||
Loading…
Reference in New Issue