From 1d2ae677f985de86a2bba934c324c122bf6a8404 Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 23 Mar 2026 16:32:11 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E7=8B=AC=E7=AB=8B=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E7=BC=96=E7=A0=81=E5=99=A8=EF=BC=8C=E6=94=B9=E7=94=A8?= =?UTF-8?q?=E9=A2=84=E8=AE=AD=E7=BB=83=E7=BC=96=E7=A0=81=E5=99=A8=E5=B9=B6?= =?UTF-8?q?=E8=B0=83=E6=95=B4=E7=9B=B8=E5=85=B3=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hello.py | 6 --- src/model/model.py | 91 ++++++++++++---------------------------------- 2 files changed, 23 insertions(+), 74 deletions(-) delete mode 100644 hello.py diff --git a/hello.py b/hello.py deleted file mode 100644 index 162518b..0000000 --- a/hello.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from suimemodeltraner!") - - -if __name__ == "__main__": - main() diff --git a/src/model/model.py b/src/model/model.py index 13a0fae..3d55db7 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -6,25 +6,17 @@ from .components import AttentionPooling, Expert # , ResidualBlock # 假设已 class InputMethodEngine(nn.Module): - """ - 输入法引擎模型 - 输入:光标前/后文本、拼音、历史记录(四段)的编码序列 - 输出:槽位序列(最多24个文字ID)的概率分布 - """ - def __init__( self, - input_vocab_size: int, # 输入文本的词汇表大小(含特殊符号) - output_vocab_size: int, # 输出文字的词汇表大小(含终止符0) - hidden_size: int = 512, # 隐藏层维度 - num_layers: int = 4, # Transformer 层数 - num_heads: int = 4, # 多头注意力头数 - max_slot_steps: int = 24, # 最大槽位步数(8槽×3步) - num_experts: int = 20, # 专家数量 - top_k: int = 3, # 每个token选择的专家数 - expert_res_blocks: int = 4, # 每个专家内部的残差块数 + pretrained_encoder, # 已加载并扩展好的预训练编码器 + output_vocab_size: int, + hidden_size: int = 512, # 需与预训练模型隐藏维度一致 + max_slot_steps: int = 24, + num_experts: int = 20, + top_k: int = 3, + expert_res_blocks: int = 4, dropout: float = 0.3, - use_attention_pooling: bool = False, # 是否在槽位特征后加池化(暂未使用) + use_attention_pooling: bool = False, **kwargs, ): super().__init__() @@ -35,45 +27,24 @@ class InputMethodEngine(nn.Module): self.top_k = top_k self.use_attention_pooling = use_attention_pooling - # -------------------- 1. 文本编码器 -------------------- - self.token_embedding = nn.Embedding(input_vocab_size, hidden_size) - self.position_embedding = nn.Embedding( - 512, hidden_size - ) # 可学习位置编码,512足够长 - self.token_type_embedding = nn.Embedding(4, hidden_size) # 4种段落类型 + # 预训练编码器 + self.encoder = pretrained_encoder - encoder_layer = nn.TransformerEncoderLayer( - d_model=hidden_size, - nhead=num_heads, - dim_feedforward=hidden_size * 4, - dropout=dropout, - activation="gelu", - batch_first=True, - ) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - - # -------------------- 2. 槽位相关组件 -------------------- - # 槽位文字嵌入(输出词汇表) + # 其他组件与之前相同 self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size) - # 槽位位置编码(可学习,最大步数+1用于起始符) self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size) - # 交叉注意力(Query = 槽位序列,Key/Value = 文本编码结果) self.cross_attention = nn.MultiheadAttention( embed_dim=hidden_size, - num_heads=num_heads, + num_heads=4, # 可配置 dropout=dropout, batch_first=True, ) - # 可选:注意力池化(若 use_attention_pooling=True,则用于聚合槽位序列特征) if use_attention_pooling: self.attention_pooling = AttentionPooling(hidden_size) - # -------------------- 3. 门控网络 + 专家层 -------------------- - self.gate = nn.Linear(hidden_size, num_experts) # 输出 logits,用于选择专家 - - # 每个专家是一个独立的残差网络(输出维度=hidden_size) + self.gate = nn.Linear(hidden_size, num_experts) self.experts = nn.ModuleList( [ Expert( @@ -87,7 +58,6 @@ class InputMethodEngine(nn.Module): ] ) - # -------------------- 4. 分类头(同维FFN)-------------------- self.classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size, bias=False), nn.LayerNorm(hidden_size), @@ -95,36 +65,21 @@ class InputMethodEngine(nn.Module): nn.Linear(hidden_size, output_vocab_size), ) - # 初始化 self._init_weights() - def _init_weights(self): - """初始化权重(简单示例)""" - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - def encode_text(self, input_ids, token_type_ids, attention_mask): """ - 编码输入文本(光标前、拼音、光标后、历史记录) - 返回: [batch, seq_len, hidden_size] 编码后的上下文表示 + 使用预训练编码器编码文本 + 注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output """ - seq_len = input_ids.size(1) - # 位置编码 - positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) - pos_emb = self.position_embedding(positions) - # token类型编码 - type_emb = self.token_type_embedding(token_type_ids) - # 词嵌入 - token_emb = self.token_embedding(input_ids) - # 相加 - x = token_emb + pos_emb + type_emb - # Transformer编码器(需要padding mask) - # attention_mask: [batch, seq_len] 1表示有效,0表示填充 - # TransformerEncoder 要求 mask 为 [batch, seq_len] 且 1 表示忽略 - src_key_padding_mask = attention_mask == 0 # True表示填充位置 - x = self.transformer(x, src_key_padding_mask=src_key_padding_mask) - return x + outputs = self.encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + return_dict=True, + ) + # 取 last_hidden_state [batch, seq_len, hidden] + return outputs.last_hidden_state def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None): """