65 lines
2.3 KiB
Rust
65 lines
2.3 KiB
Rust
use std::path::PathBuf;
|
|
use std::time::Instant;
|
|
|
|
mod model;
|
|
mod tokenizers;
|
|
mod vocabs;
|
|
|
|
use tokenizers::HFTokenizer;
|
|
// use crate::vocabs::Dictionary;
|
|
use crate::model::OnnxModel;
|
|
|
|
fn main() {
|
|
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
|
tokenizer_json_path.push("assets");
|
|
tokenizer_json_path.push("tokenizer.json");
|
|
|
|
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
|
model_path.push("assets");
|
|
model_path.push("20260228suinput_fp32.onnx");
|
|
|
|
let mut session = OnnxModel::new(model_path, 4).unwrap();
|
|
let tokenizer = HFTokenizer::new(tokenizer_json_path).unwrap();
|
|
let start = Instant::now(); // 开始计时
|
|
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
|
|
let _ = session.predict(sample).unwrap();
|
|
let duration = start.elapsed(); // 结束计时
|
|
print!("predict Time elapsed: {:?}\n", duration);
|
|
|
|
let start = Instant::now(); // 开始计时
|
|
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
|
|
let logits = session
|
|
.predict_to_sorted_pairs_simple(sample, false)
|
|
.unwrap();
|
|
let duration = start.elapsed(); // 结束计时
|
|
print!("predict Time elapsed: {:?}\n", duration);
|
|
|
|
let start = Instant::now(); // 开始计时
|
|
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
|
|
let probs = session
|
|
.predict_to_sorted_pairs_simple(sample, true)
|
|
.unwrap();
|
|
let duration = start.elapsed(); // 结束计时
|
|
print!("predict Time elapsed: {:?}\n", duration);
|
|
print!("logits: {:?}", &logits[0..10].to_vec());
|
|
print!("probs: {:?}", &probs[0..10].to_vec());
|
|
|
|
/*
|
|
if let Ok(tokenizer) = HFTokenizer::new(tokenizer_json_path) {
|
|
println!("Tokenizer loaded successfully");
|
|
if let Ok(sample) = tokenizer.gen_predict_sample("从北京到上", "hai") {
|
|
let logits: ndarray::ArrayBase<
|
|
ndarray::OwnedRepr<f32>,
|
|
ndarray::Dim<ndarray::IxDynImpl>,
|
|
f32,
|
|
> = session.predict(sample).unwrap();
|
|
|
|
let duration = start.elapsed(); // 结束计时
|
|
|
|
println!("Time elapsed: {:?}", duration);
|
|
println!("Model input generated successfully");
|
|
println!("Logits: {:?}", logits);
|
|
}
|
|
}*/
|
|
}
|