移除独立文本编码器,改用预训练编码器并调整相关参数
This commit is contained in:
parent
b6a677f15d
commit
1d2ae677f9
6
hello.py
6
hello.py
|
|
@ -1,6 +0,0 @@
|
||||||
def main():
|
|
||||||
print("Hello from suimemodeltraner!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -6,25 +6,17 @@ from .components import AttentionPooling, Expert # , ResidualBlock # 假设已
|
||||||
|
|
||||||
|
|
||||||
class InputMethodEngine(nn.Module):
|
class InputMethodEngine(nn.Module):
|
||||||
"""
|
|
||||||
输入法引擎模型
|
|
||||||
输入:光标前/后文本、拼音、历史记录(四段)的编码序列
|
|
||||||
输出:槽位序列(最多24个文字ID)的概率分布
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_vocab_size: int, # 输入文本的词汇表大小(含特殊符号)
|
pretrained_encoder, # 已加载并扩展好的预训练编码器
|
||||||
output_vocab_size: int, # 输出文字的词汇表大小(含终止符0)
|
output_vocab_size: int,
|
||||||
hidden_size: int = 512, # 隐藏层维度
|
hidden_size: int = 512, # 需与预训练模型隐藏维度一致
|
||||||
num_layers: int = 4, # Transformer 层数
|
max_slot_steps: int = 24,
|
||||||
num_heads: int = 4, # 多头注意力头数
|
num_experts: int = 20,
|
||||||
max_slot_steps: int = 24, # 最大槽位步数(8槽×3步)
|
top_k: int = 3,
|
||||||
num_experts: int = 20, # 专家数量
|
expert_res_blocks: int = 4,
|
||||||
top_k: int = 3, # 每个token选择的专家数
|
|
||||||
expert_res_blocks: int = 4, # 每个专家内部的残差块数
|
|
||||||
dropout: float = 0.3,
|
dropout: float = 0.3,
|
||||||
use_attention_pooling: bool = False, # 是否在槽位特征后加池化(暂未使用)
|
use_attention_pooling: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -35,45 +27,24 @@ class InputMethodEngine(nn.Module):
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.use_attention_pooling = use_attention_pooling
|
self.use_attention_pooling = use_attention_pooling
|
||||||
|
|
||||||
# -------------------- 1. 文本编码器 --------------------
|
# 预训练编码器
|
||||||
self.token_embedding = nn.Embedding(input_vocab_size, hidden_size)
|
self.encoder = pretrained_encoder
|
||||||
self.position_embedding = nn.Embedding(
|
|
||||||
512, hidden_size
|
|
||||||
) # 可学习位置编码,512足够长
|
|
||||||
self.token_type_embedding = nn.Embedding(4, hidden_size) # 4种段落类型
|
|
||||||
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
# 其他组件与之前相同
|
||||||
d_model=hidden_size,
|
|
||||||
nhead=num_heads,
|
|
||||||
dim_feedforward=hidden_size * 4,
|
|
||||||
dropout=dropout,
|
|
||||||
activation="gelu",
|
|
||||||
batch_first=True,
|
|
||||||
)
|
|
||||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
||||||
|
|
||||||
# -------------------- 2. 槽位相关组件 --------------------
|
|
||||||
# 槽位文字嵌入(输出词汇表)
|
|
||||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||||
# 槽位位置编码(可学习,最大步数+1用于起始符)
|
|
||||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||||
|
|
||||||
# 交叉注意力(Query = 槽位序列,Key/Value = 文本编码结果)
|
|
||||||
self.cross_attention = nn.MultiheadAttention(
|
self.cross_attention = nn.MultiheadAttention(
|
||||||
embed_dim=hidden_size,
|
embed_dim=hidden_size,
|
||||||
num_heads=num_heads,
|
num_heads=4, # 可配置
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 可选:注意力池化(若 use_attention_pooling=True,则用于聚合槽位序列特征)
|
|
||||||
if use_attention_pooling:
|
if use_attention_pooling:
|
||||||
self.attention_pooling = AttentionPooling(hidden_size)
|
self.attention_pooling = AttentionPooling(hidden_size)
|
||||||
|
|
||||||
# -------------------- 3. 门控网络 + 专家层 --------------------
|
self.gate = nn.Linear(hidden_size, num_experts)
|
||||||
self.gate = nn.Linear(hidden_size, num_experts) # 输出 logits,用于选择专家
|
|
||||||
|
|
||||||
# 每个专家是一个独立的残差网络(输出维度=hidden_size)
|
|
||||||
self.experts = nn.ModuleList(
|
self.experts = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Expert(
|
Expert(
|
||||||
|
|
@ -87,7 +58,6 @@ class InputMethodEngine(nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# -------------------- 4. 分类头(同维FFN)--------------------
|
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||||
nn.LayerNorm(hidden_size),
|
nn.LayerNorm(hidden_size),
|
||||||
|
|
@ -95,36 +65,21 @@ class InputMethodEngine(nn.Module):
|
||||||
nn.Linear(hidden_size, output_vocab_size),
|
nn.Linear(hidden_size, output_vocab_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化
|
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
"""初始化权重(简单示例)"""
|
|
||||||
for p in self.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
|
|
||||||
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
def encode_text(self, input_ids, token_type_ids, attention_mask):
|
||||||
"""
|
"""
|
||||||
编码输入文本(光标前、拼音、光标后、历史记录)
|
使用预训练编码器编码文本
|
||||||
返回: [batch, seq_len, hidden_size] 编码后的上下文表示
|
注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output
|
||||||
"""
|
"""
|
||||||
seq_len = input_ids.size(1)
|
outputs = self.encoder(
|
||||||
# 位置编码
|
input_ids=input_ids,
|
||||||
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
|
token_type_ids=token_type_ids,
|
||||||
pos_emb = self.position_embedding(positions)
|
attention_mask=attention_mask,
|
||||||
# token类型编码
|
return_dict=True,
|
||||||
type_emb = self.token_type_embedding(token_type_ids)
|
)
|
||||||
# 词嵌入
|
# 取 last_hidden_state [batch, seq_len, hidden]
|
||||||
token_emb = self.token_embedding(input_ids)
|
return outputs.last_hidden_state
|
||||||
# 相加
|
|
||||||
x = token_emb + pos_emb + type_emb
|
|
||||||
# Transformer编码器(需要padding mask)
|
|
||||||
# attention_mask: [batch, seq_len] 1表示有效,0表示填充
|
|
||||||
# TransformerEncoder 要求 mask 为 [batch, seq_len] 且 1 表示忽略
|
|
||||||
src_key_padding_mask = attention_mask == 0 # True表示填充位置
|
|
||||||
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue