From 483e4d4f98c0a93ce54502cc3b4720dc8cf4bf2b Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 10 May 2026 10:38:14 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E7=A7=BB=E9=99=A4=20torch.compil?= =?UTF-8?q?e=20=E7=9A=84=E6=B3=A8=E9=87=8A=E5=92=8C=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/model.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/model/model.py b/src/model/model.py index a5ac2c4..ab81651 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -87,24 +87,12 @@ class InputMethodEngine(nn.Module): # 6. 分类头 self.classifier = nn.Linear(dim, vocab_size) - # 开启 torch.compile 优化 (如果请求) - # 在模型编译时添加优化选项 if compile: - from torch._inductor.select_algorithm import TritonTemplate - - TritonTemplate.all_templates.clear() - self.forward = torch.compile( self.forward, mode="reduce-overhead", fullgraph=False, dynamic=False, - # options={ - # "epilogue_fusion": True, - # "max_autotune": True, - # "triton.cudagraphs": True, - # "reorder_for_compute_comm_overlap": False, - # }, ) def forward(