使用 hint 字段替代原始 input_ids 和 attention_mask 进行推理
This commit is contained in:
parent
bb72b4542b
commit
35e835f618
|
|
@ -230,8 +230,8 @@ class MoEModel(nn.Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in eval_dataloader:
|
for batch in eval_dataloader:
|
||||||
# 移动数据到模型设备
|
# 移动数据到模型设备
|
||||||
input_ids = batch["input_ids"].to(self.device)
|
input_ids = batch['hint']["input_ids"].to(self.device)
|
||||||
attention_mask = batch["attention_mask"].to(self.device)
|
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
||||||
pg = batch["pg"].to(self.device)
|
pg = batch["pg"].to(self.device)
|
||||||
labels = batch["char_id"].to(self.device)
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue