Compare commits
2 Commits
a8ccb3e4fe
...
1d2ae677f9
| Author | SHA1 | Date |
|---|---|---|
|
|
1d2ae677f9 | |
|
|
b6a677f15d |
|
|
@ -0,0 +1 @@
|
||||||
|
3.12
|
||||||
97
README.md
97
README.md
|
|
@ -1,2 +1,99 @@
|
||||||
# SUimeModelTraner
|
# SUimeModelTraner
|
||||||
|
|
||||||
|
> 深度学习输入法引擎技术方案
|
||||||
|
|
||||||
|
## 1. 任务目标
|
||||||
|
设计一个基于上下文的输入法引擎模型,输入包括光标前文本、拼音、光标后文本及历史记录,输出候选词序列(文字ID序列),支持束搜索解码,实现精准的拼音到文字转换。
|
||||||
|
|
||||||
|
## 2. 输入表示
|
||||||
|
- **四段文本**:
|
||||||
|
- 光标前文本
|
||||||
|
- 拼音(如“bangdao”)
|
||||||
|
- 光标后文本
|
||||||
|
- 符合条件的历史记录(如“绑到|邦道|…”)
|
||||||
|
- **编码方式**:使用BERT类Tokenizer,统一序列长度 **L=128**(或88)。
|
||||||
|
- **段落区分**:通过 `token_type_ids` 标记段落,取值为 **0,1,2,3**(分别对应四类输入)。
|
||||||
|
- **拼音处理**:暂将拼音作为普通文本输入,预留专用嵌入接口供后续优化。
|
||||||
|
|
||||||
|
## 3. 模型架构概览
|
||||||
|
整体结构分为:输入编码 → Transformer编码 → 槽位记忆 → 交叉注意力 → 门控+专家混合 → 分类头 → 束搜索解码。
|
||||||
|
|
||||||
|
![模型架构示意图]
|
||||||
|
|
||||||
|
## 4. 核心模块设计
|
||||||
|
|
||||||
|
### 4.1 Embedding层与Transformer编码器
|
||||||
|
- **Embedding维度**:512
|
||||||
|
- **Transformer层数**:4层
|
||||||
|
- **多头注意力头数**:4或8
|
||||||
|
- **输出**:上下文表示 `H`(形状:`[batch, L, 512]`)
|
||||||
|
- **预训练**:可选加载structBERT轻量级骨干(链接已失效,当前从头训练)。
|
||||||
|
|
||||||
|
### 4.2 槽位记忆模块
|
||||||
|
- **槽位结构**:共8个槽位,每个槽位最多可填充3步,总计最多24步。
|
||||||
|
- **槽位嵌入**:每一步选择的文字ID通过共享的embedding层转换为512维向量。
|
||||||
|
- **历史槽位表示**:将当前步之前的所有槽位嵌入**按顺序拼接**,形成动态增长的序列。
|
||||||
|
- **位置编码**:为拼接后的槽位序列添加可学习的位置嵌入,帮助模型捕捉时序关系。
|
||||||
|
- **初始状态**:第一步时历史为空,用特殊 `[START]` 嵌入向量作为占位。
|
||||||
|
|
||||||
|
### 4.3 交叉注意力与注意力池化
|
||||||
|
- **Query**:当前步的槽位序列(经过位置编码后)
|
||||||
|
- **Key/Value**:Transformer编码器输出 `H`
|
||||||
|
- **输出**:槽位相关的特征序列
|
||||||
|
- **注意力池化**:对交叉注意力输出序列进行池化(如均值池化或可学习注意力池化),得到固定长度的特征向量 `f`(维度512)。
|
||||||
|
|
||||||
|
### 4.4 门控网络与专家层(MoE)
|
||||||
|
- **门控网络**:输入 `f`,输出20个专家的权重,选择**top-3**专家。
|
||||||
|
- **专家结构**:每个专家为残差网络,输出维度 **512**(与隐藏层一致)。
|
||||||
|
- **专家组合**:对选中的3个专家输出进行加权求和,得到融合特征 `e`。
|
||||||
|
|
||||||
|
### 4.5 分类头
|
||||||
|
- **设计**:采用同维FFN,结构为
|
||||||
|
`Linear(512, 512, bias=False) → LayerNorm → GELU → Linear(512, 10019)`
|
||||||
|
输出10019维(ID范围0~10018,其中0表示终止符)。
|
||||||
|
- **优点**:保留特征维度,引入非线性,参数量适中。
|
||||||
|
|
||||||
|
## 5. 解码策略:束搜索
|
||||||
|
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(如5)。
|
||||||
|
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
|
||||||
|
- **终止条件**:
|
||||||
|
1. 所有槽位已填满(8×3=24步);
|
||||||
|
2. 当前步所有候选分支的最高概率词均为 **0(终止符)**,则强制退出。
|
||||||
|
- **输出**:概率最高的完整槽位序列。
|
||||||
|
|
||||||
|
## 6. 训练设置
|
||||||
|
- **优化器**:AdamW
|
||||||
|
- **损失函数**:每一步的CrossEntropyLoss(仅计算非填充位置)
|
||||||
|
- **训练数据**:真实用户输入日志,构造(上下文,拼音,目标槽位序列)三元组
|
||||||
|
- **标签处理**:每个槽位步的真实文字ID作为监督信号
|
||||||
|
|
||||||
|
## 7. 关键设计考量与潜在风险
|
||||||
|
|
||||||
|
| 设计点 | 考量 | 潜在风险/优化 |
|
||||||
|
|--------|------|----------------|
|
||||||
|
| 拼音编码 | 当前作为普通文本,实现简单 | 多字词场景可能对齐困难,后续可增加专用拼音嵌入层 |
|
||||||
|
| 槽位更新 | 拼接+位置编码,保留完整历史 | 序列长度动态增长(最多24),交叉注意力计算量可控 |
|
||||||
|
| 专家层 | 20专家,top-3组合,增强特征选择性 | 需确保门控网络负载均衡,避免专家“死”掉 |
|
||||||
|
| 分类头 | 同维FFN,兼顾容量与非线性 | 若过拟合可降维或增加Dropout |
|
||||||
|
| 束搜索 | 槽位内束搜索,工程复杂度适中 | 需维护多个候选路径的历史状态,内存管理需注意 |
|
||||||
|
|
||||||
|
## 8. 后续优化方向
|
||||||
|
- **拼音增强**:引入拼音专用嵌入(音节级编码)
|
||||||
|
- **槽位建模增强**:若拼接方式难以学习长距离依赖,可替换为轻量级GRU(但当前暂不采用)
|
||||||
|
- **预训练**:尝试加载开源中文BERT权重作为编码器初始化
|
||||||
|
- **知识蒸馏**:若模型过大,可蒸馏为更小版本用于端侧部署
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:关键超参数(待定)
|
||||||
|
- 序列长度:128
|
||||||
|
- 隐藏层维度:512
|
||||||
|
- Transformer层数:4
|
||||||
|
- 注意力头数:4
|
||||||
|
- 专家数量:20
|
||||||
|
- 束宽:5
|
||||||
|
- 学习率:待调(建议 1e-4 ~ 5e-4,带warmup)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
此方案结构完整,模块间接口清晰,可立即进入原型实现阶段。建议先在小规模数据上验证前向与训练流程,再逐步扩展至全量数据调优。
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
[project]
|
||||||
|
name = "suimemodeltraner"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"modelscope>=1.35.1",
|
||||||
|
"onnx>=1.20.1",
|
||||||
|
"torch>=2.10.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
# 设置当前项目的默认索引源
|
||||||
|
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"autocommit",
|
||||||
|
"pytest>=9.0.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
autocommit = { git = "https://gitea.winkinshly.site/songsenand/autocommit.git" }
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
# 👇 这是关键:指定包在 src/ 下
|
||||||
|
package-dir = {"" = "src"}
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
model = ["assets/*"]
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
import torch
|
||||||
|
import torch.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 注意力池化模块----------------------------
|
||||||
|
class AttentionPooling(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.Linear(hidden_size, 1)
|
||||||
|
# 三个可学习偏置:文本、拼音、个性化
|
||||||
|
self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias]
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, token_type_ids=None):
|
||||||
|
scores = self.attn(x).squeeze(-1) # [batch, seq_len]
|
||||||
|
if token_type_ids is not None:
|
||||||
|
# 根据 token_type_ids 添加对应偏置
|
||||||
|
# bias 形状 [3],通过索引扩展为 [batch, seq_len]
|
||||||
|
bias_per_token = self.bias[token_type_ids] # [batch, seq_len]
|
||||||
|
scores = scores + bias_per_token
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
weights = torch.softmax(scores, dim=-1)
|
||||||
|
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
# ---------------------------- 残差块 ----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, dim, dropout_prob=0.3):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(dim, dim)
|
||||||
|
self.ln1 = nn.LayerNorm(dim)
|
||||||
|
self.linear2 = nn.Linear(dim, dim)
|
||||||
|
self.ln2 = nn.LayerNorm(dim)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
# 修复:使用 self.gelu 而不是未定义的 self.relu
|
||||||
|
x = self.gelu(self.linear1(x))
|
||||||
|
x = self.ln1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = self.ln2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = x + residual
|
||||||
|
return self.gelu(x)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 专家网络 ----------------------------
|
||||||
|
class Expert(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
d_model=512,
|
||||||
|
num_resblocks=4,
|
||||||
|
output_multiplier=1,
|
||||||
|
dropout_prob=0.3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = input_dim * output_multiplier
|
||||||
|
self.linear_in = nn.Linear(input_dim, d_model)
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
|
||||||
|
)
|
||||||
|
self.output = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout_prob),
|
||||||
|
nn.Linear(d_model, self.output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_in(x)
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
return self.output(x)
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||||
|
|
||||||
|
|
||||||
|
class InputMethodEngine(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrained_encoder, # 已加载并扩展好的预训练编码器
|
||||||
|
output_vocab_size: int,
|
||||||
|
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
|
||||||
|
max_slot_steps: int = 24,
|
||||||
|
num_experts: int = 20,
|
||||||
|
top_k: int = 3,
|
||||||
|
expert_res_blocks: int = 4,
|
||||||
|
dropout: float = 0.3,
|
||||||
|
use_attention_pooling: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.output_vocab_size = output_vocab_size
|
||||||
|
self.max_slot_steps = max_slot_steps
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.top_k = top_k
|
||||||
|
self.use_attention_pooling = use_attention_pooling
|
||||||
|
|
||||||
|
# 预训练编码器
|
||||||
|
self.encoder = pretrained_encoder
|
||||||
|
|
||||||
|
# 其他组件与之前相同
|
||||||
|
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||||
|
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||||
|
|
||||||
|
self.cross_attention = nn.MultiheadAttention(
|
||||||
|
embed_dim=hidden_size,
|
||||||
|
num_heads=4, # 可配置
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_attention_pooling:
|
||||||
|
self.attention_pooling = AttentionPooling(hidden_size)
|
||||||
|
|
||||||
|
self.gate = nn.Linear(hidden_size, num_experts)
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Expert(
|
||||||
|
input_dim=hidden_size,
|
||||||
|
d_model=hidden_size,
|
||||||
|
num_resblocks=expert_res_blocks,
|
||||||
|
output_multiplier=1,
|
||||||
|
dropout_prob=dropout,
|
||||||
|
)
|
||||||
|
for _ in range(num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||||
|
nn.LayerNorm(hidden_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_size, output_vocab_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
||||||
|
"""
|
||||||
|
使用预训练编码器编码文本
|
||||||
|
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
|
||||||
|
"""
|
||||||
|
outputs = self.encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
# 取 last_hidden_state [batch, seq_len, hidden]
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
||||||
|
"""
|
||||||
|
单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布
|
||||||
|
context: [batch, seq_len, hidden] 文本编码结果
|
||||||
|
slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接)
|
||||||
|
slot_seq_mask: [batch, current_len] 有效位置mask(1有效)
|
||||||
|
返回: [batch, output_vocab_size] 概率分布
|
||||||
|
"""
|
||||||
|
batch_size = slot_seq_emb.size(0)
|
||||||
|
# 交叉注意力:Query是槽位序列(通常只取最后一个步的嵌入作为Query,但这里我们使用整个序列)
|
||||||
|
# 为了简单,我们使用整个序列作为Query,然后取最后一个位置的输出(因为自回归)
|
||||||
|
# 方法1:Query = 最后一个位置的嵌入(单个向量)
|
||||||
|
last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden]
|
||||||
|
# 交叉注意力
|
||||||
|
attn_out, _ = self.cross_attention(
|
||||||
|
query=last_query,
|
||||||
|
key=context,
|
||||||
|
value=context,
|
||||||
|
key_padding_mask=(
|
||||||
|
context.sum(-1) == 0
|
||||||
|
), # 忽略填充位置,实际应传入attention_mask
|
||||||
|
) # [batch, 1, hidden]
|
||||||
|
attn_out = attn_out.squeeze(1) # [batch, hidden]
|
||||||
|
|
||||||
|
# 门控网络:选择top-k专家
|
||||||
|
gate_logits = self.gate(attn_out) # [batch, num_experts]
|
||||||
|
topk_weights, topk_indices = torch.topk(
|
||||||
|
F.softmax(gate_logits, dim=-1), self.top_k, dim=-1
|
||||||
|
)
|
||||||
|
# 归一化权重
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 计算专家输出加权和
|
||||||
|
expert_outputs = torch.zeros_like(attn_out) # [batch, hidden]
|
||||||
|
for i in range(self.top_k):
|
||||||
|
expert_idx = topk_indices[:, i] # [batch]
|
||||||
|
weight = topk_weights[:, i].unsqueeze(1) # [batch, 1]
|
||||||
|
# 批量获取专家输出(需对每个样本分别处理,或用循环)
|
||||||
|
# 这里简单循环,实际可优化为 gather
|
||||||
|
for b in range(batch_size):
|
||||||
|
expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1])
|
||||||
|
expert_outputs[b] += weight[b] * expert_out.squeeze(0)
|
||||||
|
|
||||||
|
# 分类头
|
||||||
|
logits = self.classifier(expert_outputs) # [batch, output_vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward_train(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
attention_mask,
|
||||||
|
slot_target_ids,
|
||||||
|
slot_target_mask=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
训练模式:一次性并行计算所有槽位步的预测
|
||||||
|
slot_target_ids: [batch, max_slot_steps] 真实标签(teacher forcing)
|
||||||
|
slot_target_mask: [batch, max_slot_steps] 有效步mask(1表示该步存在)
|
||||||
|
返回: 各步logits [batch, max_slot_steps, output_vocab_size]
|
||||||
|
"""
|
||||||
|
batch_size, max_steps = slot_target_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# 1. 编码文本
|
||||||
|
context = self.encode_text(
|
||||||
|
input_ids, token_type_ids, attention_mask
|
||||||
|
) # [b, L, h]
|
||||||
|
|
||||||
|
# 2. 构建槽位序列嵌入(teacher forcing:使用真实标签构建)
|
||||||
|
# 将每个样本的槽位ID转为embedding,并添加位置编码
|
||||||
|
# slot_target_ids 形状 [b, T]
|
||||||
|
slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h]
|
||||||
|
# 位置编码(位置从0开始)
|
||||||
|
positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T]
|
||||||
|
pos_emb = self.slot_position_embedding(positions) # [1, T, h]
|
||||||
|
slot_emb = slot_emb + pos_emb
|
||||||
|
# 可选:加入mask(无效位置置0)
|
||||||
|
if slot_target_mask is not None:
|
||||||
|
slot_emb = slot_emb * slot_target_mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
# 3. 交叉注意力:Query为整个槽位序列(每个位置独立预测)
|
||||||
|
# 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列
|
||||||
|
# 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。
|
||||||
|
# 使用 MultiheadAttention 的 query 和 key/value 不同即可。
|
||||||
|
attn_out, _ = self.cross_attention(
|
||||||
|
query=slot_emb, # [b, T, h]
|
||||||
|
key=context, # [b, L, h]
|
||||||
|
value=context,
|
||||||
|
key_padding_mask=(attention_mask == 0), # 忽略文本填充位置
|
||||||
|
) # [b, T, h]
|
||||||
|
|
||||||
|
# 4. 对每个槽位步分别通过门控+专家+分类头
|
||||||
|
# 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理
|
||||||
|
b, T, h = attn_out.shape
|
||||||
|
attn_flat = attn_out.view(b * T, h) # [b*T, h]
|
||||||
|
|
||||||
|
# 门控网络
|
||||||
|
gate_logits = self.gate(attn_flat) # [b*T, num_experts]
|
||||||
|
gate_probs = F.softmax(gate_logits, dim=-1)
|
||||||
|
topk_probs, topk_indices = torch.topk(
|
||||||
|
gate_probs, self.top_k, dim=-1
|
||||||
|
) # [b*T, top_k]
|
||||||
|
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化
|
||||||
|
|
||||||
|
# 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化)
|
||||||
|
expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h]
|
||||||
|
for i in range(self.top_k):
|
||||||
|
weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1]
|
||||||
|
idx = topk_indices[:, i] # [b*T]
|
||||||
|
# 批量获取专家输出
|
||||||
|
# 注意:每个专家处理所有样本,这里用循环专家,效率较低
|
||||||
|
# 实际生产环境应优化为组卷积或使用 einsum,此处保持清晰
|
||||||
|
for expert_id in range(self.num_experts):
|
||||||
|
mask = idx == expert_id
|
||||||
|
if mask.any():
|
||||||
|
# 取出属于该专家的样本
|
||||||
|
sub_input = attn_flat[mask] # [k, h]
|
||||||
|
sub_output = self.experts[expert_id](sub_input) # [k, h]
|
||||||
|
expert_out_flat[mask] += weight[mask] * sub_output
|
||||||
|
|
||||||
|
# 分类头
|
||||||
|
logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size]
|
||||||
|
logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size]
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
attention_mask,
|
||||||
|
slot_target_ids=None,
|
||||||
|
slot_target_mask=None,
|
||||||
|
mode="train",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
统一接口
|
||||||
|
mode='train': 使用 teacher forcing 并行计算所有步 logits,返回 [b, T, vocab]
|
||||||
|
mode='infer': 需结合外部循环,使用 forward_single_step
|
||||||
|
"""
|
||||||
|
if mode == "train":
|
||||||
|
return self.forward_train(
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
attention_mask,
|
||||||
|
slot_target_ids,
|
||||||
|
slot_target_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Inference mode should call forward_single_step directly"
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue