test.py+8
This commit is contained in:
parent
504353e895
commit
05e440bcfe
5
test.py
5
test.py
|
|
@ -55,7 +55,7 @@ if len_py < 24:
|
||||||
else:
|
else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||||
|
masked_labels = [0] * 8
|
||||||
part3 = "。"
|
part3 = "。"
|
||||||
part4 = ""
|
part4 = ""
|
||||||
|
|
||||||
|
|
@ -82,12 +82,11 @@ sample = {
|
||||||
"pinyin_ids": pinyin_ids,
|
"pinyin_ids": pinyin_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
token_type_ids = sample["token_type_ids"]
|
token_type_ids = sample["token_type_ids"]
|
||||||
attention_mask = sample["attention_mask"]
|
attention_mask = sample["attention_mask"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue