fix(model): 移除 torch.compile 的注释和未使用配置

This commit is contained in:
songsenand 2026-05-10 10:38:14 +08:00
parent 432132a108
commit 483e4d4f98
1 changed files with 0 additions and 12 deletions

View File

@ -87,24 +87,12 @@ class InputMethodEngine(nn.Module):
# 6. 分类头
self.classifier = nn.Linear(dim, vocab_size)
# 开启 torch.compile 优化 (如果请求)
# 在模型编译时添加优化选项
if compile:
from torch._inductor.select_algorithm import TritonTemplate
TritonTemplate.all_templates.clear()
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",
fullgraph=False,
dynamic=False,
# options={
# "epilogue_fusion": True,
# "max_autotune": True,
# "triton.cudagraphs": True,
# "reorder_for_compute_comm_overlap": False,
# },
)
def forward(