test.py+8

This commit is contained in:
songsenand 2026-04-09 17:49:23 +08:00
parent 504353e895
commit 05e440bcfe
1 changed files with 2 additions and 3 deletions

View File

@ -55,7 +55,7 @@ if len_py < 24:
else: else:
pinyin_ids = pinyin_ids[:24] pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
masked_labels = [0] * 8
part3 = "" part3 = ""
part4 = "" part4 = ""
@ -82,12 +82,11 @@ sample = {
"pinyin_ids": pinyin_ids, "pinyin_ids": pinyin_ids,
} }
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
model = InputMethodEngine(pinyin_vocab_size=30, compile=False) model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
input_ids = sample["input_ids"] input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"] token_type_ids = sample["token_type_ids"]
attention_mask = sample["attention_mask"] attention_mask = sample["attention_mask"]