144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
from typing import Optional
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
# 导入 components.py 中的组件
|
||
from .components import (
|
||
ContextEncoder,
|
||
CrossAttentionFusion,
|
||
MoELayer,
|
||
SlotMemory,
|
||
)
|
||
|
||
|
||
class InputMethodEngine(nn.Module):
|
||
"""
|
||
输入法预测引擎模型。
|
||
基于 README 设计的轻量级输入法预测模型,整合了上下文编码、槽位记忆、
|
||
交叉注意力融合、混合专家网络 (MoE) 以及分类预测。
|
||
|
||
输入参数:
|
||
input_ids (torch.Tensor): [batch_size, seq_len] 文本 token ids。
|
||
token_type_ids (torch.Tensor): [batch_size, seq_len] 标识前缀/后缀等类型,用于偏置。
|
||
attention_mask (torch.Tensor): [batch_size, seq_len] 注意力掩码,1 表示有效位置。
|
||
history_slot_ids (torch.Tensor): [batch_size, num_slots] 或 [num_slots] 历史槽位 ID。
|
||
如果输入为 [num_slots],内部会自动扩展 batch 维度。
|
||
|
||
输出:
|
||
logits (torch.Tensor): [batch_size, vocab_size] 下一个字符的概率分布(未经过 softmax)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
vocab_size: int = 10019,
|
||
pinyin_vocab_size: int = 28,
|
||
dim: int = 512,
|
||
num_slots: int = 8,
|
||
n_layers: int = 4,
|
||
n_heads: int = 4,
|
||
num_experts: int = 10,
|
||
max_seq_len: int = 128,
|
||
compile: bool = False,
|
||
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
|
||
):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.num_slots = num_slots
|
||
self.vocab_size = vocab_size
|
||
|
||
# 1. 上下文编码器 (ContextEncoder)
|
||
# 若 use_pinyin=False,则传入 pinyin_vocab_size=1 并固定嵌入为零
|
||
self.context_encoder = ContextEncoder(
|
||
vocab_size=vocab_size,
|
||
pinyin_vocab_size=pinyin_vocab_size,
|
||
dim=dim,
|
||
n_layers=n_layers,
|
||
n_heads=n_heads,
|
||
max_len=max_seq_len,
|
||
)
|
||
|
||
# 2. 槽位记忆模块 (SlotMemory)
|
||
# 适配历史槽位数量为 num_slots(每个槽位对应一个词,而非多步)
|
||
self.slot_memory = SlotMemory(
|
||
vocab_size=vocab_size,
|
||
max_slots=num_slots,
|
||
steps_per_slot=1, # 每个槽位只占一步
|
||
dim=dim,
|
||
)
|
||
|
||
# 3. 交叉注意力融合 (CrossAttentionFusion)
|
||
# 使用 F.scaled_dot_product_attention 实现的版本
|
||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||
|
||
# 4. 混合专家层 (MoE)
|
||
self.moe = MoELayer(
|
||
dim=dim,
|
||
num_experts=num_experts,
|
||
top_k=3,
|
||
num_resblocks=12,
|
||
moe_mode=moe_mode,
|
||
)
|
||
|
||
# 5. 槽位注意力池化
|
||
self.slot_attention = nn.Linear(dim, 1)
|
||
|
||
# 6. 分类头
|
||
self.classifier = nn.Linear(dim, vocab_size)
|
||
|
||
# 开启 torch.compile 优化 (如果请求)
|
||
# 在模型编译时添加优化选项
|
||
if compile:
|
||
from torch._inductor.select_algorithm import TritonTemplate
|
||
|
||
TritonTemplate.all_templates.clear()
|
||
|
||
self.forward = torch.compile(
|
||
self.forward,
|
||
mode="reduce-overhead",
|
||
fullgraph=False,
|
||
dynamic=False,
|
||
# options={
|
||
# "epilogue_fusion": True,
|
||
# "max_autotune": True,
|
||
# "triton.cudagraphs": True,
|
||
# "reorder_for_compute_comm_overlap": False,
|
||
# },
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
token_type_ids: torch.Tensor,
|
||
attention_mask: torch.Tensor,
|
||
pinyin_ids: torch.Tensor,
|
||
history_slot_ids: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
"""
|
||
前向传播。
|
||
"""
|
||
batch_size = input_ids.size(0)
|
||
|
||
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
||
|
||
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||
|
||
S = self.slot_memory(history_slot_ids)
|
||
|
||
context_mask = attention_mask == 0
|
||
pinyin_mask = pinyin_ids == 0
|
||
fused = self.cross_attn(
|
||
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
|
||
)
|
||
|
||
moe_out = self.moe(fused)
|
||
|
||
batch_size = input_ids.size(0)
|
||
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||
slot_weights = torch.softmax(slot_scores, dim=1)
|
||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
||
|
||
logits = self.classifier(pooled)
|
||
return logits
|