fix(model): 移除 torch.compile 的注释和未使用配置
This commit is contained in:
parent
432132a108
commit
483e4d4f98
|
|
@ -87,24 +87,12 @@ class InputMethodEngine(nn.Module):
|
||||||
# 6. 分类头
|
# 6. 分类头
|
||||||
self.classifier = nn.Linear(dim, vocab_size)
|
self.classifier = nn.Linear(dim, vocab_size)
|
||||||
|
|
||||||
# 开启 torch.compile 优化 (如果请求)
|
|
||||||
# 在模型编译时添加优化选项
|
|
||||||
if compile:
|
if compile:
|
||||||
from torch._inductor.select_algorithm import TritonTemplate
|
|
||||||
|
|
||||||
TritonTemplate.all_templates.clear()
|
|
||||||
|
|
||||||
self.forward = torch.compile(
|
self.forward = torch.compile(
|
||||||
self.forward,
|
self.forward,
|
||||||
mode="reduce-overhead",
|
mode="reduce-overhead",
|
||||||
fullgraph=False,
|
fullgraph=False,
|
||||||
dynamic=False,
|
dynamic=False,
|
||||||
# options={
|
|
||||||
# "epilogue_fusion": True,
|
|
||||||
# "max_autotune": True,
|
|
||||||
# "triton.cudagraphs": True,
|
|
||||||
# "reorder_for_compute_comm_overlap": False,
|
|
||||||
# },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue