SUimeModelTraner/big_expert.py

30 lines
710 B
Python

import torch
from model.components import MoELayer
from model.model import InputMethodEngine
class BigExpert(InputMethodEngine):
def __init__(self, *args, **kw):
if "compile" in kw:
compile = kw.pop("compile")
else:
compile = False
kw["compile"] = False
super().__init__(*args, **kw)
if "dim" in kw:
dim = kw["dim"]
else:
dim = 512
self.moe = MoELayer(dim=dim, num_experts=40, top_k=3)
if compile:
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",
fullgraph=False,
dynamic=False,
)