移除独立文本编码器,改用预训练编码器并调整相关参数

This commit is contained in:
songsenand 2026-03-23 16:32:11 +08:00
parent b6a677f15d
commit 1d2ae677f9
2 changed files with 23 additions and 74 deletions

View File

@ -1,6 +0,0 @@
def main():
print("Hello from suimemodeltraner!")
if __name__ == "__main__":
main()

View File

@ -6,25 +6,17 @@ from .components import AttentionPooling, Expert # , ResidualBlock # 假设已
class InputMethodEngine(nn.Module): class InputMethodEngine(nn.Module):
"""
输入法引擎模型
输入光标前/后文本拼音历史记录四段的编码序列
输出槽位序列最多24个文字ID的概率分布
"""
def __init__( def __init__(
self, self,
input_vocab_size: int, # 输入文本的词汇表大小(含特殊符号) pretrained_encoder, # 已加载并扩展好的预训练编码器
output_vocab_size: int, # 输出文字的词汇表大小含终止符0 output_vocab_size: int,
hidden_size: int = 512, # 隐藏层维度 hidden_size: int = 512, # 需与预训练模型隐藏维度一致
num_layers: int = 4, # Transformer 层数 max_slot_steps: int = 24,
num_heads: int = 4, # 多头注意力头数 num_experts: int = 20,
max_slot_steps: int = 24, # 最大槽位步数8槽×3步 top_k: int = 3,
num_experts: int = 20, # 专家数量 expert_res_blocks: int = 4,
top_k: int = 3, # 每个token选择的专家数
expert_res_blocks: int = 4, # 每个专家内部的残差块数
dropout: float = 0.3, dropout: float = 0.3,
use_attention_pooling: bool = False, # 是否在槽位特征后加池化(暂未使用) use_attention_pooling: bool = False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -35,45 +27,24 @@ class InputMethodEngine(nn.Module):
self.top_k = top_k self.top_k = top_k
self.use_attention_pooling = use_attention_pooling self.use_attention_pooling = use_attention_pooling
# -------------------- 1. 文本编码器 -------------------- # 预训练编码器
self.token_embedding = nn.Embedding(input_vocab_size, hidden_size) self.encoder = pretrained_encoder
self.position_embedding = nn.Embedding(
512, hidden_size
) # 可学习位置编码512足够长
self.token_type_embedding = nn.Embedding(4, hidden_size) # 4种段落类型
encoder_layer = nn.TransformerEncoderLayer( # 其他组件与之前相同
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size * 4,
dropout=dropout,
activation="gelu",
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# -------------------- 2. 槽位相关组件 --------------------
# 槽位文字嵌入(输出词汇表)
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size) self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
# 槽位位置编码(可学习,最大步数+1用于起始符
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size) self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
# 交叉注意力Query = 槽位序列Key/Value = 文本编码结果)
self.cross_attention = nn.MultiheadAttention( self.cross_attention = nn.MultiheadAttention(
embed_dim=hidden_size, embed_dim=hidden_size,
num_heads=num_heads, num_heads=4, # 可配置
dropout=dropout, dropout=dropout,
batch_first=True, batch_first=True,
) )
# 可选:注意力池化(若 use_attention_pooling=True则用于聚合槽位序列特征
if use_attention_pooling: if use_attention_pooling:
self.attention_pooling = AttentionPooling(hidden_size) self.attention_pooling = AttentionPooling(hidden_size)
# -------------------- 3. 门控网络 + 专家层 -------------------- self.gate = nn.Linear(hidden_size, num_experts)
self.gate = nn.Linear(hidden_size, num_experts) # 输出 logits用于选择专家
# 每个专家是一个独立的残差网络(输出维度=hidden_size
self.experts = nn.ModuleList( self.experts = nn.ModuleList(
[ [
Expert( Expert(
@ -87,7 +58,6 @@ class InputMethodEngine(nn.Module):
] ]
) )
# -------------------- 4. 分类头同维FFN--------------------
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=False), nn.Linear(hidden_size, hidden_size, bias=False),
nn.LayerNorm(hidden_size), nn.LayerNorm(hidden_size),
@ -95,36 +65,21 @@ class InputMethodEngine(nn.Module):
nn.Linear(hidden_size, output_vocab_size), nn.Linear(hidden_size, output_vocab_size),
) )
# 初始化
self._init_weights() self._init_weights()
def _init_weights(self):
"""初始化权重(简单示例)"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode_text(self, input_ids, token_type_ids, attention_mask): def encode_text(self, input_ids, token_type_ids, attention_mask):
""" """
编码输入文本光标前拼音光标后历史记录 使用预训练编码器编码文本
返回: [batch, seq_len, hidden_size] 编码后的上下文表示 注意预训练模型输出可能包含 last_hidden_state pooler_output
""" """
seq_len = input_ids.size(1) outputs = self.encoder(
# 位置编码 input_ids=input_ids,
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) token_type_ids=token_type_ids,
pos_emb = self.position_embedding(positions) attention_mask=attention_mask,
# token类型编码 return_dict=True,
type_emb = self.token_type_embedding(token_type_ids) )
# 词嵌入 # 取 last_hidden_state [batch, seq_len, hidden]
token_emb = self.token_embedding(input_ids) return outputs.last_hidden_state
# 相加
x = token_emb + pos_emb + type_emb
# Transformer编码器需要padding mask
# attention_mask: [batch, seq_len] 1表示有效0表示填充
# TransformerEncoder 要求 mask 为 [batch, seq_len] 且 1 表示忽略
src_key_padding_mask = attention_mask == 0 # True表示填充位置
x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
return x
def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None): def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None):
""" """