diff --git a/src/trainer/model.py b/src/trainer/model.py index 723954b..8183dcc 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -230,8 +230,8 @@ class MoEModel(nn.Module): with torch.no_grad(): for batch in eval_dataloader: # 移动数据到模型设备 - input_ids = batch["input_ids"].to(self.device) - attention_mask = batch["attention_mask"].to(self.device) + input_ids = batch['hint']["input_ids"].to(self.device) + attention_mask = batch['hint']["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device)