From 05e440bcfe896b73cc2c74bd19c2015f6e3a7f74 Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 9 Apr 2026 17:49:23 +0800 Subject: [PATCH] test.py+8 --- test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index b71e675..abcd548 100644 --- a/test.py +++ b/test.py @@ -55,7 +55,7 @@ if len_py < 24: else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) - +masked_labels = [0] * 8 part3 = "。" part4 = "" @@ -82,12 +82,11 @@ sample = { "pinyin_ids": pinyin_ids, } -dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) - model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) + input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"]