使用 hint 字段替代原始 input_ids 和 attention_mask 进行推理
This commit is contained in:
parent
bb72b4542b
commit
35e835f618
|
|
@ -230,8 +230,8 @@ class MoEModel(nn.Module):
|
|||
with torch.no_grad():
|
||||
for batch in eval_dataloader:
|
||||
# 移动数据到模型设备
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
attention_mask = batch["attention_mask"].to(self.device)
|
||||
input_ids = batch['hint']["input_ids"].to(self.device)
|
||||
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue