diff --git a/test.py b/test.py index b71e675..abcd548 100644 --- a/test.py +++ b/test.py @@ -55,7 +55,7 @@ if len_py < 24: else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) - +masked_labels = [0] * 8 part3 = "。" part4 = "" @@ -82,12 +82,11 @@ sample = { "pinyin_ids": pinyin_ids, } -dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) - model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) + input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"]