diff --git a/test.py b/test.py index 9702402..a1aebfc 100644 --- a/test.py +++ b/test.py @@ -80,11 +80,11 @@ sample = { "pinyin_ids": pinyin_ids, } - model = InputMethodEngine(pinyin_vocab_size=30, compile=False) checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) + input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] attention_mask = sample["attention_mask"]