36 lines
950 B
Python
36 lines
950 B
Python
import torch
|
|
|
|
from model.components import MoELayer
|
|
from model.model import InputMethodEngine
|
|
|
|
|
|
class BigExpert(InputMethodEngine):
|
|
def __init__(self, *args, **kw):
|
|
if "compile" in kw:
|
|
compile = kw.pop("compile")
|
|
|
|
else:
|
|
compile = False
|
|
kw["compile"] = False
|
|
super().__init__(*args, **kw)
|
|
if "dim" in kw:
|
|
dim = kw["dim"]
|
|
else:
|
|
dim = 512
|
|
|
|
self.moe = MoELayer(dim=dim, num_experts=40, top_k=3)
|
|
|
|
if compile:
|
|
self.forward = torch.compile(
|
|
self.forward,
|
|
# mode="reduce-overhead",
|
|
fullgraph=False,
|
|
dynamic=False,
|
|
options={
|
|
"epilogue_fusion": True,
|
|
"max_autotune": True,
|
|
"triton.cudagraphs": True,
|
|
"reorder_for_compute_comm_overlap": False,
|
|
},
|
|
)
|