import torch from model.components import MoELayer from model.model import InputMethodEngine class BigExpert(InputMethodEngine): def __init__(self, *args, **kw): if "compile" in kw: compile = kw.pop("compile") else: compile = False kw["compile"] = False super().__init__(*args, **kw) if "dim" in kw: dim = kw["dim"] else: dim = 512 self.moe = MoELayer(dim=dim, num_experts=40, top_k=3) if compile: self.forward = torch.compile( self.forward, mode="reduce-overhead", fullgraph=False, dynamic=False, )