移除独立文本编码器,改用预训练编码器并调整相关参数
This commit is contained in:
parent
b6a677f15d
commit
1d2ae677f9
6
hello.py
6
hello.py
|
|
@ -1,6 +0,0 @@
|
|||
def main():
|
||||
print("Hello from suimemodeltraner!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -6,25 +6,17 @@ from .components import AttentionPooling, Expert # , ResidualBlock # 假设已
|
|||
|
||||
|
||||
class InputMethodEngine(nn.Module):
|
||||
"""
|
||||
输入法引擎模型
|
||||
输入:光标前/后文本、拼音、历史记录(四段)的编码序列
|
||||
输出:槽位序列(最多24个文字ID)的概率分布
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_vocab_size: int, # 输入文本的词汇表大小(含特殊符号)
|
||||
output_vocab_size: int, # 输出文字的词汇表大小(含终止符0)
|
||||
hidden_size: int = 512, # 隐藏层维度
|
||||
num_layers: int = 4, # Transformer 层数
|
||||
num_heads: int = 4, # 多头注意力头数
|
||||
max_slot_steps: int = 24, # 最大槽位步数(8槽×3步)
|
||||
num_experts: int = 20, # 专家数量
|
||||
top_k: int = 3, # 每个token选择的专家数
|
||||
expert_res_blocks: int = 4, # 每个专家内部的残差块数
|
||||
pretrained_encoder, # 已加载并扩展好的预训练编码器
|
||||
output_vocab_size: int,
|
||||
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
|
||||
max_slot_steps: int = 24,
|
||||
num_experts: int = 20,
|
||||
top_k: int = 3,
|
||||
expert_res_blocks: int = 4,
|
||||
dropout: float = 0.3,
|
||||
use_attention_pooling: bool = False, # 是否在槽位特征后加池化(暂未使用)
|
||||
use_attention_pooling: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
|
@ -35,45 +27,24 @@ class InputMethodEngine(nn.Module):
|
|||
self.top_k = top_k
|
||||
self.use_attention_pooling = use_attention_pooling
|
||||
|
||||
# -------------------- 1. 文本编码器 --------------------
|
||||
self.token_embedding = nn.Embedding(input_vocab_size, hidden_size)
|
||||
self.position_embedding = nn.Embedding(
|
||||
512, hidden_size
|
||||
) # 可学习位置编码,512足够长
|
||||
self.token_type_embedding = nn.Embedding(4, hidden_size) # 4种段落类型
|
||||
# 预训练编码器
|
||||
self.encoder = pretrained_encoder
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=hidden_size * 4,
|
||||
dropout=dropout,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# -------------------- 2. 槽位相关组件 --------------------
|
||||
# 槽位文字嵌入(输出词汇表)
|
||||
# 其他组件与之前相同
|
||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||
# 槽位位置编码(可学习,最大步数+1用于起始符)
|
||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||
|
||||
# 交叉注意力(Query = 槽位序列,Key/Value = 文本编码结果)
|
||||
self.cross_attention = nn.MultiheadAttention(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_heads=4, # 可配置
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
# 可选:注意力池化(若 use_attention_pooling=True,则用于聚合槽位序列特征)
|
||||
if use_attention_pooling:
|
||||
self.attention_pooling = AttentionPooling(hidden_size)
|
||||
|
||||
# -------------------- 3. 门控网络 + 专家层 --------------------
|
||||
self.gate = nn.Linear(hidden_size, num_experts) # 输出 logits,用于选择专家
|
||||
|
||||
# 每个专家是一个独立的残差网络(输出维度=hidden_size)
|
||||
self.gate = nn.Linear(hidden_size, num_experts)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
|
|
@ -87,7 +58,6 @@ class InputMethodEngine(nn.Module):
|
|||
]
|
||||
)
|
||||
|
||||
# -------------------- 4. 分类头(同维FFN)--------------------
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||
nn.LayerNorm(hidden_size),
|
||||
|
|
@ -95,36 +65,21 @@ class InputMethodEngine(nn.Module):
|
|||
nn.Linear(hidden_size, output_vocab_size),
|
||||
)
|
||||
|
||||
# 初始化
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化权重(简单示例)"""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
||||
"""
|
||||
编码输入文本(光标前、拼音、光标后、历史记录)
|
||||
返回: [batch, seq_len, hidden_size] 编码后的上下文表示
|
||||
使用预训练编码器编码文本
|
||||
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
|
||||
"""
|
||||
seq_len = input_ids.size(1)
|
||||
# 位置编码
|
||||
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
|
||||
pos_emb = self.position_embedding(positions)
|
||||
# token类型编码
|
||||
type_emb = self.token_type_embedding(token_type_ids)
|
||||
# 词嵌入
|
||||
token_emb = self.token_embedding(input_ids)
|
||||
# 相加
|
||||
x = token_emb + pos_emb + type_emb
|
||||
# Transformer编码器(需要padding mask)
|
||||
# attention_mask: [batch, seq_len] 1表示有效,0表示填充
|
||||
# TransformerEncoder 要求 mask 为 [batch, seq_len] 且 1 表示忽略
|
||||
src_key_padding_mask = attention_mask == 0 # True表示填充位置
|
||||
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
|
||||
return x
|
||||
outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
# 取 last_hidden_state [batch, seq_len, hidden]
|
||||
return outputs.last_hidden_state
|
||||
|
||||
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue