SUimeModelTraner/src/model/model.py

144 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入 components.py 中的组件
from .components import (
ContextEncoder,
CrossAttentionFusion,
MoELayer,
SlotMemory,
)
class InputMethodEngine(nn.Module):
"""
输入法预测引擎模型。
基于 README 设计的轻量级输入法预测模型,整合了上下文编码、槽位记忆、
交叉注意力融合、混合专家网络 (MoE) 以及分类预测。
输入参数:
input_ids (torch.Tensor): [batch_size, seq_len] 文本 token ids。
token_type_ids (torch.Tensor): [batch_size, seq_len] 标识前缀/后缀等类型,用于偏置。
attention_mask (torch.Tensor): [batch_size, seq_len] 注意力掩码1 表示有效位置。
history_slot_ids (torch.Tensor): [batch_size, num_slots] 或 [num_slots] 历史槽位 ID。
如果输入为 [num_slots],内部会自动扩展 batch 维度。
输出:
logits (torch.Tensor): [batch_size, vocab_size] 下一个字符的概率分布(未经过 softmax
"""
def __init__(
self,
vocab_size: int = 10019,
pinyin_vocab_size: int = 28,
dim: int = 512,
num_slots: int = 8,
n_layers: int = 4,
n_heads: int = 4,
num_experts: int = 10,
max_seq_len: int = 128,
compile: bool = False,
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
):
super().__init__()
self.dim = dim
self.num_slots = num_slots
self.vocab_size = vocab_size
# 1. 上下文编码器 (ContextEncoder)
# 若 use_pinyin=False则传入 pinyin_vocab_size=1 并固定嵌入为零
self.context_encoder = ContextEncoder(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
n_layers=n_layers,
n_heads=n_heads,
max_len=max_seq_len,
)
# 2. 槽位记忆模块 (SlotMemory)
# 适配历史槽位数量为 num_slots每个槽位对应一个词而非多步
self.slot_memory = SlotMemory(
vocab_size=vocab_size,
max_slots=num_slots,
steps_per_slot=1, # 每个槽位只占一步
dim=dim,
)
# 3. 交叉注意力融合 (CrossAttentionFusion)
# 使用 F.scaled_dot_product_attention 实现的版本
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE)
self.moe = MoELayer(
dim=dim,
num_experts=num_experts,
top_k=3,
num_resblocks=12,
moe_mode=moe_mode,
)
# 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1)
# 6. 分类头
self.classifier = nn.Linear(dim, vocab_size)
# 开启 torch.compile 优化 (如果请求)
# 在模型编译时添加优化选项
if compile:
from torch._inductor.select_algorithm import TritonTemplate
TritonTemplate.all_templates.clear()
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",
fullgraph=False,
dynamic=False,
# options={
# "epilogue_fusion": True,
# "max_autotune": True,
# "triton.cudagraphs": True,
# "reorder_for_compute_comm_overlap": False,
# },
)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
attention_mask: torch.Tensor,
pinyin_ids: torch.Tensor,
history_slot_ids: torch.Tensor,
) -> torch.Tensor:
"""
前向传播。
"""
batch_size = input_ids.size(0)
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
S = self.slot_memory(history_slot_ids)
context_mask = attention_mask == 0
pinyin_mask = pinyin_ids == 0
fused = self.cross_attn(
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
)
moe_out = self.moe(fused)
batch_size = input_ids.size(0)
slot_scores = self.slot_attention(moe_out).squeeze(-1)
slot_weights = torch.softmax(slot_scores, dim=1)
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
logits = self.classifier(pooled)
return logits