test.py+8
This commit is contained in:
parent
504353e895
commit
05e440bcfe
5
test.py
5
test.py
|
|
@ -55,7 +55,7 @@ if len_py < 24:
|
|||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
masked_labels = [0] * 8
|
||||
part3 = "。"
|
||||
part4 = ""
|
||||
|
||||
|
|
@ -82,12 +82,11 @@ sample = {
|
|||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
|
||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
||||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
token_type_ids = sample["token_type_ids"]
|
||||
attention_mask = sample["attention_mask"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue