fix(model): 移除 torch.compile 的注释和未使用配置
This commit is contained in:
parent
432132a108
commit
483e4d4f98
|
|
@ -87,24 +87,12 @@ class InputMethodEngine(nn.Module):
|
|||
# 6. 分类头
|
||||
self.classifier = nn.Linear(dim, vocab_size)
|
||||
|
||||
# 开启 torch.compile 优化 (如果请求)
|
||||
# 在模型编译时添加优化选项
|
||||
if compile:
|
||||
from torch._inductor.select_algorithm import TritonTemplate
|
||||
|
||||
TritonTemplate.all_templates.clear()
|
||||
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
mode="reduce-overhead",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
# options={
|
||||
# "epilogue_fusion": True,
|
||||
# "max_autotune": True,
|
||||
# "triton.cudagraphs": True,
|
||||
# "reorder_for_compute_comm_overlap": False,
|
||||
# },
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
|||
Loading…
Reference in New Issue