diff --git a/src/model/model.py b/src/model/model.py index a5ac2c4..ab81651 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -87,24 +87,12 @@ class InputMethodEngine(nn.Module): # 6. 分类头 self.classifier = nn.Linear(dim, vocab_size) - # 开启 torch.compile 优化 (如果请求) - # 在模型编译时添加优化选项 if compile: - from torch._inductor.select_algorithm import TritonTemplate - - TritonTemplate.all_templates.clear() - self.forward = torch.compile( self.forward, mode="reduce-overhead", fullgraph=False, dynamic=False, - # options={ - # "epilogue_fusion": True, - # "max_autotune": True, - # "triton.cudagraphs": True, - # "reorder_for_compute_comm_overlap": False, - # }, ) def forward(