feat(BigExpert): 添加 torch.compile 支持并优化编译参数
This commit is contained in:
parent
813dce2224
commit
5dda0e6f85
|
|
@ -1,8 +1,35 @@
|
|||
import torch
|
||||
|
||||
from model.components import MoELayer
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
|
||||
class BigExpert(InputMethodEngine):
|
||||
def __init__(self, *args, **kw):
|
||||
if "compile" in kw:
|
||||
compile = kw.pop("compile")
|
||||
|
||||
else:
|
||||
compile = False
|
||||
kw["compile"] = False
|
||||
super().__init__(*args, **kw)
|
||||
if "dim" in kw:
|
||||
dim = kw["dim"]
|
||||
else:
|
||||
dim = 512
|
||||
|
||||
self.moe = MoELayer(dim=dim, num_experts=40, top_k=3)
|
||||
|
||||
if compile:
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
# mode="reduce-overhead",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
options={
|
||||
"epilogue_fusion": True,
|
||||
"max_autotune": True,
|
||||
"triton.cudagraphs": True,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -82,15 +82,14 @@ class InputMethodEngine(nn.Module):
|
|||
if compile:
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
mode="reduce-overhead",
|
||||
# mode="reduce-overhead",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
options={
|
||||
"epilogue_fusion": True,
|
||||
"max_autotune": True, # 启用自动调优
|
||||
"triton.cudagraphs": True,
|
||||
# 尝试控制归约策略
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
"epilogue_fusion": True,
|
||||
"max_autotune": True,
|
||||
"triton.cudagraphs": True,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue