使用 hint 字段替代原始 input_ids 和 attention_mask 进行推理

This commit is contained in:
songsenand 2026-02-13 01:31:18 +08:00
parent bb72b4542b
commit 35e835f618
1 changed files with 2 additions and 2 deletions

View File

@ -230,8 +230,8 @@ class MoEModel(nn.Module):
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
input_ids = batch['hint']["input_ids"].to(self.device)
attention_mask = batch['hint']["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)