feat(BigExpert): 添加 torch.compile 支持并优化编译参数

This commit is contained in:
songsenand 2026-04-08 00:21:15 +08:00
parent 813dce2224
commit 5dda0e6f85
2 changed files with 32 additions and 6 deletions

View File

@ -1,8 +1,35 @@
import torch
from model.components import MoELayer
from model.model import InputMethodEngine
class BigExpert(InputMethodEngine):
def __init__(self, *args, **kw):
if "compile" in kw:
compile = kw.pop("compile")
else:
compile = False
kw["compile"] = False
super().__init__(*args, **kw)
if "dim" in kw:
dim = kw["dim"]
else:
dim = 512
self.moe = MoELayer(dim=dim, num_experts=40, top_k=3)
if compile:
self.forward = torch.compile(
self.forward,
# mode="reduce-overhead",
fullgraph=False,
dynamic=False,
options={
"epilogue_fusion": True,
"max_autotune": True,
"triton.cudagraphs": True,
"reorder_for_compute_comm_overlap": False,
},
)

View File

@ -82,14 +82,13 @@ class InputMethodEngine(nn.Module):
if compile:
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",
# mode="reduce-overhead",
fullgraph=False,
dynamic=False,
options={
"epilogue_fusion": True,
"max_autotune": True, # 启用自动调优
"max_autotune": True,
"triton.cudagraphs": True,
# 尝试控制归约策略
"reorder_for_compute_comm_overlap": False,
},
)